vkit.pipeline.text_detection.page_text_region_cropping

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Sequence, Tuple, List, Optional
 15import itertools
 16
 17import attrs
 18from numpy.random import Generator as RandomGenerator
 19from shapely.strtree import STRtree
 20from shapely.geometry import Point as ShapelyPoint
 21import cv2 as cv
 22
 23from vkit.element import Box, Mask, ScoreMap, Image
 24from vkit.mechanism.distortion import rotate
 25from vkit.mechanism.cropper import Cropper
 26from ..interface import PipelineStep, PipelineStepFactory
 27from .page_cropping import PageCroppingStepOutput
 28from .page_text_region import PageTextRegionStepOutput
 29from .page_text_region_label import (
 30    PageCharRegressionLabelTag,
 31    PageCharRegressionLabel,
 32    PageTextRegionLabelStepOutput,
 33)
 34
 35
 36@attrs.define
 37class PageTextRegionCroppingStepConfig:
 38    core_size: int
 39    pad_size: int
 40    num_samples_factor_relative_to_num_cropped_pages: float = 1.0
 41    num_centroid_points_min: int = 10
 42    num_deviate_points_min: int = 10
 43    pad_value: int = 0
 44    enable_downsample_labeling: bool = True
 45    downsample_labeling_factor: int = 2
 46
 47
 48@attrs.define
 49class PageTextRegionCroppingStepInput:
 50    page_cropping_step_output: PageCroppingStepOutput
 51    page_text_region_step_output: PageTextRegionStepOutput
 52    page_text_region_label_step_output: PageTextRegionLabelStepOutput
 53
 54
 55@attrs.define
 56class DownsampledLabel:
 57    shape: Tuple[int, int]
 58    page_char_mask: Mask
 59    page_char_height_score_map: ScoreMap
 60    page_char_gaussian_score_map: ScoreMap
 61    page_char_regression_labels: Sequence[PageCharRegressionLabel]
 62    page_char_bounding_box_mask: Mask
 63    target_core_box: Box
 64
 65
 66@attrs.define
 67class CroppedPageTextRegion:
 68    page_image: Image
 69    page_char_mask: Mask
 70    page_char_height_score_map: ScoreMap
 71    page_char_gaussian_score_map: ScoreMap
 72    page_char_regression_labels: Sequence[PageCharRegressionLabel]
 73    page_char_bounding_box_mask: Mask
 74    target_core_box: Box
 75    downsampled_label: Optional[DownsampledLabel]
 76
 77
 78@attrs.define
 79class PageTextRegionCroppingStepOutput:
 80    cropped_page_text_regions: Sequence[CroppedPageTextRegion]
 81
 82
 83class PageTextRegionCroppingStep(
 84    PipelineStep[
 85        PageTextRegionCroppingStepConfig,
 86        PageTextRegionCroppingStepInput,
 87        PageTextRegionCroppingStepOutput,
 88    ]
 89):  # yapf: disable
 90
 91    @classmethod
 92    def build_strtree_for_page_char_regression_labels(
 93        cls,
 94        labels: Sequence[PageCharRegressionLabel],
 95    ):
 96        shapely_points: List[ShapelyPoint] = []
 97
 98        for label in labels:
 99            # Original resolution.
