# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
from typing import List
import numpy as np
from .data_buffer import DataBuffer
[docs]
class SampleModel(ABC):
"""The SampleModel class and its subclasses are defined according to the data structure of
SampleModel class in Java AWT.
"""
TYPE_BANDED = 1
TYPE_PIXEL_INTERLEAVED = 2
TYPE_SINGLE_PIXEL_PACKED = 3
TYPE_MULTI_PIXEL_PACKED = 4
TYPE_COMPONENT_JAI = 5
TYPE_COMPONENT = 6
sample_model_type: int
data_type: int
width: int
height: int
[docs]
def __init__(self, sample_model_type, data_type, width, height):
self.sample_model_type = sample_model_type
self.data_type = data_type
self.width = width
self.height = height
[docs]
@abstractmethod
def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
raise NotImplementedError(
"Abstract method as_numpy was not implemented by subclass"
)
[docs]
class ComponentSampleModel(SampleModel):
pixel_stride: int
scanline_stride: int
bank_indices: List[int]
band_offsets: List[int]
[docs]
def __init__(
self,
data_type,
width,
height,
pixel_stride,
scanline_stride,
bank_indices,
band_offsets,
):
super().__init__(SampleModel.TYPE_COMPONENT, data_type, width, height)
self.pixel_stride = pixel_stride
self.scanline_stride = scanline_stride
self.bank_indices = bank_indices
self.band_offsets = band_offsets
[docs]
def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
if self.scanline_stride == self.width and self.pixel_stride == 1:
# Fast path: no gaps between pixels
band_arrs = []
for bank_index in self.bank_indices:
bank_data = data_buffer.bank_data[bank_index]
offset = self.band_offsets[bank_index]
if offset != 0:
bank_data = bank_data[offset : (offset + self.width * self.height)]
band_arr = bank_data.reshape(self.height, self.width)
band_arrs.append(band_arr)
return np.array(band_arrs)
else:
# Slow path
band_arrs = []
for k in range(len(self.bank_indices)):
bank_index = self.bank_indices[k]
bank_data = data_buffer.bank_data[bank_index]
offset = self.band_offsets[k]
band_pixel_data = []
for y in range(self.height):
for x in range(self.width):
pos = offset + y * self.scanline_stride + x * self.pixel_stride
band_pixel_data.append(bank_data[pos])
arr = np.array(band_pixel_data).reshape(self.height, self.width)
band_arrs.append(arr)
return np.array(band_arrs)
[docs]
class PixelInterleavedSampleModel(SampleModel):
pixel_stride: int
scanline_stride: int
band_offsets: List[int]
[docs]
def __init__(
self, data_type, width, height, pixel_stride, scanline_stride, band_offsets
):
super().__init__(SampleModel.TYPE_PIXEL_INTERLEAVED, data_type, width, height)
self.pixel_stride = pixel_stride
self.scanline_stride = scanline_stride
self.band_offsets = band_offsets
[docs]
def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
num_bands = len(self.band_offsets)
bank_data = data_buffer.bank_data[0]
if (
self.pixel_stride == num_bands
and self.scanline_stride == self.width * num_bands
and self.band_offsets == list(range(0, num_bands))
):
# Fast path: no gapping in between band data, no band reordering
arr = bank_data.reshape(self.height, self.width, num_bands)
return np.transpose(arr, [2, 0, 1])
else:
# Slow path
pixel_data = []
for y in range(self.height):
for x in range(self.width):
begin = y * self.scanline_stride + x * self.pixel_stride
end = begin + num_bands
pixel = bank_data[begin:end][self.band_offsets]
pixel_data.append(pixel)
arr = np.array(pixel_data).reshape(self.height, self.width, num_bands)
return np.transpose(arr, [2, 0, 1])
[docs]
class SinglePixelPackedSampleModel(SampleModel):
scanline_stride: int
bit_masks: List[int]
bit_offsets: List[int]
[docs]
def __init__(self, data_type, width, height, scanline_stride, bit_masks):
super().__init__(SampleModel.TYPE_SINGLE_PIXEL_PACKED, data_type, width, height)
self.scanline_stride = scanline_stride
self.bit_masks = bit_masks
self.bit_offsets = []
for v in self.bit_masks:
self.bit_offsets.append((v & -v).bit_length() - 1)
[docs]
def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
num_bands = len(self.bit_masks)
bank_data = data_buffer.bank_data[0]
pixel_data = []
for y in range(self.height):
for x in range(self.width):
pos = y * self.scanline_stride + x
value = bank_data[pos]
pixel = []
for mask, bit_offset in zip(self.bit_masks, self.bit_offsets):
pixel.append((value & mask) >> bit_offset)
pixel_data.append(pixel)
arr = np.array(pixel_data, dtype=bank_data.dtype).reshape(
self.height, self.width, num_bands
)
return np.transpose(arr, [2, 0, 1])
[docs]
class MultiPixelPackedSampleModel(SampleModel):
num_bits: int
scanline_stride: int
data_bit_offset: int
[docs]
def __init__(
self, data_type, width, height, num_bits, scanline_stride, data_bit_offset
):
super().__init__(SampleModel.TYPE_MULTI_PIXEL_PACKED, data_type, width, height)
self.num_bits = num_bits
self.scanline_stride = scanline_stride
self.data_bit_offset = data_bit_offset
[docs]
def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
bank_data = data_buffer.bank_data[0]
bits_per_value = bank_data.dtype.itemsize * 8
pixel_per_value = bits_per_value / self.num_bits
shift_right = bits_per_value - self.num_bits
mask = ((1 << self.num_bits) - 1) << shift_right
band_data = []
for y in range(self.height):
pos = y * self.scanline_stride + self.data_bit_offset // bits_per_value
value = bank_data[pos]
shift = self.data_bit_offset % bits_per_value
value = value << shift
pixels: List[int] = []
while len(pixels) < self.width:
while shift < bits_per_value and len(pixels) < self.width:
pixels.append((value & mask) >> shift_right)
value = value << self.num_bits
shift += self.num_bits
pos += 1
value = bank_data[pos]
shift = 0
band_data.append(np.array(pixels, dtype=bank_data.dtype))
return np.array(band_data).reshape(1, self.height, self.width)