# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
from typing import TYPE_CHECKING, Callable, List
# We may be able to achieve streaming rather than complete materialization by using
# with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)
from sedona.spark.sql.st_functions import ST_AsEWKB
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, DataType, ArrayType, MapType
from sedona.spark.sql.types import GeometryType
from pyspark.sql.pandas.types import (
from_arrow_type,
)
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
if TYPE_CHECKING:
import geopandas as gpd
[docs]
def dataframe_to_arrow(df, crs=None):
"""
Collect a DataFrame as a PyArrow Table
In the output Table, geometry will be encoded as a GeoArrow extension type.
The resulting output is compatible with `lonboard.viz()`,
`geopandas.GeoDataFrame.from_arrow()`, or any library compatible with
GeoArrow extension types.
:param df: A Spark DataFrame
:param crs: A CRS-like object (e.g., `pyproj.CRS` or string interpretable by
`pyproj.CRS`). If provided, this will override any CRS present in the output
geometries. If omitted, the CRS will be inferred from the values present in
the output if exactly one CRS is present in the output.
:return:
"""
import pyarrow as pa
col_is_geometry = [isinstance(f.dataType, GeometryType) for f in df.schema.fields]
if not any(col_is_geometry):
return dataframe_to_arrow_raw(df)
df_columns = list(df)
df_column_names = df.schema.fieldNames()
for i, is_geom in enumerate(col_is_geometry):
if is_geom:
df_columns[i] = ST_AsEWKB(df_columns[i]).alias(df_column_names[i])
df_projected = df.select(*df_columns)
table = dataframe_to_arrow_raw(df_projected)
try:
# Using geoarrow-types is the preferred mechanism for Arrow output.
# Using the extension type ensures that the type and its metadata will
# propagate through all pyarrow transformations.
import geoarrow.types as gat
try_register_extension_types()
spec = gat.wkb()
new_cols = [
wrap_geoarrow_extension(col, spec, crs) if is_geom else col
for is_geom, col in zip(col_is_geometry, table.columns)
]
return pa.table(new_cols, table.column_names)
except ImportError:
# In the event that we don't have access to GeoArrow extension types,
# we can still add field metadata that will propagate through some types
# of operations (e.g., writing this table to a file or passing it to
# DuckDB as long as no intermediate transformations were applied).
new_fields = [
(
wrap_geoarrow_field(table.schema.field(i), table[i], crs)
if is_geom
else table.schema.field(i)
)
for i, is_geom in enumerate(col_is_geometry)
]
return table.from_arrays(table.columns, schema=pa.schema(new_fields))
[docs]
def dataframe_to_arrow_raw(df):
"""Backport of toArrow() (available in Spark 4.0)"""
from pyspark.sql.dataframe import DataFrame
assert isinstance(df, DataFrame)
jconf = df.sparkSession._jconf
from pyspark.sql.pandas.types import to_arrow_schema
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
schema = to_arrow_schema(df.schema)
import pyarrow as pa
self_destruct = jconf.arrowPySparkSelfDestructEnabled()
batches = df._collect_as_arrow(split_batches=self_destruct)
# The zero row case can use from_batches() with schema (nothing to cast)
if not batches:
return pa.Table.from_batches([], schema)
# When batches were returned, use cast(schema). This was backported from
# Spark, where presumably there is a good reason that the schemas of batches
# may not necessarily align with that of schema (thus a cast is required)
table = pa.Table.from_batches(batches).cast(schema)
# Ensure only the table has a reference to the batches, so that
# self_destruct (if enabled) is effective
del batches
return table
[docs]
def wrap_geoarrow_extension(col, spec, crs):
if crs is None:
crs = unique_srid_from_ewkb(col)
elif not hasattr(crs, "to_json"):
import pyproj
crs = pyproj.CRS(crs)
return spec.override(crs=crs).to_pyarrow().wrap_array(col)
[docs]
def wrap_geoarrow_field(field, col, crs):
if crs is None:
crs = unique_srid_from_ewkb(col)
if crs is not None:
metadata = f'"crs": {crs_to_json(crs)}'
else:
metadata = ""
return field.with_metadata(
{
"ARROW:extension:name": "geoarrow.wkb",
"ARROW:extension:metadata": "{" + metadata + "}",
}
)
[docs]
def crs_to_json(crs):
if hasattr(crs, "to_json"):
return crs.to_json()
else:
import pyproj
return pyproj.CRS(crs).to_json()
[docs]
def try_register_extension_types():
"""Try to register extension types using geoarrow-types
Do this defensively, because it can fail if the extension type was
registered in some other way (notably: old versions of geoarrow-pyarrow,
which is a dependency of Kepler).
"""
from geoarrow.types.type_pyarrow import register_extension_types
try:
register_extension_types()
except RuntimeError:
pass
[docs]
def unique_srid_from_ewkb(obj):
import pyarrow as pa
import pyarrow.compute as pc
if len(obj) == 0:
return None
# Output shouldn't have mixed endian here
endian = pc.binary_slice(obj, 0, 1).unique()
if len(endian) != 1:
return None
# WKB Z high byte is 0x80
# WKB M high byte is is 0x40
# EWKB SRID high byte is 0x20
# High bytes where the SRID is set would be
# [0x20, 0x20 | 0x40, 0x20 | 0x80, 0x20 | 0x40 | 0x80]
# == [0x20, 0x60, 0xa0, 0xe0]
is_little_endian = endian[0].as_py() == b"\x01"
high_byte = (
pc.binary_slice(obj, 4, 5) if is_little_endian else pc.binary_slice(obj, 1, 2)
)
has_srid = pc.is_in(high_byte, pa.array([b"\x20", b"\x60", b"\xa0", b"\xe0"]))
unique_srids = (
pc.if_else(has_srid, pc.binary_slice(obj, 5, 9), None).unique().drop_null()
)
if len(unique_srids) != 1:
return None
srid_bytes = unique_srids[0].as_py()
endian = "little" if is_little_endian else "big"
epsg_code = int.from_bytes(srid_bytes, endian)
import pyproj
return pyproj.CRS(f"EPSG:{epsg_code}")
def _dedup_names(names: List[str]) -> List[str]:
if len(set(names)) == len(names):
return names
else:
def _gen_dedup(_name: str) -> Callable[[], str]:
_i = itertools.count()
return lambda: f"{_name}_{next(_i)}"
def _gen_identity(_name: str) -> Callable[[], str]:
return lambda: _name
gen_new_name = {
name: _gen_dedup(name) if len(list(group)) > 1 else _gen_identity(name)
for name, group in itertools.groupby(sorted(names))
}
return [gen_new_name[name]() for name in names]
# Backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
def _deduplicate_field_names(dt: DataType) -> DataType:
if isinstance(dt, StructType):
dedup_field_names = _dedup_names(dt.names)
return StructType(
[
StructField(
dedup_field_names[i],
_deduplicate_field_names(field.dataType),
nullable=field.nullable,
)
for i, field in enumerate(dt.fields)
]
)
elif isinstance(dt, ArrayType):
return ArrayType(
_deduplicate_field_names(dt.elementType), containsNull=dt.containsNull
)
elif isinstance(dt, MapType):
return MapType(
_deduplicate_field_names(dt.keyType),
_deduplicate_field_names(dt.valueType),
valueContainsNull=dt.valueContainsNull,
)
else:
return dt
[docs]
def infer_schema(gdf: "gpd.GeoDataFrame") -> StructType:
import pyarrow as pa
fields = gdf.dtypes.reset_index().values.tolist()
geom_fields = []
index = 0
for name, dtype in fields:
if dtype == "geometry":
geom_fields.append((index, name))
continue
index += 1
if not geom_fields:
raise ValueError("No geometry field found in the GeoDataFrame")
pa_schema = pa.Schema.from_pandas(
gdf.drop([name for _, name in geom_fields], axis=1)
)
spark_schema = []
for field in pa_schema:
field_type = field.type
spark_type = from_arrow_type(field_type)
spark_schema.append(StructField(field.name, spark_type, True))
for index, geom_field in geom_fields:
spark_schema.insert(index, StructField(geom_field, GeometryType(), True))
return StructType(spark_schema)
# Modified backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/conversion.py#L632
[docs]
def create_spatial_dataframe(spark: SparkSession, gdf: "gpd.GeoDataFrame") -> DataFrame:
from pyspark.sql.pandas.types import (
to_arrow_type,
)
def reader_func(temp_filename):
return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)
def create_iter_server():
return spark._jvm.ArrowIteratorServer()
schema = infer_schema(gdf)
timezone = spark._jconf.sessionLocalTimeZone()
step = spark._jconf.arrowMaxRecordsPerBatch()
step = step if step > 0 else len(gdf)
pdf_slices = (gdf.iloc[start : start + step] for start in range(0, len(gdf), step))
spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]
arrow_data = [
[
(c, to_arrow_type(t) if t is not None else None, t)
for (_, c), t in zip(pdf_slice.items(), spark_types)
]
for pdf_slice in pdf_slices
]
safecheck = spark._jconf.arrowSafeTypeConversion()
ser = ArrowStreamPandasSerializer(timezone, safecheck)
jiter = spark._sc._serialize_to_jvm(
arrow_data, ser, reader_func, create_iter_server
)
jsparkSession = spark._jsparkSession
jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)
df = DataFrame(jdf, spark)
df._schema = schema
return df