100            assert not label.is_downsampled
101            # As int.
102            xy_pair = (label.downsampled_label_point_x, label.downsampled_label_point_y)
103            shapely_points.append(ShapelyPoint(*xy_pair))
104
105        strtree = STRtree(shapely_points)
106        return strtree
107
108    def sample_cropped_page_text_regions(
109        self,
110        page_image: Image,
111        shape_before_rotate: Tuple[int, int],
112        rotate_angle: int,
113        page_char_mask: Mask,
114        page_char_height_score_map: ScoreMap,
115        page_char_gaussian_score_map: ScoreMap,
116        page_char_bounding_box_mask: Mask,
117        centroid_strtree: STRtree,
118        centroid_page_char_regression_labels: Sequence[PageCharRegressionLabel],
119        deviate_strtree: STRtree,
120        deviate_page_char_regression_labels: Sequence[PageCharRegressionLabel],
121        rng: RandomGenerator,
122    ):
123        if rotate_angle != 0:
124            cropper_before_rotate = Cropper.create_from_random_proposal(
125                shape=shape_before_rotate,
126                core_size=self.config.core_size,
127                pad_size=self.config.pad_size,
128                pad_value=self.config.pad_value,
129                rng=rng,
130            )
131            original_box_before_rotate = cropper_before_rotate.cropper_state.original_box
132            center_point_before_rotate = original_box_before_rotate.get_center_point()
133
134            rotated_result = rotate.distort(
135                {'angle': rotate_angle},
136                shapable_or_shape=shape_before_rotate,
137                point=center_point_before_rotate,
138            )
139            assert rotated_result.shape == page_image.shape
140            center_point = rotated_result.point
141            assert center_point
142
143            cropper = Cropper.create_from_center_point(
144                shape=page_image.shape,
145                core_size=self.config.core_size,
146                pad_size=self.config.pad_size,
147                pad_value=self.config.pad_value,
148                center_point=center_point,
149            )
150
151        else:
152            cropper = Cropper.create_from_random_proposal(
153                shape=page_image.shape,
154                core_size=self.config.core_size,
155                pad_size=self.config.pad_size,
156                pad_value=self.config.pad_value,
157                rng=rng,
158            )
159
160        # Remove labels out of the original core box.
161        original_core_shapely_polygon = cropper.original_core_box.to_shapely_polygon()
162
163        centroid_labels: List[PageCharRegressionLabel] = []
164        for centroid_page_char_regression_label_idx in sorted(
165            centroid_strtree.query(
166                original_core_shapely_polygon,
167                predicate='intersects',
168            )
169        ):
170            centroid_label = \
171                centroid_page_char_regression_labels[centroid_page_char_regression_label_idx]
172            centroid_labels.append(centroid_label)
173
174        preserved_char_indices = set(centroid_label.char_idx for centroid_label in centroid_labels)
175        deviate_labels: List[PageCharRegressionLabel] = []
176        for deviate_page_char_regression_label_idx in sorted(
177            deviate_strtree.query(
178                original_core_shapely_polygon,
179                predicate='intersects',
180            )
181        ):
182            deviate_label = \
183                deviate_page_char_regression_labels[deviate_page_char_regression_label_idx]
184            if deviate_label.char_idx not in preserved_char_indices:
185                # If the centroid is not preserved, ignore this deviate label as well.
186                continue
187            deviate_labels.append(deviate_label)
188
189        if len(centroid_labels) < self.config.num_centroid_points_min \
190                or len(deviate_labels) < self.config.num_deviate_points_min:
191            return None
192
193        # Shift labels.
194        offset_y = cropper.target_box.up - cropper.original_box.up
195        offset_x = cropper.target_box.left - cropper.original_box.left
196        shifted_centroid_labels = [
197            centroid_label.to_shifted_page_char_regression_label(
198                offset_y=offset_y,
199                offset_x=offset_x,
200            ) for centroid_label in centroid_labels
201        ]
202        shifted_deviate_labels = [
203            deviate_label.to_shifted_page_char_regression_label(
204                offset_y=offset_y,
205                offset_x=offset_x,
206            ) for deviate_label in deviate_labels
207        ]
208
209        # Crop image and labeling maps.
210        page_image = cropper.crop_image(page_image)
211        page_char_mask = cropper.crop_mask(
212            page_char_mask,
213            core_only=True,
214        )
215        page_char_height_score_map = cropper.crop_score_map(
216            page_char_height_score_map,
217            core_only=True,
218        )
219        page_char_gaussian_score_map = cropper.crop_score_map(
220            page_char_gaussian_score_map,
221            core_only=True,
222        )
223        page_char_bounding_box_mask = cropper.crop_mask(
224            page_char_bounding_box_mask,
225            core_only=True,
226        )
227
228        downsampled_label: Optional[DownsampledLabel] = None
229        if self.config.enable_downsample_labeling:
230            downsample_labeling_factor = self.config.downsample_labeling_factor
231
232            assert cropper.crop_size % downsample_labeling_factor == 0
233            downsampled_size = cropper.crop_size // downsample_labeling_factor
234            downsampled_shape = (downsampled_size, downsampled_size)
235
236            assert self.config.pad_size % downsample_labeling_factor == 0
237            assert self.config.core_size % downsample_labeling_factor == 0
238            assert cropper.target_core_box.height \
239                == cropper.target_core_box.width \
240                == self.config.core_size
241
242            downsampled_pad_size = self.config.pad_size // downsample_labeling_factor
243            downsampled_core_size = self.config.core_size // downsample_labeling_factor
244
245            downsampled_target_core_begin = downsampled_pad_size
246            downsampled_target_core_end = downsampled_target_core_begin + downsampled_core_size - 1
247            downsampled_target_core_box = Box(
248                up=downsampled_target_core_begin,
249                down=downsampled_target_core_end,
250                left=downsampled_target_core_begin,
251                right=downsampled_target_core_end,
252            )
253
254            downsampled_page_char_mask = page_char_mask.to_box_detached()
255            downsampled_page_char_mask = \
256                downsampled_page_char_mask.to_resized_mask(
257                    resized_height=downsampled_core_size,
258                    resized_width=downsampled_core_size,
259                    cv_resize_interpolation=cv.INTER_AREA,
260                )
261
262            downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached()
263            downsampled_page_char_height_score_map = \
264                downsampled_page_char_height_score_map.to_resized_score_map(
265                    resized_height=downsampled_core_size,
266                    resized_width=downsampled_core_size,
267                    cv_resize_interpolation=cv.INTER_AREA,
268                )
269
270            downsampled_page_char_gaussian_score_map = \
271                page_char_gaussian_score_map.to_box_detached()
272            downsampled_page_char_gaussian_score_map = \
273                downsampled_page_char_gaussian_score_map.to_resized_score_map(
274                    resized_height=downsampled_core_size,
275                    resized_width=downsampled_core_size,
276                    cv_resize_interpolation=cv.INTER_AREA,
277                )
278
279            downsampled_page_char_bounding_box_mask = \
280                page_char_bounding_box_mask.to_box_detached()
281            downsampled_page_char_bounding_box_mask = \
282                downsampled_page_char_bounding_box_mask.to_resized_mask(
283                    resized_height=downsampled_core_size,
284                    resized_width=downsampled_core_size,
285                    cv_resize_interpolation=cv.INTER_AREA,
286                )
287
288            downsampled_page_char_regression_labels = [
289                label.to_downsampled_page_char_regression_label(
290                    self.config.downsample_labeling_factor
291                ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels)
292            ]
293
294            downsampled_label = DownsampledLabel(
295                shape=downsampled_shape,
296                page_char_mask=downsampled_page_char_mask,
297                page_char_height_score_map=downsampled_page_char_height_score_map,
298                page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map,
299                page_char_regression_labels=downsampled_page_char_regression_labels,
300                page_char_bounding_box_mask=downsampled_page_char_bounding_box_mask,
301                target_core_box=downsampled_target_core_box,
302            )
303
304        return CroppedPageTextRegion(
305            page_image=page_image,
306            page_char_mask=page_char_mask,
307            page_char_height_score_map=page_char_height_score_map,
308            page_char_gaussian_score_map=page_char_gaussian_score_map,
309            page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels,
310            page_char_bounding_box_mask=page_char_bounding_box_mask,
311            target_core_box=cropper.target_core_box,
312            downsampled_label=downsampled_label,
313        )
314
315    def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator):
316        page_cropping_step_output = input.page_cropping_step_output
317        num_cropped_pages = len(page_cropping_step_output.cropped_pages)
318
319        page_text_region_step_output = input.page_text_region_step_output
320        page_image = page_text_region_step_output.page_image
321        shape_before_rotate = page_text_region_step_output.shape_before_rotate
322        rotate_angle = page_text_region_step_output.rotate_angle
323
324        page_text_region_label_step_output = input.page_text_region_label_step_output
325        page_char_mask = page_text_region_label_step_output.page_char_mask
326        page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map
327        page_char_gaussian_score_map = \
328            page_text_region_label_step_output.page_char_gaussian_score_map
329        page_char_regression_labels = \
330            page_text_region_label_step_output.page_char_regression_labels
331        page_char_bounding_box_mask = \
332            page_text_region_label_step_output.page_char_bounding_box_mask
333
334        centroid_page_char_regression_labels = [
335            label for label in page_char_regression_labels
336            if label.tag == PageCharRegressionLabelTag.CENTROID
337        ]
338        centroid_strtree = self.build_strtree_for_page_char_regression_labels(
339            centroid_page_char_regression_labels
340        )
341
342        deviate_page_char_regression_labels = [
343            label for label in page_char_regression_labels
344            if label.tag == PageCharRegressionLabelTag.DEVIATE
345        ]
346        deviate_strtree = self.build_strtree_for_page_char_regression_labels(
347            deviate_page_char_regression_labels
348        )
349
350        num_samples = round(
351            self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages
352        )
353
354        run_count_max = max(3, 2 * num_samples)
355        run_count = 0
356
357        cropped_page_text_regions: List[CroppedPageTextRegion] = []
358
359        while len(cropped_page_text_regions) < num_samples and run_count < run_count_max:
360            cropped_page_text_region = self.sample_cropped_page_text_regions(
361                page_image=page_image,
362                shape_before_rotate=shape_before_rotate,
363                rotate_angle=rotate_angle,
364                page_char_mask=page_char_mask,
365                page_char_height_score_map=page_char_height_score_map,
366                page_char_gaussian_score_map=page_char_gaussian_score_map,
367                page_char_bounding_box_mask=page_char_bounding_box_mask,
368                centroid_strtree=centroid_strtree,
369                centroid_page_char_regression_labels=centroid_page_char_regression_labels,
370                deviate_strtree=deviate_strtree,
371                deviate_page_char_regression_labels=deviate_page_char_regression_labels,
372                rng=rng,
373            )
374            if cropped_page_text_region:
375                cropped_page_text_regions.append(cropped_page_text_region)
376            run_count += 1
377
378        return PageTextRegionCroppingStepOutput(
379            cropped_page_text_regions=cropped_page_text_regions,
380        )
381
382
383page_text_region_cropping_step_factory = PipelineStepFactory(PageTextRegionCroppingStep)
class PageTextRegionCroppingStepConfig:
38class PageTextRegionCroppingStepConfig:
39    core_size: int
40    pad_size: int
41    num_samples_factor_relative_to_num_cropped_pages: float = 1.0
42    num_centroid_points_min: int = 10
43    num_deviate_points_min: int = 10
44    pad_value: int = 0
45    enable_downsample_labeling: bool = True
46    downsample_labeling_factor: int = 2
PageTextRegionCroppingStepConfig( core_size: int, pad_size: int, num_samples_factor_relative_to_num_cropped_pages: float = 1.0, num_centroid_points_min: int = 10, num_deviate_points_min: int = 10, pad_value: int = 0, enable_downsample_labeling: bool = True, downsample_labeling_factor: int = 2)
 2def __init__(self, core_size, pad_size, num_samples_factor_relative_to_num_cropped_pages=attr_dict['num_samples_factor_relative_to_num_cropped_pages'].default, num_centroid_points_min=attr_dict['num_centroid_points_min'].default, num_deviate_points_min=attr_dict['num_deviate_points_min'].default, pad_value=attr_dict['pad_value'].default, enable_downsample_labeling=attr_dict['enable_downsample_labeling'].default, downsample_labeling_factor=attr_dict['downsample_labeling_factor'].default):
 3    self.core_size = core_size
 4    self.pad_size = pad_size
 5    self.num_samples_factor_relative_to_num_cropped_pages = num_samples_factor_relative_to_num_cropped_pages
 6    self.num_centroid_points_min = num_centroid_points_min
 7    self.num_deviate_points_min = num_deviate_points_min
 8    self.pad_value = pad_value
 9    self.enable_downsample_labeling = enable_downsample_labeling
