vkit.pipeline.text_detection.page_text_region_cropping

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Sequence, Mapping, Tuple, DefaultDict, List, Optional
 15from collections import defaultdict
 16import itertools
 17import warnings
 18
 19import attrs
 20from numpy.random import Generator as RandomGenerator
 21from shapely.errors import ShapelyDeprecationWarning
 22from shapely.strtree import STRtree
 23from shapely.geometry import Point as ShapelyPoint
 24import cv2 as cv
 25
 26from vkit.element import Box, Mask, ScoreMap, Image
 27from vkit.mechanism.distortion import rotate
 28from vkit.mechanism.cropper import Cropper
 29from ..interface import PipelineStep, PipelineStepFactory
 30from .page_cropping import PageCroppingStepOutput
 31from .page_text_region import PageTextRegionStepOutput
 32from .page_text_region_label import (
 33    PageCharRegressionLabelTag,
 34    PageCharRegressionLabel,
 35    PageTextRegionLabelStepOutput,
 36)
 37
 38# Shapely version has been explicitly locked under 2.0, hence ignore this warning.
 39warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
 40
 41
 42@attrs.define
 43class PageTextRegionCroppingStepConfig:
 44    core_size: int
 45    pad_size: int
 46    num_samples_factor_relative_to_num_cropped_pages: float = 1.0
 47    num_centroid_points_min: int = 10
 48    num_deviate_points_min: int = 10
 49    pad_value: int = 0
 50    enable_downsample_labeling: bool = True
 51    downsample_labeling_factor: int = 2
 52
 53
 54@attrs.define
 55class PageTextRegionCroppingStepInput:
 56    page_cropping_step_output: PageCroppingStepOutput
 57    page_text_region_step_output: PageTextRegionStepOutput
 58    page_text_region_label_step_output: PageTextRegionLabelStepOutput
 59
 60
 61@attrs.define
 62class DownsampledLabel:
 63    shape: Tuple[int, int]
 64    page_char_mask: Mask
 65    page_char_height_score_map: ScoreMap
 66    page_char_gaussian_score_map: ScoreMap
 67    page_char_regression_labels: Sequence[PageCharRegressionLabel]
 68    core_box: Box
 69
 70
 71@attrs.define
 72class CroppedPageTextRegion:
 73    page_image: Image
 74    page_char_mask: Mask
 75    page_char_height_score_map: ScoreMap
 76    page_char_gaussian_score_map: ScoreMap
 77    page_char_regression_labels: Sequence[PageCharRegressionLabel]
 78    core_box: Box
 79    downsampled_label: Optional[DownsampledLabel]
 80
 81
 82@attrs.define
 83class PageTextRegionCroppingStepOutput:
 84    cropped_page_text_regions: Sequence[CroppedPageTextRegion]
 85
 86
 87class PageTextRegionCroppingStep(
 88    PipelineStep[
 89        PageTextRegionCroppingStepConfig,
 90        PageTextRegionCroppingStepInput,
 91        PageTextRegionCroppingStepOutput,
 92    ]
 93):  # yapf: disable
 94
 95    @classmethod
 96    def build_strtree_for_page_char_regression_labels(
 97        cls,
 98        labels: Sequence[PageCharRegressionLabel],
 99    ):
