Source code for sedona.spark.geopandas.tools.sjoin

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import pyspark.pandas as ps
from pyspark.pandas.internal import SPARK_DEFAULT_INDEX_NAME, InternalFrame
from pyspark.pandas.utils import scol_for
from pyspark.sql.functions import expr

from sedona.spark.geopandas import GeoDataFrame, GeoSeries

# Pre-compiled regex pattern for suffix validation
SUFFIX_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")


def _frame_join(
    left_df: GeoDataFrame,
    right_df: GeoDataFrame,
    how="inner",
    predicate="intersects",
    lsuffix="left",
    rsuffix="right",
    distance=None,
    on_attribute=None,
):
    """Join the GeoDataFrames at the DataFrame level.

    Parameters
    ----------
    left_df : GeoDataFrame
        Left dataset to join
    right_df : GeoDataFrame
        Right dataset to join
    how : str, default 'inner'
        Join type: 'inner', 'left', 'right'
    predicate : str, default 'intersects'
        Spatial predicate to use
    lsuffix : str, default 'left'
        Suffix for left overlapping columns
    rsuffix : str, default 'right'
        Suffix for right overlapping columns
    distance : float, optional
        Distance parameter for dwithin predicate
    on_attribute : list, optional
        Additional columns to join on

    Note: Unlike GeoPandas, Sedona does not preserve key order for performance reasons. Consider using .sort_index() after the join, if you need to preserve the order.

    Returns
    -------
    GeoDataFrame or GeoSeries
        Joined result
    """
    # Predicate mapping
    predicate_map = {
        "intersects": "ST_Intersects",
        "contains": "ST_Contains",
        "within": "ST_Within",
        "touches": "ST_Touches",
        "crosses": "ST_Crosses",
        "overlaps": "ST_Overlaps",
        "dwithin": "ST_DWithin",
        "covers": "ST_Covers",
        "covered_by": "ST_CoveredBy",
        # "contains_properly": "ST_ContainsProperly",  # not supported by Sedona yet
    }

    if predicate not in predicate_map:
        raise ValueError(
            f"Predicate '{predicate}' not supported. Available: {list(predicate_map.keys())}"
        )

    spatial_func = predicate_map[predicate]

    # Get the internal Spark DataFrames
    left_sdf = left_df._internal.spark_frame
    right_sdf = right_df._internal.spark_frame

    # Handle geometry columns - check if they exist and get proper column names
    left_geom_col = None
    right_geom_col = None

    # Find geometry columns in left dataframe
    left_geom_col = left_df.active_geometry_name

    # Find geometry columns in right dataframe
    right_geom_col = right_df.active_geometry_name

    if not left_geom_col:
        raise ValueError("Left dataframe geometry column not set")
    if not right_geom_col:
        raise ValueError("Right dataframe geometry column not set")

    left_geom_expr = f"`{left_geom_col}` as l_geometry"
    right_geom_expr = f"`{right_geom_col}` as r_geometry"

    # Select all columns with geometry
    left_cols = [left_geom_expr] + [
        f"`{field.name}` as l_{field.name}"
        for field in left_sdf.schema.fields
        if field.name != left_geom_col and not field.name.startswith("__")
    ]
    right_cols = [right_geom_expr] + [
        f"`{field.name}` as r_{field.name}"
        for field in right_sdf.schema.fields
        if field.name != right_geom_col and not field.name.startswith("__")
    ]

    left_geo_df = left_sdf.selectExpr(
        *left_cols, f"`{SPARK_DEFAULT_INDEX_NAME}` as index_{lsuffix}"
    )
    right_geo_df = right_sdf.selectExpr(
        *right_cols, f"`{SPARK_DEFAULT_INDEX_NAME}` as index_{rsuffix}"
    )

    # Build spatial join condition
    if predicate == "dwithin":
        if distance is None:
            raise ValueError("Distance parameter is required for 'dwithin' predicate")
        spatial_condition = f"{spatial_func}(l_geometry, r_geometry, {distance})"
    else:
        spatial_condition = f"{spatial_func}(l_geometry, r_geometry)"

    # Add attribute-based join condition if specified
    join_condition = spatial_condition
    if on_attribute:
        for attr in on_attribute:
            join_condition += f" AND l_{attr} = r_{attr}"

    # Perform spatial join based on join type
    if how == "inner":
        spatial_join_df = left_geo_df.alias("l").join(
            right_geo_df.alias("r"), expr(join_condition)
        )
    elif how == "left":
        spatial_join_df = left_geo_df.alias("l").join(
            right_geo_df.alias("r"), expr(join_condition), "left"
        )
    elif how == "right":
        spatial_join_df = left_geo_df.alias("l").join(
            right_geo_df.alias("r"), expr(join_condition), "right"
        )
    else:
        raise ValueError(f"Join type '{how}' not supported")

    # Pick which index to use for the resulting df's index based on 'how'
    index_col = f"index_{lsuffix}" if how in ("inner", "left") else f"index_{rsuffix}"

    # Handle column naming with suffixes
    final_columns = []

    # Add geometry column (always from left for geopandas compatibility)
    final_columns.append("l_geometry as geometry")

    # Add other columns with suffix handling
    left_data_cols = [
        col
        for col in left_geo_df.columns
        if col not in ["l_geometry", f"index_{lsuffix}"]
    ]
    right_data_cols = [
        col
        for col in right_geo_df.columns
        if col not in ["r_geometry", f"index_{rsuffix}"]
    ]

    final_columns.append(f"{index_col} as {SPARK_DEFAULT_INDEX_NAME}")

    if index_col != f"index_{lsuffix}":
        final_columns.append(f"index_{lsuffix}")

    for col_name in left_data_cols:
        base_name = col_name[2:]  # Remove "l_" prefix
        right_col = f"r_{base_name}"

        if right_col in right_data_cols:
            # Column exists in both - apply suffixes
            final_columns.append(f"{col_name} as {base_name}_{lsuffix}")
        else:
            # Column only in left
            final_columns.append(f"{col_name} as {base_name}")

    if index_col != f"index_{rsuffix}":
        final_columns.append(f"index_{rsuffix}")

    for col_name in right_data_cols:
        base_name = col_name[2:]  # Remove "r_" prefix
        left_col = f"l_{base_name}"

        if left_col in left_data_cols:
            # Column exists in both - apply suffixes
            final_columns.append(f"{col_name} as {base_name}_{rsuffix}")
        else:
            # Column only in right
            final_columns.append(f"{col_name} as {base_name}")

    # Select final columns
    result_df = spatial_join_df.selectExpr(*final_columns)
    # Note, we do not .orderBy(SPARK_DEFAULT_INDEX_NAME) to avoid a performance hit

    data_spark_columns = [
        scol_for(result_df, col)
        for col in result_df.columns
        if col != SPARK_DEFAULT_INDEX_NAME
    ]

    internal = InternalFrame(
        spark_frame=result_df,
        index_spark_columns=[scol_for(result_df, SPARK_DEFAULT_INDEX_NAME)],
        data_spark_columns=data_spark_columns,
    )
    return GeoDataFrame(ps.DataFrame(internal))