10    self.downsample_labeling_factor = downsample_labeling_factor

Method generated by attrs for class PageTextRegionCroppingStepConfig.

class PageTextRegionCroppingStepInput:
50class PageTextRegionCroppingStepInput:
51    page_cropping_step_output: PageCroppingStepOutput
52    page_text_region_step_output: PageTextRegionStepOutput
53    page_text_region_label_step_output: PageTextRegionLabelStepOutput
PageTextRegionCroppingStepInput( page_cropping_step_output: vkit.pipeline.text_detection.page_cropping.PageCroppingStepOutput, page_text_region_step_output: vkit.pipeline.text_detection.page_text_region.PageTextRegionStepOutput, page_text_region_label_step_output: vkit.pipeline.text_detection.page_text_region_label.PageTextRegionLabelStepOutput)
2def __init__(self, page_cropping_step_output, page_text_region_step_output, page_text_region_label_step_output):
3    self.page_cropping_step_output = page_cropping_step_output
4    self.page_text_region_step_output = page_text_region_step_output
5    self.page_text_region_label_step_output = page_text_region_label_step_output

Method generated by attrs for class PageTextRegionCroppingStepInput.

class DownsampledLabel:
57class DownsampledLabel:
58    shape: Tuple[int, int]
59    page_char_mask: Mask
60    page_char_height_score_map: ScoreMap
61    page_char_gaussian_score_map: ScoreMap
62    page_char_regression_labels: Sequence[PageCharRegressionLabel]
63    page_char_bounding_box_mask: Mask
64    target_core_box: Box
DownsampledLabel( shape: Tuple[int, int], page_char_mask: vkit.element.mask.Mask, page_char_height_score_map: vkit.element.score_map.ScoreMap, page_char_gaussian_score_map: vkit.element.score_map.ScoreMap, page_char_regression_labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel], page_char_bounding_box_mask: vkit.element.mask.Mask, target_core_box: vkit.element.box.Box)
2def __init__(self, shape, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, page_char_bounding_box_mask, target_core_box):
3    self.shape = shape
4    self.page_char_mask = page_char_mask
5    self.page_char_height_score_map = page_char_height_score_map
6    self.page_char_gaussian_score_map = page_char_gaussian_score_map
7    self.page_char_regression_labels = page_char_regression_labels
8    self.page_char_bounding_box_mask = page_char_bounding_box_mask
9    self.target_core_box = target_core_box