100        shapely_points: List[ShapelyPoint] = []
101
102        xy_pair_to_labels: DefaultDict[
103            Tuple[int, int],
104            List[PageCharRegressionLabel],
105        ] = defaultdict(list)  # yapf: disable
106
107        for label in labels:
108            assert isinstance(label.label_point_x, int)
109            assert isinstance(label.label_point_y, int)
110            xy_pair = (label.label_point_x, label.label_point_y)
111            shapely_points.append(ShapelyPoint(*xy_pair))
112            xy_pair_to_labels[xy_pair].append(label)
113
114        strtree = STRtree(shapely_points)
115        return strtree, xy_pair_to_labels
116
117    def sample_cropped_page_text_regions(
118        self,
119        page_image: Image,
120        shape_before_rotate: Tuple[int, int],
121        rotate_angle: int,
122        page_char_mask: Mask,
123        page_char_height_score_map: ScoreMap,
124        page_char_gaussian_score_map: ScoreMap,
125        centroid_strtree: STRtree,
126        centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]],
127        deviate_strtree: STRtree,
128        deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]],
129        rng: RandomGenerator,
130    ):
131        if rotate_angle != 0:
132            cropper_before_rotate = Cropper.create(
133                shape=shape_before_rotate,
134                core_size=self.config.core_size,
135                pad_size=self.config.pad_size,
136                pad_value=self.config.pad_value,
137                rng=rng,
138            )
139            origin_box_before_rotate = cropper_before_rotate.cropper_state.origin_box
140            center_point_before_rotate = origin_box_before_rotate.get_center_point()
141
142            rotated_result = rotate.distort(
143                {'angle': rotate_angle},
144                shapable_or_shape=shape_before_rotate,
145                point=center_point_before_rotate,
146            )
147            assert rotated_result.shape == page_image.shape
148            center_point = rotated_result.point
149            assert center_point
150
151            cropper = Cropper.create_from_center_point(
152                shape=page_image.shape,
153                core_size=self.config.core_size,
154                pad_size=self.config.pad_size,
155                pad_value=self.config.pad_value,
156                center_point=center_point,
157            )
158
159        else:
160            cropper = Cropper.create(
161                shape=page_image.shape,
162                core_size=self.config.core_size,
163                pad_size=self.config.pad_size,
164                pad_value=self.config.pad_value,
165                rng=rng,
166            )
167
168        # Pick labels.
169        origin_core_shapely_polygon = cropper.origin_core_box.to_shapely_polygon()
170
171        centroid_labels: List[PageCharRegressionLabel] = []
172        for shapely_point in centroid_strtree.query(origin_core_shapely_polygon):
173            if not origin_core_shapely_polygon.intersects(shapely_point):
174                continue
175            assert isinstance(shapely_point, ShapelyPoint)
176            centroid_xy_pair = (int(shapely_point.x), int(shapely_point.y))
177            centroid_labels.extend(centroid_xy_pair_to_labels[centroid_xy_pair])
178
179        deviate_labels: List[PageCharRegressionLabel] = []
180        for shapely_point in deviate_strtree.query(origin_core_shapely_polygon):
181            if not origin_core_shapely_polygon.intersects(shapely_point):
182                continue
183            assert isinstance(shapely_point, ShapelyPoint)
184            deviate_xy_pair = (int(shapely_point.x), int(shapely_point.y))
185            deviate_labels.extend(deviate_xy_pair_to_labels[deviate_xy_pair])
186
187        if len(centroid_labels) < self.config.num_centroid_points_min \
188                or len(deviate_labels) < self.config.num_deviate_points_min:
189            return None
190
191        # Shift labels.
192        offset_y = cropper.target_box.up - cropper.origin_box.up
193        offset_x = cropper.target_box.left - cropper.origin_box.left
194        shifted_centroid_labels = [
195            centroid_label.to_shifted_page_char_regression_label(
196                offset_y=offset_y,
197                offset_x=offset_x,
198            ) for centroid_label in centroid_labels
199        ]
200        shifted_deviate_labels = [
201            deviate_label.to_shifted_page_char_regression_label(
202                offset_y=offset_y,
203                offset_x=offset_x,
204            ) for deviate_label in deviate_labels
205        ]
206
207        # Crop image and score map.
208        page_image = cropper.crop_image(page_image)
209        page_char_mask = cropper.crop_mask(
210            page_char_mask,
211            core_only=True,
212        )
213        page_char_height_score_map = cropper.crop_score_map(
214            page_char_height_score_map,
215            core_only=True,
216        )
217        page_char_gaussian_score_map = cropper.crop_score_map(
218            page_char_gaussian_score_map,
219            core_only=True,
220        )
221
222        downsampled_label: Optional[DownsampledLabel] = None
223        if self.config.enable_downsample_labeling:
224            downsample_labeling_factor = self.config.downsample_labeling_factor
225
226            assert cropper.crop_size % downsample_labeling_factor == 0
227            downsampled_size = cropper.crop_size // downsample_labeling_factor
228            downsampled_shape = (downsampled_size, downsampled_size)
229
230            assert self.config.pad_size % downsample_labeling_factor == 0
231            assert self.config.core_size % downsample_labeling_factor == 0
232            assert cropper.core_box.height == cropper.core_box.width == self.config.core_size
233
234            downsampled_pad_size = self.config.pad_size // downsample_labeling_factor
235            downsampled_core_size = self.config.core_size // downsample_labeling_factor
236
237            downsampled_core_begin = downsampled_pad_size
238            downsampled_core_end = downsampled_core_begin + downsampled_core_size - 1
239            downsampled_core_box = Box(
240                up=downsampled_core_begin,
241                down=downsampled_core_end,
242                left=downsampled_core_begin,
243                right=downsampled_core_end,
244            )
245
246            downsampled_page_char_mask = page_char_mask.to_box_detached()
247            downsampled_page_char_mask = \
248                downsampled_page_char_mask.to_resized_mask(
249                    resized_height=downsampled_core_size,
250                    resized_width=downsampled_core_size,
251                    cv_resize_interpolation=cv.INTER_AREA,
252                )
253
254            downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached()
255            downsampled_page_char_height_score_map = \
256                downsampled_page_char_height_score_map.to_resized_score_map(
257                    resized_height=downsampled_core_size,
258                    resized_width=downsampled_core_size,
259                    cv_resize_interpolation=cv.INTER_AREA,
260                )
261
262            downsampled_page_char_gaussian_score_map = \
263                page_char_gaussian_score_map.to_box_detached()
264            downsampled_page_char_gaussian_score_map = \
265                downsampled_page_char_gaussian_score_map.to_resized_score_map(
266                    resized_height=downsampled_core_size,
267                    resized_width=downsampled_core_size,
268                    cv_resize_interpolation=cv.INTER_AREA,
269                )
270
271            downsampled_page_char_regression_labels = [
272                label.to_downsampled_page_char_regression_label(
273                    self.config.downsample_labeling_factor
274                ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels)
275            ]
276
277            downsampled_label = DownsampledLabel(
278                shape=downsampled_shape,
279                page_char_mask=downsampled_page_char_mask,
280                page_char_height_score_map=downsampled_page_char_height_score_map,
281                page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map,
282                page_char_regression_labels=downsampled_page_char_regression_labels,
283                core_box=downsampled_core_box,
284            )
285
286        return CroppedPageTextRegion(
287            page_image=page_image,
288            page_char_mask=page_char_mask,
289            page_char_height_score_map=page_char_height_score_map,
290            page_char_gaussian_score_map=page_char_gaussian_score_map,
291            page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels,
292            core_box=cropper.core_box,
293            downsampled_label=downsampled_label,
294        )
295
296    def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator):
297        page_cropping_step_output = input.page_cropping_step_output
298        num_cropped_pages = len(page_cropping_step_output.cropped_pages)
299
300        page_text_region_step_output = input.page_text_region_step_output
301        page_image = page_text_region_step_output.page_image
302        shape_before_rotate = page_text_region_step_output.shape_before_rotate
303        rotate_angle = page_text_region_step_output.rotate_angle
304
305        page_text_region_label_step_output = input.page_text_region_label_step_output
306        page_char_mask = page_text_region_label_step_output.page_char_mask
307        page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map
308        page_char_gaussian_score_map = \
309            page_text_region_label_step_output.page_char_gaussian_score_map
310        page_char_regression_labels = \
311            page_text_region_label_step_output.page_char_regression_labels
312
313        (
314            centroid_strtree,
315            centroid_xy_pair_to_labels,
316        ) = self.build_strtree_for_page_char_regression_labels([
317            label for label in page_char_regression_labels
318            if label.tag == PageCharRegressionLabelTag.CENTROID
319        ])
320        (
321            deviate_strtree,
322            deviate_xy_pair_to_labels,
323        ) = self.build_strtree_for_page_char_regression_labels([
324            label for label in page_char_regression_labels
325            if label.tag == PageCharRegressionLabelTag.DEVIATE
326        ])
327
328        num_samples = round(
329            self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages
330        )
331
332        run_count_max = max(3, 2 * num_samples)
333        run_count = 0
334
335        cropped_page_text_regions: List[CroppedPageTextRegion] = []
336
337        while len(cropped_page_text_regions) < num_samples and run_count < run_count_max:
338            cropped_page_text_region = self.sample_cropped_page_text_regions(
339                page_image=page_image,
340                shape_before_rotate=shape_before_rotate,
341                rotate_angle=rotate_angle,
342                page_char_mask=page_char_mask,
343                page_char_height_score_map=page_char_height_score_map,
344                page_char_gaussian_score_map=page_char_gaussian_score_map,
345                centroid_strtree=centroid_strtree,
346                centroid_xy_pair_to_labels=centroid_xy_pair_to_labels,
347                deviate_strtree=deviate_strtree,
348                deviate_xy_pair_to_labels=deviate_xy_pair_to_labels,
349                rng=rng,
350            )
351            if cropped_page_text_region:
352                cropped_page_text_regions.append(cropped_page_text_region)
353            run_count += 1
354
355        return PageTextRegionCroppingStepOutput(
356            cropped_page_text_regions=cropped_page_text_regions,
357        )
358
359
360page_text_region_cropping_step_factory = PipelineStepFactory(PageTextRegionCroppingStep)
class PageTextRegionCroppingStepConfig:
44class PageTextRegionCroppingStepConfig:
45    core_size: int
46    pad_size: int
47    num_samples_factor_relative_to_num_cropped_pages: float = 1.0
48    num_centroid_points_min: int = 10
49    num_deviate_points_min: int = 10
50    pad_value: int = 0
51    enable_downsample_labeling: bool = True
52    downsample_labeling_factor: int = 2
PageTextRegionCroppingStepConfig( core_size: int, pad_size: int, num_samples_factor_relative_to_num_cropped_pages: float = 1.0, num_centroid_points_min: int = 10, num_deviate_points_min: int = 10, pad_value: int = 0, enable_downsample_labeling: bool = True, downsample_labeling_factor: int = 2)
 2def __init__(self, core_size, pad_size, num_samples_factor_relative_to_num_cropped_pages=attr_dict['num_samples_factor_relative_to_num_cropped_pages'].default, num_centroid_points_min=attr_dict['num_centroid_points_min'].default, num_deviate_points_min=attr_dict['num_deviate_points_min'].default, pad_value=attr_dict['pad_value'].default, enable_downsample_labeling=attr_dict['enable_downsample_labeling'].default, downsample_labeling_factor=attr_dict['downsample_labeling_factor'].default):
 3    self.core_size = core_size
 4    self.pad_size = pad_size
 5    self.num_samples_factor_relative_to_num_cropped_pages = num_samples_factor_relative_to_num_cropped_pages
 6    self.num_centroid_points_min = num_centroid_points_min
 7    self.num_deviate_points_min = num_deviate_points_min
 8    self.pad_value = pad_value
 9    self.enable_downsample_labeling = enable_downsample_labeling