[docs] def sjoin( left_df: GeoDataFrame, right_df: GeoDataFrame, how="inner", predicate="intersects", lsuffix="left", rsuffix="right", distance=None, on_attribute=None, **kwargs, ) -> GeoDataFrame: """Spatial join of two GeoDataFrames. Parameters ---------- left_df, right_df : GeoDataFrames how : string, default 'inner' The type of join: * 'left': use keys from left_df; retain only left_df geometry column * 'right': use keys from right_df; retain only right_df geometry column * 'inner': use intersection of keys from both dfs; retain only left_df geometry column predicate : string, default 'intersects' Binary predicate. Valid values are determined by the spatial index used. You can check the valid values in left_df or right_df as ``left_df.sindex.valid_query_predicates`` or ``right_df.sindex.valid_query_predicates`` Replaces deprecated ``op`` parameter. lsuffix : string, default 'left' Suffix to apply to overlapping column names (left GeoDataFrame). rsuffix : string, default 'right' Suffix to apply to overlapping column names (right GeoDataFrame). distance : number or array_like, optional Distance(s) around each input geometry within which to query the tree for the 'dwithin' predicate. If array_like, must be one-dimesional with length equal to length of left GeoDataFrame. Required if ``predicate='dwithin'``. on_attribute : string, list or tuple Column name(s) to join on as an additional join restriction on top of the spatial predicate. These must be found in both DataFrames. If set, observations are joined only if the predicate applies and values in specified columns match. Returns ------- GeoDataFrame The joined GeoDataFrame. Examples -------- >>> groceries_w_communities = geopandas.sjoin(groceries, chicago) >>> groceries_w_communities.head() # doctest: +SKIP OBJECTID community geometry 0 16 UPTOWN MULTIPOINT ((-87.65661 41.97321)) 1 18 MORGAN PARK MULTIPOINT ((-87.68136 41.69713)) 2 22 NEAR WEST SIDE MULTIPOINT ((-87.63918 41.86847)) 3 23 NEAR WEST SIDE MULTIPOINT ((-87.65495 41.87783)) 4 27 CHATHAM MULTIPOINT ((-87.62715 41.73623)) [5 rows x 95 columns] Notes ----- Every operation in GeoPandas is planar, i.e. the potential third dimension is not taken into account. """ if kwargs: first = next(iter(kwargs.keys())) raise TypeError(f"sjoin() got an unexpected keyword argument '{first}'") on_attribute = _maybe_make_list(on_attribute) _basic_checks(left_df, right_df, how, lsuffix, rsuffix, on_attribute=on_attribute) joined = _frame_join( left_df, right_df, how=how, predicate=predicate, lsuffix=lsuffix, rsuffix=rsuffix, distance=distance, on_attribute=on_attribute, ) return joined
def _maybe_make_list(obj): if isinstance(obj, tuple): return list(obj) if obj is not None and not isinstance(obj, list): return [obj] return obj def _basic_checks(left_df, right_df, how, lsuffix, rsuffix, on_attribute=None): """Checks the validity of join input parameters. `how` must be one of the valid options. `'index_'` concatenated with `lsuffix` or `rsuffix` must not already exist as columns in the left or right data frames. Parameters ------------ left_df : GeoDataFrame or GeoSeries right_df : GeoDataFrame or GeoSeries how : str, one of 'left', 'right', 'inner' join type lsuffix : str left index suffix rsuffix : str right index suffix on_attribute : list, default None list of column names to merge on along with geometry """ if not isinstance(left_df, GeoDataFrame): raise ValueError(f"'left_df' should be GeoDataFrame, got {type(left_df)}") if not isinstance(right_df, GeoDataFrame): raise ValueError(f"'right_df' should be GeoDataFrame, got {type(right_df)}") allowed_hows = ["inner", "left", "right"] if how not in allowed_hows: raise ValueError(f'`how` was "{how}" but is expected to be in {allowed_hows}') # Check if on_attribute columns exist in both datasets if on_attribute: for attr in on_attribute: if hasattr(left_df, "columns") and attr not in left_df.columns: raise ValueError(f"Column '{attr}' not found in left dataset") if hasattr(right_df, "columns") and attr not in right_df.columns: raise ValueError(f"Column '{attr}' not found in right dataset") # Check for reserved column names that would conflict if lsuffix == rsuffix: raise ValueError("lsuffix and rsuffix cannot be the same") # Validate suffix format (should not contain special characters that would break SQL) if not SUFFIX_PATTERN.match(lsuffix): raise ValueError(f"lsuffix '{lsuffix}' contains invalid characters") if not SUFFIX_PATTERN.match(rsuffix): raise ValueError(f"rsuffix '{rsuffix}' contains invalid characters") def _to_geo_series(df: ps.Series) -> GeoSeries: """ Get the first Series from the DataFrame. Parameters: - df: The input DataFrame. Returns: - GeoSeries: The first Series from the DataFrame. """ return GeoSeries(data=df)