Method generated by attrs for class DownsampledLabel.

class CroppedPageTextRegion:
68class CroppedPageTextRegion:
69    page_image: Image
70    page_char_mask: Mask
71    page_char_height_score_map: ScoreMap
72    page_char_gaussian_score_map: ScoreMap
73    page_char_regression_labels: Sequence[PageCharRegressionLabel]
74    page_char_bounding_box_mask: Mask
75    target_core_box: Box
76    downsampled_label: Optional[DownsampledLabel]
CroppedPageTextRegion( page_image: vkit.element.image.Image, page_char_mask: vkit.element.mask.Mask, page_char_height_score_map: vkit.element.score_map.ScoreMap, page_char_gaussian_score_map: vkit.element.score_map.ScoreMap, page_char_regression_labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel], page_char_bounding_box_mask: vkit.element.mask.Mask, target_core_box: vkit.element.box.Box, downsampled_label: Union[vkit.pipeline.text_detection.page_text_region_cropping.DownsampledLabel, NoneType])
 2def __init__(self, page_image, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, page_char_bounding_box_mask, target_core_box, downsampled_label):
 3    self.page_image = page_image
 4    self.page_char_mask = page_char_mask
 5    self.page_char_height_score_map = page_char_height_score_map
 6    self.page_char_gaussian_score_map = page_char_gaussian_score_map
 7    self.page_char_regression_labels = page_char_regression_labels
 8    self.page_char_bounding_box_mask = page_char_bounding_box_mask
 9    self.target_core_box = target_core_box
10    self.downsampled_label = downsampled_label

Method generated by attrs for class CroppedPageTextRegion.

class PageTextRegionCroppingStepOutput:
80class PageTextRegionCroppingStepOutput:
81    cropped_page_text_regions: Sequence[CroppedPageTextRegion]
PageTextRegionCroppingStepOutput( cropped_page_text_regions: Sequence[vkit.pipeline.text_detection.page_text_region_cropping.CroppedPageTextRegion])
2def __init__(self, cropped_page_text_regions):
3    self.cropped_page_text_regions = cropped_page_text_regions

Method generated by attrs for class PageTextRegionCroppingStepOutput.

 84class PageTextRegionCroppingStep(
 85    PipelineStep[
 86        PageTextRegionCroppingStepConfig,
 87        PageTextRegionCroppingStepInput,
 88        PageTextRegionCroppingStepOutput,
 89    ]
 90):  # yapf: disable
 91
 92    @classmethod
 93    def build_strtree_for_page_char_regression_labels(
 94        cls,
 95        labels: Sequence[PageCharRegressionLabel],
 96    ):
 97        shapely_points: List[ShapelyPoint] = []
 98
 99        for label in labels:
100            # Original resolution.
101            assert not label.is_downsampled
102            # As int.
103            xy_pair = (label.downsampled_label_point_x, label.downsampled_label_point_y)
104            shapely_points.append(ShapelyPoint(*xy_pair))
105
106        strtree = STRtree(shapely_points)
107        return strtree
108
109    def sample_cropped_page_text_regions(
110        self,
111        page_image: Image,
112        shape_before_rotate: Tuple[int, int],
113        rotate_angle: int,
114        page_char_mask: Mask,
115        page_char_height_score_map: ScoreMap,
116        page_char_gaussian_score_map: ScoreMap,
117        page_char_bounding_box_mask: Mask,
118        centroid_strtree: STRtree,
119        centroid_page_char_regression_labels: Sequence[PageCharRegressionLabel],
120        deviate_strtree: STRtree,
121        deviate_page_char_regression_labels: Sequence[PageCharRegressionLabel],
122        rng: RandomGenerator,
123    ):
124        if rotate_angle != 0:
125            cropper_before_rotate = Cropper.create_from_random_proposal(
126                shape=shape_before_rotate,
127                core_size=self.config.core_size,
128                pad_size=self.config.pad_size,
129                pad_value=self.config.pad_value,
130                rng=rng,
131            )
132            original_box_before_rotate = cropper_before_rotate.cropper_state.original_box
133            center_point_before_rotate = original_box_before_rotate.get_center_point()
134
135            rotated_result = rotate.distort(
136                {'angle': rotate_angle},
137                shapable_or_shape=shape_before_rotate,
138                point=center_point_before_rotate,
139            )
140            assert rotated_result.shape == page_image.shape
141            center_point = rotated_result.point
142            assert center_point
143
144            cropper = Cropper.create_from_center_point(
145                shape=page_image.shape,
146                core_size=self.config.core_size,
147                pad_size=self.config.pad_size,
148                pad_value=self.config.pad_value,
149                center_point=center_point,
150            )
151
152        else:
153            cropper = Cropper.create_from_random_proposal(
154                shape=page_image.shape,
155                core_size=self.config.core_size,
156                pad_size=self.config.pad_size,
157                pad_value=self.config.pad_value,
158                rng=rng,
159            )
160
161        # Remove labels out of the original core box.
162        original_core_shapely_polygon = cropper.original_core_box.to_shapely_polygon()
163
164        centroid_labels: List[PageCharRegressionLabel] = []
165        for centroid_page_char_regression_label_idx in sorted(
166            centroid_strtree.query(
167                original_core_shapely_polygon,
168                predicate='intersects',
169            )
170        ):
171            centroid_label = \
172                centroid_page_char_regression_labels[centroid_page_char_regression_label_idx]
173            centroid_labels.append(centroid_label)
174
175        preserved_char_indices = set(centroid_label.char_idx for centroid_label in centroid_labels)
176        deviate_labels: List[PageCharRegressionLabel] = []
177        for deviate_page_char_regression_label_idx in sorted(
178            deviate_strtree.query(
179                original_core_shapely_polygon,
180                predicate='intersects',
181            )
182        ):
183            deviate_label = \
184                deviate_page_char_regression_labels[deviate_page_char_regression_label_idx]
185            if deviate_label.char_idx not in preserved_char_indices:
186                # If the centroid is not preserved, ignore this deviate label as well.
187                continue
188            deviate_labels.append(deviate_label)
189
190        if len(centroid_labels) < self.config.num_centroid_points_min \
191                or len(deviate_labels) < self.config.num_deviate_points_min:
192            return None
193
194        # Shift labels.
195        offset_y = cropper.target_box.up - cropper.original_box.up
196        offset_x = cropper.target_box.left - cropper.original_box.left
197        shifted_centroid_labels = [
198            centroid_label.to_shifted_page_char_regression_label(
199                offset_y=offset_y,
200                offset_x=offset_x,
201            ) for centroid_label in centroid_labels
202        ]
203        shifted_deviate_labels = [
204            deviate_label.to_shifted_page_char_regression_label(
205                offset_y=offset_y,
206                offset_x=offset_x,
207            ) for deviate_label in deviate_labels
208        ]
209
210        # Crop image and labeling maps.
211        page_image = cropper.crop_image(page_image)
212        page_char_mask = cropper.crop_mask(
213            page_char_mask,
214            core_only=True,
215        )
216        page_char_height_score_map = cropper.crop_score_map(
217            page_char_height_score_map,
218            core_only=True,
219        )
220        page_char_gaussian_score_map = cropper.crop_score_map(
221            page_char_gaussian_score_map,
222            core_only=True,
223        )
224        page_char_bounding_box_mask = cropper.crop_mask(
225            page_char_bounding_box_mask,
226            core_only=True,
227        )
228
229        downsampled_label: Optional[DownsampledLabel] = None
230        if self.config.enable_downsample_labeling:
231            downsample_labeling_factor = self.config.downsample_labeling_factor
232
233            assert cropper.crop_size % downsample_labeling_factor == 0
234            downsampled_size = cropper.crop_size // downsample_labeling_factor
235            downsampled_shape = (downsampled_size, downsampled_size)
236
237            assert self.config.pad_size % downsample_labeling_factor == 0
238            assert self.config.core_size % downsample_labeling_factor == 0
239            assert cropper.target_core_box.height \
240                == cropper.target_core_box.width \
241                == self.config.core_size
242
243            downsampled_pad_size = self.config.pad_size // downsample_labeling_factor
244            downsampled_core_size = self.config.core_size // downsample_labeling_factor
245
246            downsampled_target_core_begin = downsampled_pad_size
247            downsampled_target_core_end = downsampled_target_core_begin + downsampled_core_size - 1
248            downsampled_target_core_box = Box(
249                up=downsampled_target_core_begin,
250                down=downsampled_target_core_end,
251                left=downsampled_target_core_begin,
252                right=downsampled_target_core_end,
253            )
254
255            downsampled_page_char_mask = page_char_mask.to_box_detached()
256            downsampled_page_char_mask = \
257                downsampled_page_char_mask.to_resized_mask(
258                    resized_height=downsampled_core_size,
259                    resized_width=downsampled_core_size,
260                    cv_resize_interpolation=cv.INTER_AREA,
261                )
262
263            downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached()
264            downsampled_page_char_height_score_map = \
265                downsampled_page_char_height_score_map.to_resized_score_map(
266                    resized_height=downsampled_core_size,
267                    resized_width=downsampled_core_size,
268                    cv_resize_interpolation=cv.INTER_AREA,
269                )
270
271            downsampled_page_char_gaussian_score_map = \
272                page_char_gaussian_score_map.to_box_detached()
273            downsampled_page_char_gaussian_score_map = \
274                downsampled_page_char_gaussian_score_map.to_resized_score_map(
275                    resized_height=downsampled_core_size,
276                    resized_width=downsampled_core_size,
277                    cv_resize_interpolation=cv.INTER_AREA,
278                )
279
280            downsampled_page_char_bounding_box_mask = \
281                page_char_bounding_box_mask.to_box_detached()
282            downsampled_page_char_bounding_box_mask = \
283                downsampled_page_char_bounding_box_mask.to_resized_mask(
284                    resized_height=downsampled_core_size,
285                    resized_width=downsampled_core_size,
286                    cv_resize_interpolation=cv.INTER_AREA,
287                )
288
289            downsampled_page_char_regression_labels = [
290                label.to_downsampled_page_char_regression_label(
291                    self.config.downsample_labeling_factor
292                ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels)
293            ]
294
295            downsampled_label = DownsampledLabel(
296                shape=downsampled_shape,
297                page_char_mask=downsampled_page_char_mask,
298                page_char_height_score_map=downsampled_page_char_height_score_map,
299                page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map,
300                page_char_regression_labels=downsampled_page_char_regression_labels,
301                page_char_bounding_box_mask=downsampled_page_char_bounding_box_mask,
302                target_core_box=downsampled_target_core_box,
303            )
304
305        return CroppedPageTextRegion(
306            page_image=page_image,
307            page_char_mask=page_char_mask,
308            page_char_height_score_map=page_char_height_score_map,
309            page_char_gaussian_score_map=page_char_gaussian_score_map,
310            page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels,
311            page_char_bounding_box_mask=page_char_bounding_box_mask,
312            target_core_box=cropper.target_core_box,
313            downsampled_label=downsampled_label,
314        )
315
316    def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator):
317        page_cropping_step_output = input.page_cropping_step_output
318        num_cropped_pages = len(page_cropping_step_output.cropped_pages)
319
320        page_text_region_step_output = input.page_text_region_step_output
321        page_image = page_text_region_step_output.page_image
322        shape_before_rotate = page_text_region_step_output.shape_before_rotate
323        rotate_angle = page_text_region_step_output.rotate_angle
324
325        page_text_region_label_step_output = input.page_text_region_label_step_output
326        page_char_mask = page_text_region_label_step_output.page_char_mask
327        page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map
328        page_char_gaussian_score_map = \
329            page_text_region_label_step_output.page_char_gaussian_score_map
330        page_char_regression_labels = \
331            page_text_region_label_step_output.page_char_regression_labels
332        page_char_bounding_box_mask = \
333            page_text_region_label_step_output.page_char_bounding_box_mask
334
335        centroid_page_char_regression_labels = [
336            label for label in page_char_regression_labels
337            if label.tag == PageCharRegressionLabelTag.CENTROID
338        ]
339        centroid_strtree = self.build_strtree_for_page_char_regression_labels(
340            centroid_page_char_regression_labels
341        )
342
343        deviate_page_char_regression_labels = [
344            label for label in page_char_regression_labels
345            if label.tag == PageCharRegressionLabelTag.DEVIATE
346        ]
347        deviate_strtree = self.build_strtree_for_page_char_regression_labels(
348            deviate_page_char_regression_labels
349        )
350
351        num_samples = round(
352            self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages
353        )
354
355        run_count_max = max(3, 2 * num_samples)
356        run_count = 0
357
358        cropped_page_text_regions: List[CroppedPageTextRegion] = []
359
360        while len(cropped_page_text_regions) < num_samples and run_count < run_count_max:
361            cropped_page_text_region = self.sample_cropped_page_text_regions(
362                page_image=page_image,
363                shape_before_rotate=shape_before_rotate,
364                rotate_angle=rotate_angle,
365                page_char_mask=page_char_mask,
366                page_char_height_score_map=page_char_height_score_map,
367                page_char_gaussian_score_map=page_char_gaussian_score_map,
368                page_char_bounding_box_mask=page_char_bounding_box_mask,
369                centroid_strtree=centroid_strtree,
370                centroid_page_char_regression_labels=centroid_page_char_regression_labels,
371                deviate_strtree=deviate_strtree,
372                deviate_page_char_regression_labels=deviate_page_char_regression_labels,
373                rng=rng,
374            )
375            if cropped_page_text_region:
376                cropped_page_text_regions.append(cropped_page_text_region)
377            run_count += 1
378
379        return PageTextRegionCroppingStepOutput(
380            cropped_page_text_regions=cropped_page_text_regions,
381        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

@classmethod
def build_strtree_for_page_char_regression_labels( cls, labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel]):
 92    @classmethod
 93    def build_strtree_for_page_char_regression_labels(
 94        cls,
 95        labels: Sequence[PageCharRegressionLabel],
 96    ):
 97        shapely_points: List[ShapelyPoint] = []
 98
 99        for label in labels:
