Source code for sedona.spark.sql.dataframe_api

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import functools
import inspect
import itertools
import typing
from typing import Any, Callable, Iterable, List, Mapping, Tuple, Type, Union

from pyspark import SparkContext
from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as f

try:
    from pyspark.sql.connect.column import Column as ConnectColumn
    from pyspark.sql.utils import is_remote
except ImportError:
    # be backwards compatible with Spark < 3.4
[docs] def is_remote(): return False
[docs] class ConnectColumn: pass
else: from sedona.spark.sql.connect import call_sedona_function_connect ColumnOrName = Union[Column, ConnectColumn, str] ColumnOrNameOrNumber = Union[Column, ConnectColumn, str, float, int] def _convert_argument_to_java_column(arg: Any) -> Column: if isinstance(arg, Column): return arg._jc elif isinstance(arg, str): return f.col(arg)._jc elif isinstance(arg, Iterable): return f.array(*[Column(_convert_argument_to_java_column(x)) for x in arg])._jc else: return f.lit(arg)._jc
[docs] def call_sedona_function( object_name: str, function_name: str, args: Union[Any, Tuple[Any]] ) -> Column: # apparently a Column is an Iterable so we need to check for it explicitly if (not isinstance(args, Iterable)) or isinstance( args, (str, Column, ConnectColumn) ): args = [args] # in spark-connect environments use connect API if is_remote(): return call_sedona_function_connect(function_name, args) args = map(_convert_argument_to_java_column, args) jvm = SparkContext._jvm if jvm is None: raise ValueError( "No active spark context was detected. Unable to call sedona function." ) jobject = getattr(jvm, object_name) jfunc = getattr(jobject, function_name) jc = jfunc(*args) return Column(jc)
def _get_type_list(annotated_type: Type) -> Tuple[Type, ...]: """Convert a type annotation into a tuple of types. For most types this will be a tuple with a single element, but for Union a tuple with multiple elements will be returned. :param annotated_type: Type annotation to convert a tuple. :type annotated_type: Type :return: Tuple of all types covered by the type annotation. :rtype: Tuple[Type, ...] """ if typing.get_origin(annotated_type) is typing.Union: # again, there is a really nice method for this in 3.8: typing.get_args valid_types = typing.get_args(annotated_type) else: valid_types = (annotated_type,) # functions accepting a Column should also accept the Spark Connect sort of Column if Column in valid_types: valid_types = valid_types + (ConnectColumn,) return valid_types def _strip_extra_from_class_name(class_name): return class_name[len("<class '") : -len("'>")].split(".")[-1] def _get_readable_name_for_type(type: Type) -> str: """Get a human readable name for a type annotation used on a function's parameter. :param type: Type annotation for a parameter. :type type: Type :return: Human readable name for the type annotation. :rtype: str """ if isinstance(type, typing._GenericAlias) and type.__origin__._name == "Union": return f"Union[{', '.join(_strip_extra_from_class_name(str(x)) for x in type.__args__)}]" return _strip_extra_from_class_name(str(type)) def _get_bound_arguments(f: Callable, *args, **kwargs) -> Mapping[str, Any]: """Bind the passed arguments to f with actual parameter names, including defaults. :param f: Function to bind arguments for. :type f: Callable :return: Dictionary of parameter names to argument values. :rtype: Mapping[str, Any] """ f_signature = inspect.signature(f) bound_args = f_signature.bind(*args, **kwargs) bound_args.apply_defaults() return bound_args def _check_bound_arguments( bound_args: Mapping[str, Any], type_annotations: List[Type], function_name: str ) -> None: """Check bound arguments against type annotations and raise a ValueError if any do not match. :param bound_args: Bound arguments to check. :type bound_args: Mapping[str, Any] :param type_annotations: Type annotations to check bound_args against. :type type_annotations: List[Type] :param function_name: Name of the function that is being checked for, used in the exception if raised. :type function_name: str :raises ValueError: If a bound argument does not match the parameter type. """ for bound_arg_name, bound_arg_value in bound_args.arguments.items(): annotated_type = type_annotations[bound_arg_name] valid_type_list = _get_type_list(annotated_type) if not any( [isinstance(bound_arg_value, valid_type) for valid_type in valid_type_list] ): raise ValueError( f"Incorrect argument type: {bound_arg_name} for {function_name} should be {_get_readable_name_for_type(annotated_type)} but received {_strip_extra_from_class_name(str(type(bound_arg_value)))}." )
[docs] def validate_argument_types(f: Callable) -> Callable: """Validates types of arguments passed to a dataframe API style function. Arguments will need to be either strings, columns, or match the typehints of f. This function is meant to be used a decorator. :param f: Function to validate for. :type f: Callable :return: f wrapped with type validation checks. :rtype: Callable """ def validated_function(*args, **kwargs) -> Column: # all arguments are Columns or strings are always legal, so only check types when one of the arguments is not a column if not all( [ isinstance(x, (Column, ConnectColumn)) or isinstance(x, str) for x in itertools.chain(args, kwargs.values()) ] ): bound_args = _get_bound_arguments(f, *args, **kwargs) type_annotations = typing.get_type_hints(f) _check_bound_arguments(bound_args, type_annotations, f.__name__) return f(*args, **kwargs) return functools.update_wrapper(validated_function, f)