10    self.downsample_labeling_factor = downsample_labeling_factor

Method generated by attrs for class PageTextRegionCroppingStepConfig.

class PageTextRegionCroppingStepInput:
56class PageTextRegionCroppingStepInput:
57    page_cropping_step_output: PageCroppingStepOutput
58    page_text_region_step_output: PageTextRegionStepOutput
59    page_text_region_label_step_output: PageTextRegionLabelStepOutput
PageTextRegionCroppingStepInput( page_cropping_step_output: vkit.pipeline.text_detection.page_cropping.PageCroppingStepOutput, page_text_region_step_output: vkit.pipeline.text_detection.page_text_region.PageTextRegionStepOutput, page_text_region_label_step_output: vkit.pipeline.text_detection.page_text_region_label.PageTextRegionLabelStepOutput)
2def __init__(self, page_cropping_step_output, page_text_region_step_output, page_text_region_label_step_output):
3    self.page_cropping_step_output = page_cropping_step_output
4    self.page_text_region_step_output = page_text_region_step_output
5    self.page_text_region_label_step_output = page_text_region_label_step_output

Method generated by attrs for class PageTextRegionCroppingStepInput.

class DownsampledLabel:
63class DownsampledLabel:
64    shape: Tuple[int, int]
65    page_char_mask: Mask
66    page_char_height_score_map: ScoreMap
67    page_char_gaussian_score_map: ScoreMap
68    page_char_regression_labels: Sequence[PageCharRegressionLabel]
69    core_box: Box
DownsampledLabel( shape: Tuple[int, int], page_char_mask: vkit.element.mask.Mask, page_char_height_score_map: vkit.element.score_map.ScoreMap, page_char_gaussian_score_map: vkit.element.score_map.ScoreMap, page_char_regression_labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel], core_box: vkit.element.box.Box)
2def __init__(self, shape, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, core_box):
3    self.shape = shape
4    self.page_char_mask = page_char_mask
5    self.page_char_height_score_map = page_char_height_score_map
6    self.page_char_gaussian_score_map = page_char_gaussian_score_map
7    self.page_char_regression_labels = page_char_regression_labels
8    self.core_box = core_box