100            # Original resolution.
101            assert not label.is_downsampled
102            # As int.
103            xy_pair = (label.downsampled_label_point_x, label.downsampled_label_point_y)
104            shapely_points.append(ShapelyPoint(*xy_pair))
105
106        strtree = STRtree(shapely_points)
107        return strtree
def sample_cropped_page_text_regions( self, page_image: vkit.element.image.Image, shape_before_rotate: Tuple[int, int], rotate_angle: int, page_char_mask: vkit.element.mask.Mask, page_char_height_score_map: vkit.element.score_map.ScoreMap, page_char_gaussian_score_map: vkit.element.score_map.ScoreMap, page_char_bounding_box_mask: vkit.element.mask.Mask, centroid_strtree: shapely.strtree.STRtree, centroid_page_char_regression_labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel], deviate_strtree: shapely.strtree.STRtree, deviate_page_char_regression_labels: Sequence[vkit.pipeline.text_detection.page_text_region_label.PageCharRegressionLabel], rng: numpy.random._generator.Generator):
109    def sample_cropped_page_text_regions(
110        self,
111        page_image: Image,
112        shape_before_rotate: Tuple[int, int],
113        rotate_angle: int,
114        page_char_mask: Mask,
115        page_char_height_score_map: ScoreMap,
116        page_char_gaussian_score_map: ScoreMap,
117        page_char_bounding_box_mask: Mask,
118        centroid_strtree: STRtree,
119        centroid_page_char_regression_labels: Sequence[PageCharRegressionLabel],
120        deviate_strtree: STRtree,
121        deviate_page_char_regression_labels: Sequence[PageCharRegressionLabel],
122        rng: RandomGenerator,
123    ):
124        if rotate_angle != 0:
125            cropper_before_rotate = Cropper.create_from_random_proposal(
126                shape=shape_before_rotate,
127                core_size=self.config.core_size,
128                pad_size=self.config.pad_size,
129                pad_value=self.config.pad_value,
130                rng=rng,
131            )
132            original_box_before_rotate = cropper_before_rotate.cropper_state.original_box
133            center_point_before_rotate = original_box_before_rotate.get_center_point()
134
135            rotated_result = rotate.distort(
136                {'angle': rotate_angle},
137                shapable_or_shape=shape_before_rotate,
138                point=center_point_before_rotate,
139            )
140            assert rotated_result.shape == page_image.shape
141            center_point = rotated_result.point
142            assert center_point
143
144            cropper = Cropper.create_from_center_point(
145                shape=page_image.shape,
146                core_size=self.config.core_size,
147                pad_size=self.config.pad_size,
148                pad_value=self.config.pad_value,
149                center_point=center_point,
150            )
151
152        else:
153            cropper = Cropper.create_from_random_proposal(
154                shape=page_image.shape,
155                core_size=self.config.core_size,
156                pad_size=self.config.pad_size,
157                pad_value=self.config.pad_value,
158                rng=rng,
159            )
160
161        # Remove labels out of the original core box.
162        original_core_shapely_polygon = cropper.original_core_box.to_shapely_polygon()
163
164        centroid_labels: List[PageCharRegressionLabel] = []
165        for centroid_page_char_regression_label_idx in sorted(
166            centroid_strtree.query(
167                original_core_shapely_polygon,
168                predicate='intersects',
169            )
170        ):
171            centroid_label = \
172                centroid_page_char_regression_labels[centroid_page_char_regression_label_idx]
173            centroid_labels.append(centroid_label)
174
175        preserved_char_indices = set(centroid_label.char_idx for centroid_label in centroid_labels)
176        deviate_labels: List[PageCharRegressionLabel] = []
177        for deviate_page_char_regression_label_idx in sorted(
178            deviate_strtree.query(
179                original_core_shapely_polygon,
180                predicate='intersects',
181            )
182        ):
183            deviate_label = \
184                deviate_page_char_regression_labels[deviate_page_char_regression_label_idx]
185            if deviate_label.char_idx not in preserved_char_indices:
186                # If the centroid is not preserved, ignore this deviate label as well.
187                continue
188            deviate_labels.append(deviate_label)
189
190        if len(centroid_labels) < self.config.num_centroid_points_min \
191                or len(deviate_labels) < self.config.num_deviate_points_min:
192            return None
193
194        # Shift labels.
195        offset_y = cropper.target_box.up - cropper.original_box.up
196        offset_x = cropper.target_box.left - cropper.original_box.left
197        shifted_centroid_labels = [
198            centroid_label.to_shifted_page_char_regression_label(
199                offset_y=offset_y,
200                offset_x=offset_x,
201            ) for centroid_label in centroid_labels
202        ]
203        shifted_deviate_labels = [
204            deviate_label.to_shifted_page_char_regression_label(
205                offset_y=offset_y,
206                offset_x=offset_x,
207            ) for deviate_label in deviate_labels
208        ]
209
210        # Crop image and labeling maps.
211        page_image = cropper.crop_image(page_image)
212        page_char_mask = cropper.crop_mask(
213            page_char_mask,
214            core_only=True,
215        )
216        page_char_height_score_map = cropper.crop_score_map(
217            page_char_height_score_map,
218            core_only=True,
219        )
220        page_char_gaussian_score_map = cropper.crop_score_map(
221            page_char_gaussian_score_map,
222            core_only=True,
223        )
224        page_char_bounding_box_mask = cropper.crop_mask(
225            page_char_bounding_box_mask,
226            core_only=True,
227        )
228
229        downsampled_label: Optional[DownsampledLabel] = None
230        if self.config.enable_downsample_labeling:
231            downsample_labeling_factor = self.config.downsample_labeling_factor
232
233            assert cropper.crop_size % downsample_labeling_factor == 0
234            downsampled_size = cropper.crop_size // downsample_labeling_factor
235            downsampled_shape = (downsampled_size, downsampled_size)
236
237            assert self.config.pad_size % downsample_labeling_factor == 0
238            assert self.config.core_size % downsample_labeling_factor == 0
239            assert cropper.target_core_box.height \
240                == cropper.target_core_box.width \
241                == self.config.core_size
242
243            downsampled_pad_size = self.config.pad_size // downsample_labeling_factor
244            downsampled_core_size = self.config.core_size // downsample_labeling_factor
245
246            downsampled_target_core_begin = downsampled_pad_size
247            downsampled_target_core_end = downsampled_target_core_begin + downsampled_core_size - 1
248            downsampled_target_core_box = Box(
249                up=downsampled_target_core_begin,
250                down=downsampled_target_core_end,
251                left=downsampled_target_core_begin,
252                right=downsampled_target_core_end,
253            )
254
255            downsampled_page_char_mask = page_char_mask.to_box_detached()
256            downsampled_page_char_mask = \
257                downsampled_page_char_mask.to_resized_mask(
258                    resized_height=downsampled_core_size,
259                    resized_width=downsampled_core_size,
260                    cv_resize_interpolation=cv.INTER_AREA,
261                )
262
263            downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached()
264            downsampled_page_char_height_score_map = \
265                downsampled_page_char_height_score_map.to_resized_score_map(
266                    resized_height=downsampled_core_size,
267                    resized_width=downsampled_core_size,
268                    cv_resize_interpolation=cv.INTER_AREA,
269                )
270
271            downsampled_page_char_gaussian_score_map = \
272                page_char_gaussian_score_map.to_box_detached()
273            downsampled_page_char_gaussian_score_map = \
274                downsampled_page_char_gaussian_score_map.to_resized_score_map(
275                    resized_height=downsampled_core_size,
276                    resized_width=downsampled_core_size,
277                    cv_resize_interpolation=cv.INTER_AREA,
278                )
279
280            downsampled_page_char_bounding_box_mask = \
281                page_char_bounding_box_mask.to_box_detached()
282            downsampled_page_char_bounding_box_mask = \
283                downsampled_page_char_bounding_box_mask.to_resized_mask(
284                    resized_height=downsampled_core_size,
285                    resized_width=downsampled_core_size,
286                    cv_resize_interpolation=cv.INTER_AREA,
287                )
288
289            downsampled_page_char_regression_labels = [
290                label.to_downsampled_page_char_regression_label(
291                    self.config.downsample_labeling_factor
292                ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels)
293            ]
294
295            downsampled_label = DownsampledLabel(
296                shape=downsampled_shape,
297                page_char_mask=downsampled_page_char_mask,
298                page_char_height_score_map=downsampled_page_char_height_score_map,
299                page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map,
300                page_char_regression_labels=downsampled_page_char_regression_labels,
301                page_char_bounding_box_mask=downsampled_page_char_bounding_box_mask,
302                target_core_box=downsampled_target_core_box,
303            )
304
305        return CroppedPageTextRegion(
306            page_image=page_image,
307            page_char_mask=page_char_mask,
308            page_char_height_score_map=page_char_height_score_map,
309            page_char_gaussian_score_map=page_char_gaussian_score_map,
310            page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels,
311            page_char_bounding_box_mask=page_char_bounding_box_mask,
312            target_core_box=cropper.target_core_box,
313            downsampled_label=downsampled_label,
314        )
def run( self, input: vkit.pipeline.text_detection.page_text_region_cropping.PageTextRegionCroppingStepInput, rng: numpy.random._generator.Generator):
316    def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator):
317        page_cropping_step_output = input.page_cropping_step_output
318        num_cropped_pages = len(page_cropping_step_output.cropped_pages)
319
320        page_text_region_step_output = input.page_text_region_step_output
321        page_image = page_text_region_step_output.page_image
322        shape_before_rotate = page_text_region_step_output.shape_before_rotate
323        rotate_angle = page_text_region_step_output.rotate_angle
324
325        page_text_region_label_step_output = input.page_text_region_label_step_output
326        page_char_mask = page_text_region_label_step_output.page_char_mask
327        page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map
328        page_char_gaussian_score_map = \
329            page_text_region_label_step_output.page_char_gaussian_score_map
330        page_char_regression_labels = \
331            page_text_region_label_step_output.page_char_regression_labels
332        page_char_bounding_box_mask = \
333            page_text_region_label_step_output.page_char_bounding_box_mask
334
335        centroid_page_char_regression_labels = [
336            label for label in page_char_regression_labels
337            if label.tag == PageCharRegressionLabelTag.CENTROID
338        ]
339        centroid_strtree = self.build_strtree_for_page_char_regression_labels(
340            centroid_page_char_regression_labels
341        )
342
343        deviate_page_char_regression_labels = [
344            label for label in page_char_regression_labels
345            if label.tag == PageCharRegressionLabelTag.DEVIATE
346        ]
347        deviate_strtree = self.build_strtree_for_page_char_regression_labels(
348            deviate_page_char_regression_labels
349        )
350
351        num_samples = round(
352            self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages
353        )
354
355        run_count_max = max(3, 2 * num_samples)
356        run_count = 0
357
358        cropped_page_text_regions: List[CroppedPageTextRegion] = []
359
360        while len(cropped_page_text_regions) < num_samples and run_count < run_count_max:
361            cropped_page_text_region = self.sample_cropped_page_text_regions(
362                page_image=page_image,
363                shape_before_rotate=shape_before_rotate,
364                rotate_angle=rotate_angle,
365                page_char_mask=page_char_mask,
366                page_char_height_score_map=page_char_height_score_map,
367                page_char_gaussian_score_map=page_char_gaussian_score_map,
368                page_char_bounding_box_mask=page_char_bounding_box_mask,
369                centroid_strtree=centroid_strtree,
370                centroid_page_char_regression_labels=centroid_page_char_regression_labels,
371                deviate_strtree=deviate_strtree,
372                deviate_page_char_regression_labels=deviate_page_char_regression_labels,
373                rng=rng,
374            )
375            if cropped_page_text_region:
376                cropped_page_text_regions.append(cropped_page_text_region)
377            run_count += 1
378
379        return PageTextRegionCroppingStepOutput(
380            cropped_page_text_regions=cropped_page_text_regions,
381        )