Source code for sedona.spark.core.SpatialRDD.spatial_rdd

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import List, Optional, Union

import attr
from py4j.java_gateway import get_field, get_method
from pyspark import RDD, SparkContext, StorageLevel
from pyspark.sql import SparkSession

from sedona.spark.core.enums.grid_type import GridType, GridTypeJvm
from sedona.spark.core.enums.index_type import IndexType, IndexTypeJvm
from sedona.spark.core.enums.spatial import SpatialType
from sedona.spark.core.geom.envelope import Envelope
from sedona.spark.core.jvm.translate import (
    JvmSedonaPythonConverter,
    SedonaPythonConverter,
)
from sedona.spark.core.SpatialRDD.spatial_rdd_factory import SpatialRDDFactory
from sedona.spark.utils.decorators import require
from sedona.spark.utils.jvm import JvmStorageLevel
from sedona.spark.utils.spatial_rdd_parser import SedonaPickler
from sedona.spark.utils.types import crs


[docs] @attr.s class SpatialPartitioner: name = attr.ib() jvm_partitioner = attr.ib()
[docs] @classmethod def from_java_class_name(cls, jvm_partitioner) -> "SpatialPartitioner": if jvm_partitioner is not None: jvm_full_name = jvm_partitioner.toString() full_class_name = jvm_full_name.split("@")[0] partitioner = full_class_name.split(".")[-1] else: partitioner = None return cls(partitioner, jvm_partitioner)
[docs] def getGrids(self) -> List[Envelope]: jvm_grids = get_method(self.jvm_partitioner, "getGrids")() number_of_grids = jvm_grids.size() envelopes = [ Envelope.from_jvm_instance(jvm_grids[index]) for index in range(number_of_grids) ] return envelopes
[docs] @attr.s class JvmSpatialRDD: jsrdd = attr.ib() sc = attr.ib(type=SparkContext) tp = attr.ib(type=SpatialType)
[docs] def saveAsObjectFile(self, location: str): self.jsrdd.saveAsObjectFile(location)
[docs] def persist(self, storage_level: StorageLevel): new_jsrdd = self.jsrdd.persist( JvmStorageLevel(self.sc._jvm, storage_level).jvm_instance ) self.jsrdd = new_jsrdd
[docs] def count(self): return self.jsrdd.count()
[docs] def cache(self): return self.persist(StorageLevel.MEMORY_ONLY)
[docs] def unpersist(self): return self.jsrdd.unpersist()
[docs] @attr.s class JvmGrids: jgrid = attr.ib() sc = attr.ib(type=SparkContext)
[docs] class SpatialRDD:
[docs] def __init__(self, sc: Optional[SparkContext] = None): self._do_init(sc) self._srdd = self._jvm_spatial_rdd()
def _do_init(self, sc: Optional[SparkContext] = None): if sc is None: session = SparkSession._instantiatedSession if session is None or session._sc._jsc is None: raise TypeError("Please initialize spark session") else: sc = session._sc self._sc = sc self._jvm = sc._jvm self._jsc = self._sc._jsc self._spatial_partitioned = False self._is_analyzed = False
[docs] def analyze(self) -> bool: """ Analyze SpatialRDD :return: bool, """ self._srdd.analyze() self._is_analyzed = True return self._is_analyzed
[docs] def CRSTransform(self, sourceEpsgCRSCode: crs, targetEpsgCRSCode: crs) -> bool: """ Function transforms coordinates from one crs to another one :param sourceEpsgCRSCode: crs, Coordinate Reference System to transform from :param targetEpsgCRSCode: crs, Coordinate Reference System to transform to :return: bool, True if transforming was correct """ return self._srdd.CRSTransform(sourceEpsgCRSCode, targetEpsgCRSCode)
[docs] def flipCoordinates(self): return self._srdd.flipCoordinates()
[docs] def MinimumBoundingRectangle(self): raise NotImplementedError()
@property def approximateTotalCount(self) -> int: """ :return: """ return get_field(self._srdd, "approximateTotalCount")
[docs] def boundary(self) -> Envelope: """ :return: """ jvm_boundary = self._srdd.boundary() envelope = Envelope.from_jvm_instance(jvm_boundary) return envelope
@property def boundaryEnvelope(self) -> Envelope: """ :return: """ if not self._is_analyzed: raise TypeError("Please use analyze before") java_boundary_envelope = get_field(self._srdd, "boundaryEnvelope") return Envelope.from_jvm_instance(java_boundary_envelope)
[docs] def buildIndex( self, indexType: Union[str, IndexType], buildIndexOnSpatialPartitionedRDD: bool ) -> bool: """ :param indexType: :param buildIndexOnSpatialPartitionedRDD: :return: """ if self._spatial_partitioned or not buildIndexOnSpatialPartitionedRDD: if type(indexType) == str: index_type = IndexTypeJvm(self._jvm, IndexType.from_string(indexType)) elif type(indexType) == IndexType: index_type = IndexTypeJvm(self._jvm, indexType) else: raise TypeError("indexType should be str or IndexType") return self._srdd.buildIndex( index_type.jvm_instance, buildIndexOnSpatialPartitionedRDD ) else: raise AttributeError("Please run spatial partitioning before")
[docs] def countWithoutDuplicates(self) -> int: """ :return: """ return self._srdd.countWithoutDuplicates()
[docs] def countWithoutDuplicatesSPRDD(self) -> int: """ :return: """ return self._srdd.countWithoutDuplicatesSPRDD()
@property def fieldNames(self) -> List[str]: """ :return: """ try: field_names = list(get_field(self._srdd, "fieldNames")) except TypeError: field_names = [] return field_names
[docs] def getCRStransformation(self): """ :return: """ return self._srdd.getCRStransformation()
[docs] def getPartitioner(self) -> SpatialPartitioner: """ :return: """ return SpatialPartitioner.from_java_class_name(self._srdd.getPartitioner())
@require(["GeoSerializerData"]) def getRawSpatialRDD(self): """ :return: """ serialized_spatial_rdd = SedonaPythonConverter( self._jvm ).translate_spatial_rdd_to_python(self._srdd.getRawSpatialRDD()) if not hasattr(self, "_raw_spatial_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) setattr( self, "_raw_spatial_rdd", RDD(serialized_spatial_rdd, self._sc, SedonaPickler()), ) else: self._raw_spatial_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_raw_spatial_rdd")
[docs] def getSampleNumber(self) -> int: """ :return: """ return self._srdd.getSampleNumber()
[docs] def getSourceEpsgCode(self) -> str: """ Function which returns source EPSG code when it is assigned. If not an empty String is returned. :return: str, source epsg code. """ return self._srdd.getSourceEpsgCode()
[docs] def getTargetEpsgCode(self) -> str: """ Function which returns target EPSG code when it is assigned. If not an empty String is returned. :return: str, target epsg code. """ return self._srdd.getTargetEpgsgCode()
@property def grids(self) -> Optional[List[Envelope]]: """ Returns grids for SpatialRDD, it is a list of Envelopes. >> spatial_rdd.grids >> [Envelope(minx=10.0, maxx=12.0, miny=10.0, maxy=12.0)] :return: """ jvm_grids = self.jvm_grids.jgrid if jvm_grids: number_of_grids = jvm_grids.size() envelopes = [ Envelope.from_jvm_instance(jvm_grids[index]) for index in range(number_of_grids) ] return envelopes else: return None @property def jvm_grids(self) -> JvmGrids: jvm_grids = get_field(self._srdd, "grids") return JvmGrids(jgrid=jvm_grids, sc=self._sc) @jvm_grids.setter def jvm_grids(self, jvm_grid: JvmGrids): self._srdd.grids = jvm_grid.jgrid @property def indexedRDD(self): """ :return: """ jrdd = get_field(self._srdd, "indexedRDD") if not hasattr(self, "_indexed_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) RDD.count = lambda x: x._jrdd.count() setattr(self, "_indexed_rdd", RDD(jrdd, self._sc)) else: self._indexed_rdd._jrdd = jrdd return getattr(self, "_indexed_rdd") @indexedRDD.setter def indexedRDD(self, indexed_rdd: RDD): """ :return: """ self._indexed_rdd = indexed_rdd @property def indexedRawRDD(self): jrdd = get_field(self._srdd, "indexedRawRDD") if not hasattr(self, "_indexed_raw_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) RDD.count = lambda x: x._jrdd.count() setattr(self, "_indexed_raw_rdd", RDD(jrdd, self._sc)) else: self._indexed_raw_rdd._jrdd = jrdd return getattr(self, "_indexed_raw_rdd") @indexedRawRDD.setter def indexedRawRDD(self, indexed_raw_rdd: RDD): self._indexed_raw_rdd = indexed_raw_rdd @property def rawSpatialRDD(self): """ :return: """ return self.getRawSpatialRDD() @rawSpatialRDD.setter def rawSpatialRDD(self, spatial_rdd): if isinstance(spatial_rdd, SpatialRDD): self._srdd = spatial_rdd._srdd self._sc = spatial_rdd._sc self._jvm = spatial_rdd._jvm self._spatial_partitioned = spatial_rdd._spatial_partitioned elif isinstance(spatial_rdd, RDD): jrdd = JvmSedonaPythonConverter(self._jvm).translate_python_rdd_to_java( spatial_rdd._jrdd ) self._srdd.setRawSpatialRDD(jrdd) else: self._srdd.setRawSpatialRDD(spatial_rdd)
[docs] def saveAsGeoJSON(self, path: str): """ :param path: :return: """ return self._srdd.saveAsGeoJSON(path)
[docs] def saveAsWKB(self, path: str): """ :param path: :return: """ return self._srdd.saveAsWKB(path)
[docs] def saveAsWKT(self, path: str): """ :param path: :return: """ return self._srdd.saveAsWKT(path)
[docs] def setRawSpatialRDD(self, jrdd): """ :return: """ return self._srdd.setRawSpatialRDD(jrdd)
[docs] def setSampleNumber(self, sampleNumber: int) -> bool: """ :return: """ return self._srdd.setSampleNumber(sampleNumber)
@property def spatialPartitionedRDD(self): from sedona.spark.utils.spatial_rdd_parser import SedonaPickler """ :return: """ serialized_spatial_rdd = SedonaPythonConverter( self._jvm ).translate_spatial_rdd_to_python( get_field(self._srdd, "spatialPartitionedRDD") ) if not hasattr(self, "_spatial_partitioned_rdd"): setattr( self, "_spatial_partitioned_rdd", RDD(serialized_spatial_rdd, self._sc, SedonaPickler()), ) else: self._spatial_partitioned_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_spatial_partitioned_rdd")
[docs] def spatialPartitioning( self, partitioning: Union[str, GridType, SpatialPartitioner, List[Envelope]], num_partitions: Optional[int] = None, ) -> bool: """ Calculate partitions and assign items in this RDD to a partition. :param partitioning: Partitioning type or existing SpatialPartitioner (e.g., one obtained from another SpatialRDD to align partitions among input data) :param num_partitions: If partitioning is a GridType, the target number of partitions into which the RDD should be split. :return: True on success """ return self._spatial_partitioning_impl( partitioning, num_partitions, self._srdd.spatialPartitioning )
[docs] def spatialPartitioningWithoutDuplicates( self, partitioning: Union[str, GridType, SpatialPartitioner, List[Envelope]], num_partitions: Optional[int] = None, ) -> bool: """ Calculate partitions and assign items in this RDD to a partition without introducing duplicates. This is not the desired behaviour for executing joins but is the correct option when partitioning in preparation for a distributed write. :param partitioning: Partitioning type or existing SpatialPartitioner (e.g., one obtained from another SpatialRDD to align partitions among input data) :param num_partitions: If partitioning is a GridType, the target number of partitions into which the RDD should be split. :return: True on success """ return self._spatial_partitioning_impl( partitioning, num_partitions, self._srdd.spatialPartitioningWithoutDuplicates, )
def _spatial_partitioning_impl( self, partitioning: Union[str, GridType, SpatialPartitioner, List[Envelope]], num_partitions: Optional[int], java_method, ) -> bool: if type(partitioning) == str: grid = GridTypeJvm(self._jvm, GridType.from_str(partitioning)).jvm_instance elif type(partitioning) == GridType: grid = GridTypeJvm(self._jvm, partitioning).jvm_instance elif type(partitioning) == SpatialPartitioner: grid = partitioning.jvm_partitioner elif type(partitioning) == list: if isinstance(partitioning[0], Envelope): bytes_data = Envelope.serialize_for_java(partitioning) jvm_envelopes = self._jvm.EnvelopeAdapter.getFromPython(bytes_data) grid = jvm_envelopes else: raise AttributeError("List should consists of Envelopes") else: raise TypeError("Grid does not have correct type") self._spatial_partitioned = True if num_partitions: return java_method(grid, num_partitions) else: return java_method(grid)
[docs] def set_srdd(self, srdd): self._srdd = srdd
[docs] def get_srdd(self): return self._srdd
[docs] def getRawJvmSpatialRDD(self) -> JvmSpatialRDD: return JvmSpatialRDD( jsrdd=self._srdd.getRawSpatialRDD(), sc=self._sc, tp=SpatialType.from_str(self.name), )
@property def rawJvmSpatialRDD(self) -> JvmSpatialRDD: return self.getRawJvmSpatialRDD() @rawJvmSpatialRDD.setter def rawJvmSpatialRDD(self, jsrdd_p: JvmSpatialRDD): if jsrdd_p.tp.value.lower() != self.name: raise TypeError( f"value should be type {self.name} but {jsrdd_p.tp} was found" ) self._sc = jsrdd_p.sc self._jvm = self._sc._jvm self._jsc = self._sc._jsc self.setRawSpatialRDD(jsrdd_p.jsrdd)
[docs] def getJvmSpatialPartitionedRDD(self) -> JvmSpatialRDD: return JvmSpatialRDD( jsrdd=get_field(self._srdd, "spatialPartitionedRDD"), sc=self._sc, tp=SpatialType.from_str(self.name), )
@property def jvmSpatialPartitionedRDD(self) -> JvmSpatialRDD: return self.getJvmSpatialPartitionedRDD() @jvmSpatialPartitionedRDD.setter def jvmSpatialPartitionedRDD(self, jsrdd_p: JvmSpatialRDD): if jsrdd_p.tp.value.lower() != self.name: raise TypeError( f"value should be type {self.name} but {jsrdd_p.tp} was found" ) self._sc = jsrdd_p.sc self._jvm = self._sc._jvm self._jsc = self._sc._jsc self._srdd.jvmSpatialPartitionedRDD = jsrdd_p.jsrdd @property def name(self): name = self.__class__.__name__ return name.replace("RDD", "").lower() @property def _jvm_spatial_rdd(self): spatial_factory = SpatialRDDFactory(self._sc) return spatial_factory.create_spatial_rdd()