Method generated by attrs for class DownsampledLabel.

class CroppedPageTextRegion:
73class CroppedPageTextRegion:
74    page_image: Image
75    page_char_mask: Mask
76    page_char_height_score_map: ScoreMap
77    page_char_gaussian_score_map: ScoreMap
78    page_char_regression_labels: Sequence[PageCharRegressionLabel]
79    core_box: Box
80    downsampled_label: Optional[DownsampledLabel]
CroppedPageTextRegion( page_image: vkit.element.image.Image, page_char_mask: vkit.element.mask.Mask, page_char_height_score_map: vkit.element.score_map.ScoreMap, page_char_gaussian_score_map: vkit.element.score_map.ScoreMap, page_char_regression_labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel], core_box: vkit.element.box.Box, downsampled_label: Union[vkit.pipeline.text_detection.page_text_region_cropping.DownsampledLabel, NoneType])
2def __init__(self, page_image, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, core_box, downsampled_label):
3    self.page_image = page_image
4    self.page_char_mask = page_char_mask
5    self.page_char_height_score_map = page_char_height_score_map
6    self.page_char_gaussian_score_map = page_char_gaussian_score_map
7    self.page_char_regression_labels = page_char_regression_labels
8    self.core_box = core_box
9    self.downsampled_label = downsampled_label

Method generated by attrs for class CroppedPageTextRegion.

class PageTextRegionCroppingStepOutput:
84class PageTextRegionCroppingStepOutput:
85    cropped_page_text_regions: Sequence[CroppedPageTextRegion]
PageTextRegionCroppingStepOutput( cropped_page_text_regions: Sequence[vkit.pipeline.text_detection.page_text_region_cropping.CroppedPageTextRegion])
2def __init__(self, cropped_page_text_regions):
3    self.cropped_page_text_regions = cropped_page_text_regions

Method generated by attrs for class PageTextRegionCroppingStepOutput.

 88class PageTextRegionCroppingStep(
 89    PipelineStep[
 90        PageTextRegionCroppingStepConfig,
 91        PageTextRegionCroppingStepInput,
 92        PageTextRegionCroppingStepOutput,
 93    ]
 94):  # yapf: disable
 95
 96    @classmethod
 97    def build_strtree_for_page_char_regression_labels(
 98        cls,
 99        labels: Sequence[PageCharRegressionLabel],
100    ):
101        shapely_points: List[ShapelyPoint] = []
102
103        xy_pair_to_labels: DefaultDict[
104            Tuple[int, int],
105            List[PageCharRegressionLabel],
106        ] = defaultdict(list)  # yapf: disable
107
108        for label in labels:
109            assert isinstance(label.label_point_x, int)
110            assert isinstance(label.label_point_y, int)
111            xy_pair = (label.label_point_x, label.label_point_y)
112            shapely_points.append(ShapelyPoint(*xy_pair))
113            xy_pair_to_labels[xy_pair].append(label)
114
115        strtree = STRtree(shapely_points)
116        return strtree, xy_pair_to_labels
117
118    def sample_cropped_page_text_regions(
119        self,
120        page_image: Image,
121        shape_before_rotate: Tuple[int, int],
122        rotate_angle: int,
123        page_char_mask: Mask,
124        page_char_height_score_map: ScoreMap,
125        page_char_gaussian_score_map: ScoreMap,
126        centroid_strtree: STRtree,
127        centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]],
128        deviate_strtree: STRtree,
129        deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]],
130        rng: RandomGenerator,
131    ):
132        if rotate_angle != 0:
133            cropper_before_rotate = Cropper.create(
134                shape=shape_before_rotate,
135                core_size=self.config.core_size,
136                pad_size=self.config.pad_size,
137                pad_value=self.config.pad_value,
138                rng=rng,
139            )
140            origin_box_before_rotate = cropper_before_rotate.cropper_state.origin_box
141            center_point_before_rotate = origin_box_before_rotate.get_center_point()
142
143            rotated_result = rotate.distort(
144                {'angle': rotate_angle},
145                shapable_or_shape=shape_before_rotate,
146                point=center_point_before_rotate,
147            )
148            assert rotated_result.shape == page_image.shape
149            center_point = rotated_result.point
150            assert center_point
151
152            cropper = Cropper.create_from_center_point(
153                shape=page_image.shape,
154                core_size=self.config.core_size,
155                pad_size=self.config.pad_size,
156                pad_value=self.config.pad_value,
157                center_point=center_point,
158            )
159
160        else:
161            cropper = Cropper.create(
162                shape=page_image.shape,
163                core_size=self.config.core_size,
164                pad_size=self.config.pad_size,
165                pad_value=self.config.pad_value,
166                rng=rng,
167            )
168
169        # Pick labels.
170        origin_core_shapely_polygon = cropper.origin_core_box.to_shapely_polygon()
171
172        centroid_labels: List[PageCharRegressionLabel] = []
173        for shapely_point in centroid_strtree.query(origin_core_shapely_polygon):
174            if not origin_core_shapely_polygon.intersects(shapely_point):
175                continue
176            assert isinstance(shapely_point, ShapelyPoint)
177            centroid_xy_pair = (int(shapely_point.x), int(shapely_point.y))
178            centroid_labels.extend(centroid_xy_pair_to_labels[centroid_xy_pair])
179
180        deviate_labels: List[PageCharRegressionLabel] = []
181        for shapely_point in deviate_strtree.query(origin_core_shapely_polygon):
182            if not origin_core_shapely_polygon.intersects(shapely_point):
183                continue
184            assert isinstance(shapely_point, ShapelyPoint)
185            deviate_xy_pair = (int(shapely_point.x), int(shapely_point.y))
186            deviate_labels.extend(deviate_xy_pair_to_labels[deviate_xy_pair])
187
188        if len(centroid_labels) < self.config.num_centroid_points_min \
189                or len(deviate_labels) < self.config.num_deviate_points_min:
190            return None
191
192        # Shift labels.
193        offset_y = cropper.target_box.up - cropper.origin_box.up
194        offset_x = cropper.target_box.left - cropper.origin_box.left
195        shifted_centroid_labels = [
196            centroid_label.to_shifted_page_char_regression_label(
197                offset_y=offset_y,
198                offset_x=offset_x,
199            ) for centroid_label in centroid_labels
200        ]
201        shifted_deviate_labels = [
202            deviate_label.to_shifted_page_char_regression_label(
203                offset_y=offset_y,
204                offset_x=offset_x,
205            ) for deviate_label in deviate_labels
206        ]
207
208        # Crop image and score map.
209        page_image = cropper.crop_image(page_image)
210        page_char_mask = cropper.crop_mask(
211            page_char_mask,
212            core_only=True,
213        )
214        page_char_height_score_map = cropper.crop_score_map(
215            page_char_height_score_map,
216            core_only=True,
217        )
218        page_char_gaussian_score_map = cropper.crop_score_map(
219            page_char_gaussian_score_map,
220            core_only=True,
221        )
222
223        downsampled_label: Optional[DownsampledLabel] = None
224        if self.config.enable_downsample_labeling:
225            downsample_labeling_factor = self.config.downsample_labeling_factor
226
227            assert cropper.crop_size % downsample_labeling_factor == 0
228            downsampled_size = cropper.crop_size // downsample_labeling_factor
229            downsampled_shape = (downsampled_size, downsampled_size)
230
231            assert self.config.pad_size % downsample_labeling_factor == 0
232            assert self.config.core_size % downsample_labeling_factor == 0
233            assert cropper.core_box.height == cropper.core_box.width == self.config.core_size
234
235            downsampled_pad_size = self.config.pad_size // downsample_labeling_factor
236            downsampled_core_size = self.config.core_size // downsample_labeling_factor
237
238            downsampled_core_begin = downsampled_pad_size
239            downsampled_core_end = downsampled_core_begin + downsampled_core_size - 1
240            downsampled_core_box = Box(
241                up=downsampled_core_begin,
242                down=downsampled_core_end,
243                left=downsampled_core_begin,
244                right=downsampled_core_end,
245            )
246
247            downsampled_page_char_mask = page_char_mask.to_box_detached()
248            downsampled_page_char_mask = \
249                downsampled_page_char_mask.to_resized_mask(
250                    resized_height=downsampled_core_size,
251                    resized_width=downsampled_core_size,
252                    cv_resize_interpolation=cv.INTER_AREA,
253                )
254
255            downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached()
256            downsampled_page_char_height_score_map = \
257                downsampled_page_char_height_score_map.to_resized_score_map(
258                    resized_height=downsampled_core_size,
259                    resized_width=downsampled_core_size,
260                    cv_resize_interpolation=cv.INTER_AREA,
261                )
262
263            downsampled_page_char_gaussian_score_map = \
264                page_char_gaussian_score_map.to_box_detached()
265            downsampled_page_char_gaussian_score_map = \
266                downsampled_page_char_gaussian_score_map.to_resized_score_map(
267                    resized_height=downsampled_core_size,
268                    resized_width=downsampled_core_size,
269                    cv_resize_interpolation=cv.INTER_AREA,
270                )
271
272            downsampled_page_char_regression_labels = [
273                label.to_downsampled_page_char_regression_label(
274                    self.config.downsample_labeling_factor
275                ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels)
276            ]
277
278            downsampled_label = DownsampledLabel(
279                shape=downsampled_shape,
280                page_char_mask=downsampled_page_char_mask,
281                page_char_height_score_map=downsampled_page_char_height_score_map,
282                page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map,
283                page_char_regression_labels=downsampled_page_char_regression_labels,
284                core_box=downsampled_core_box,
285            )
286
287        return CroppedPageTextRegion(
288            page_image=page_image,
289            page_char_mask=page_char_mask,
290            page_char_height_score_map=page_char_height_score_map,
291            page_char_gaussian_score_map=page_char_gaussian_score_map,
292            page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels,
293            core_box=cropper.core_box,
294            downsampled_label=downsampled_label,
295        )
296
297    def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator):
298        page_cropping_step_output = input.page_cropping_step_output
299        num_cropped_pages = len(page_cropping_step_output.cropped_pages)
300
301        page_text_region_step_output = input.page_text_region_step_output
302        page_image = page_text_region_step_output.page_image
303        shape_before_rotate = page_text_region_step_output.shape_before_rotate
304        rotate_angle = page_text_region_step_output.rotate_angle
305
306        page_text_region_label_step_output = input.page_text_region_label_step_output
307        page_char_mask = page_text_region_label_step_output.page_char_mask
308        page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map
309        page_char_gaussian_score_map = \
310            page_text_region_label_step_output.page_char_gaussian_score_map
311        page_char_regression_labels = \
312            page_text_region_label_step_output.page_char_regression_labels
313
314        (
315            centroid_strtree,
316            centroid_xy_pair_to_labels,
317        ) = self.build_strtree_for_page_char_regression_labels([
318            label for label in page_char_regression_labels
319            if label.tag == PageCharRegressionLabelTag.CENTROID
320        ])
321        (
322            deviate_strtree,
323            deviate_xy_pair_to_labels,
324        ) = self.build_strtree_for_page_char_regression_labels([
325            label for label in page_char_regression_labels
326            if label.tag == PageCharRegressionLabelTag.DEVIATE
327        ])
328
329        num_samples = round(
330            self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages
331        )
332
333        run_count_max = max(3, 2 * num_samples)
334        run_count = 0
335
336        cropped_page_text_regions: List[CroppedPageTextRegion] = []
337
338        while len(cropped_page_text_regions) < num_samples and run_count < run_count_max:
339            cropped_page_text_region = self.sample_cropped_page_text_regions(
340                page_image=page_image,
341                shape_before_rotate=shape_before_rotate,
342                rotate_angle=rotate_angle,
343                page_char_mask=page_char_mask,
344                page_char_height_score_map=page_char_height_score_map,
345                page_char_gaussian_score_map=page_char_gaussian_score_map,
346                centroid_strtree=centroid_strtree,
347                centroid_xy_pair_to_labels=centroid_xy_pair_to_labels,
348                deviate_strtree=deviate_strtree,
349                deviate_xy_pair_to_labels=deviate_xy_pair_to_labels,
350                rng=rng,
351            )
352            if cropped_page_text_region:
353                cropped_page_text_regions.append(cropped_page_text_region)
354            run_count += 1
355
356        return PageTextRegionCroppingStepOutput(
357            cropped_page_text_regions=cropped_page_text_regions,
358        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

@classmethod
def build_strtree_for_page_char_regression_labels( cls, labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel]):
 96    @classmethod
 97    def build_strtree_for_page_char_regression_labels(
 98        cls,
 99        labels: Sequence[PageCharRegressionLabel],
100    ):
101        shapely_points: List[ShapelyPoint] = []
102
103        xy_pair_to_labels: DefaultDict[
104            Tuple[int, int],
105            List[PageCharRegressionLabel],
106        ] = defaultdict(list)  # yapf: disable
107
108        for label in labels:
109            assert isinstance(label.label_point_x, int)
110            assert isinstance(label.label_point_y, int)
111            xy_pair = (label.label_point_x, label.label_point_y)
112            shapely_points.append(ShapelyPoint(*xy_pair))
113            xy_pair_to_labels[xy_pair].append(label)
114
115        strtree = STRtree(shapely_points)
116        return strtree, xy_pair_to_labels
def sample_cropped_page_text_regions( self, page_image: vkit.element.image.Image, shape_before_rotate: Tuple[int, int], rotate_angle: int, page_char_mask: vkit.element.mask.Mask, page_char_height_score_map: vkit.element.score_map.ScoreMap, page_char_gaussian_score_map: vkit.element.score_map.ScoreMap, centroid_strtree: shapely.strtree.STRtree, centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel]], deviate_strtree: shapely.strtree.STRtree, deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel]], rng: numpy.random._generator.Generator):
118    def sample_cropped_page_text_regions(
119        self,
120        page_image: Image,
121        shape_before_rotate: Tuple[int, int],
122        rotate_angle: int,
123        page_char_mask: Mask,
124        page_char_height_score_map: ScoreMap,
125        page_char_gaussian_score_map: ScoreMap,
126        centroid_strtree: STRtree,
127        centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]],
128        deviate_strtree: STRtree,
129        deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]],
130        rng: RandomGenerator,
131    ):
132        if rotate_angle != 0:
133            cropper_before_rotate = Cropper.create(
134                shape=shape_before_rotate,
135                core_size=self.config.core_size,
136                pad_size=self.config.pad_size,
137                pad_value=self.config.pad_value,
138                rng=rng,
139            )
140            origin_box_before_rotate = cropper_before_rotate.cropper_state.origin_box
141            center_point_before_rotate = origin_box_before_rotate.get_center_point()
142
143            rotated_result = rotate.distort(
144                {'angle': rotate_angle},
145                shapable_or_shape=shape_before_rotate,
146                point=center_point_before_rotate,
147            )
148            assert rotated_result.shape == page_image.shape
149            center_point = rotated_result.point
150            assert center_point
151
152            cropper = Cropper.create_from_center_point(
153                shape=page_image.shape,
154                core_size=self.config.core_size,
155                pad_size=self.config.pad_size,
156                pad_value=self.config.pad_value,
157                center_point=center_point,
158            )
159
160        else:
161            cropper = Cropper.create(
162                shape=page_image.shape,
163                core_size=self.config.core_size,
164                pad_size=self.config.pad_size,
165                pad_value=self.config.pad_value,
166                rng=rng,
167            )
168
169        # Pick labels.
170        origin_core_shapely_polygon = cropper.origin_core_box.to_shapely_polygon()
171
172        centroid_labels: List[PageCharRegressionLabel] = []
173        for shapely_point in centroid_strtree.query(origin_core_shapely_polygon):
174            if not origin_core_shapely_polygon.intersects(shapely_point):
175                continue
176            assert isinstance(shapely_point, ShapelyPoint)
177            centroid_xy_pair = (int(shapely_point.x), int(shapely_point.y))
178            centroid_labels.extend(centroid_xy_pair_to_labels[centroid_xy_pair])
179
180        deviate_labels: List[PageCharRegressionLabel] = []
181        for shapely_point in deviate_strtree.query(origin_core_shapely_polygon):
182            if not origin_core_shapely_polygon.intersects(shapely_point):
183                continue
184            assert isinstance(shapely_point, ShapelyPoint)
185            deviate_xy_pair = (int(shapely_point.x), int(shapely_point.y))
186            deviate_labels.extend(deviate_xy_pair_to_labels[deviate_xy_pair])
187
188        if len(centroid_labels) < self.config.num_centroid_points_min \
189                or len(deviate_labels) < self.config.num_deviate_points_min:
190            return None
191
192        # Shift labels.
193        offset_y = cropper.target_box.up - cropper.origin_box.up
194        offset_x = cropper.target_box.left - cropper.origin_box.left
195        shifted_centroid_labels = [
196            centroid_label.to_shifted_page_char_regression_label(
197                offset_y=offset_y,
198                offset_x=offset_x,
199            ) for centroid_label in centroid_labels
200        ]
201        shifted_deviate_labels = [
202            deviate_label.to_shifted_page_char_regression_label(
203                offset_y=offset_y,
204                offset_x=offset_x,
205            ) for deviate_label in deviate_labels
206        ]
207
208        # Crop image and score map.
209        page_image = cropper.crop_image(page_image)
210        page_char_mask = cropper.crop_mask(
211            page_char_mask,
212            core_only=True,
213        )
214        page_char_height_score_map = cropper.crop_score_map(
215            page_char_height_score_map,
216            core_only=True,
217        )
218        page_char_gaussian_score_map = cropper.crop_score_map(
219            page_char_gaussian_score_map,
220            core_only=True,
221        )
222
223        downsampled_label: Optional[DownsampledLabel] = None
224        if self.config.enable_downsample_labeling:
225            downsample_labeling_factor = self.config.downsample_labeling_factor
226
227            assert cropper.crop_size % downsample_labeling_factor == 0
228            downsampled_size = cropper.crop_size // downsample_labeling_factor
229            downsampled_shape = (downsampled_size, downsampled_size)
230
231            assert self.config.pad_size % downsample_labeling_factor == 0
232            assert self.config.core_size % downsample_labeling_factor == 0
233            assert cropper.core_box.height == cropper.core_box.width == self.config.core_size
234
235            downsampled_pad_size = self.config.pad_size // downsample_labeling_factor
236            downsampled_core_size = self.config.core_size // downsample_labeling_factor
237
238            downsampled_core_begin = downsampled_pad_size
239            downsampled_core_end = downsampled_core_begin + downsampled_core_size - 1
240            downsampled_core_box = Box(
241                up=downsampled_core_begin,
242                down=downsampled_core_end,
243                left=downsampled_core_begin,
244                right=downsampled_core_end,
245            )
246
247            downsampled_page_char_mask = page_char_mask.to_box_detached()
248            downsampled_page_char_mask = \
249                downsampled_page_char_mask.to_resized_mask(
250                    resized_height=downsampled_core_size,
251                    resized_width=downsampled_core_size,
252                    cv_resize_interpolation=cv.INTER_AREA,
253                )
254
255            downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached()
256            downsampled_page_char_height_score_map = \
257                downsampled_page_char_height_score_map.to_resized_score_map(
258                    resized_height=downsampled_core_size,
259                    resized_width=downsampled_core_size,
260                    cv_resize_interpolation=cv.INTER_AREA,
261                )
262
263            downsampled_page_char_gaussian_score_map = \
264                page_char_gaussian_score_map.to_box_detached()
265            downsampled_page_char_gaussian_score_map = \
266                downsampled_page_char_gaussian_score_map.to_resized_score_map(
267                    resized_height=downsampled_core_size,
268                    resized_width=downsampled_core_size,
269                    cv_resize_interpolation=cv.INTER_AREA,
270                )
271
272            downsampled_page_char_regression_labels = [
273                label.to_downsampled_page_char_regression_label(
274                    self.config.downsample_labeling_factor
275                ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels)
276            ]
277
278            downsampled_label = DownsampledLabel(
279                shape=downsampled_shape,
280                page_char_mask=downsampled_page_char_mask,
281                page_char_height_score_map=downsampled_page_char_height_score_map,
282                page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map,
283                page_char_regression_labels=downsampled_page_char_regression_labels,
284                core_box=downsampled_core_box,
285            )
286
287        return CroppedPageTextRegion(
288            page_image=page_image,
289            page_char_mask=page_char_mask,
290            page_char_height_score_map=page_char_height_score_map,
291            page_char_gaussian_score_map=page_char_gaussian_score_map,
292            page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels,
293            core_box=cropper.core_box,
294            downsampled_label=downsampled_label,
295        )
def run( self, input: vkit.pipeline.text_detection.page_text_region_cropping.PageTextRegionCroppingStepInput, rng: numpy.random._generator.Generator):
297    def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator):
298        page_cropping_step_output = input.page_cropping_step_output
299        num_cropped_pages = len(page_cropping_step_output.cropped_pages)
300
301        page_text_region_step_output = input.page_text_region_step_output
302        page_image = page_text_region_step_output.page_image
303        shape_before_rotate = page_text_region_step_output.shape_before_rotate
304        rotate_angle = page_text_region_step_output.rotate_angle
305
306        page_text_region_label_step_output = input.page_text_region_label_step_output
307        page_char_mask = page_text_region_label_step_output.page_char_mask
308        page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map
309        page_char_gaussian_score_map = \
310            page_text_region_label_step_output.page_char_gaussian_score_map
311        page_char_regression_labels = \
312            page_text_region_label_step_output.page_char_regression_labels
313
314        (
315            centroid_strtree,
316            centroid_xy_pair_to_labels,
317        ) = self.build_strtree_for_page_char_regression_labels([
318            label for label in page_char_regression_labels
319            if label.tag == PageCharRegressionLabelTag.CENTROID
320        ])
321        (
322            deviate_strtree,
323            deviate_xy_pair_to_labels,
324        ) = self.build_strtree_for_page_char_regression_labels([
325            label for label in page_char_regression_labels
326            if label.tag == PageCharRegressionLabelTag.DEVIATE
327        ])
328
329        num_samples = round(
330            self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages
331        )
332
333        run_count_max = max(3, 2 * num_samples)
334        run_count = 0
335
336        cropped_page_text_regions: List[CroppedPageTextRegion] = []
337
338        while len(cropped_page_text_regions) < num_samples and run_count < run_count_max:
339            cropped_page_text_region = self.sample_cropped_page_text_regions(
340                page_image=page_image,
341                shape_before_rotate=shape_before_rotate,
342                rotate_angle=rotate_angle,
343                page_char_mask=page_char_mask,
344                page_char_height_score_map=page_char_height_score_map,
345                page_char_gaussian_score_map=page_char_gaussian_score_map,
346                centroid_strtree=centroid_strtree,
347                centroid_xy_pair_to_labels=centroid_xy_pair_to_labels,
348                deviate_strtree=deviate_strtree,
349                deviate_xy_pair_to_labels=deviate_xy_pair_to_labels,
350                rng=rng,
351            )
352            if cropped_page_text_region:
353                cropped_page_text_regions.append(cropped_page_text_region)
354            run_count += 1
355
356        return PageTextRegionCroppingStepOutput(
357            cropped_page_text_regions=cropped_page_text_regions,
358        )