vkit.pipeline.text_detection.page_distortion

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Optional, Union, Mapping, Any, List, Tuple, Sequence, TypeVar, Generic
 15import itertools
 16
 17import attrs
 18from numpy.random import Generator as RandomGenerator
 19import numpy as np
 20
 21from vkit.utility import PathType
 22from vkit.element import (
 23    Point,
 24    PointList,
 25    Polygon,
 26    Mask,
 27    ScoreMap,
 28    Image,
 29)
 30from vkit.mechanism.distortion_policy import (
 31    random_distortion_factory,
 32    RandomDistortionDebug,
 33)
 34from vkit.mechanism.painter import Painter
 35from ..interface import PipelineStep, PipelineStepFactory
 36from .page_layout import DisconnectedTextRegion, NonTextRegion
 37from .page_text_line_label import (
 38    PageCharPolygonCollection,
 39    PageTextLinePolygonCollection,
 40)
 41from .page_assembler import (
 42    PageAssemblerStepOutput,
 43    PageDisconnectedTextRegionCollection,
 44    PageNonTextRegionCollection,
 45)
 46
 47
 48@attrs.define
 49class PageDistortionStepConfig:
 50    random_distortion_factory_config: Optional[Union[Mapping[str, Any], PathType]] = attrs.field(
 51        factory=lambda: {
 52            # NOTE: defocus blur and zoom in blur introduce labeling noise.
 53            # TODO: enhance those blurring methods for page.
 54            'disabled_policy_names': [
 55                'defocus_blur',
 56                'zoom_in_blur',
 57            ],
 58        }
 59    )
 60    enable_debug_random_distortion: bool = False
 61    enable_distorted_char_mask: bool = True
 62    enable_distorted_char_height_score_map: bool = True
 63    enable_debug_distorted_char_heights: bool = False
 64    enable_distorted_text_line_mask: bool = True
 65    enable_distorted_text_line_height_score_map: bool = True
 66    enable_debug_distorted_text_line_heights: bool = False
 67
 68
 69@attrs.define
 70class PageDistortionStepInput:
 71    page_assembler_step_output: PageAssemblerStepOutput
 72
 73
 74@attrs.define
 75class PageDistortionStepOutput:
 76    page_image: Image
 77    page_random_distortion_debug: Optional[RandomDistortionDebug]
 78    page_active_mask: Mask
 79    page_char_polygon_collection: PageCharPolygonCollection
 80    page_char_mask: Optional[Mask]
 81    page_char_height_score_map: Optional[ScoreMap]
 82    page_char_heights: Optional[Sequence[float]]
 83    page_char_heights_debug_image: Optional[Image]
 84    page_text_line_polygon_collection: PageTextLinePolygonCollection
 85    page_text_line_mask: Optional[Mask]
 86    page_text_line_height_score_map: Optional[ScoreMap]
 87    page_text_line_heights: Optional[Sequence[float]]
 88    page_text_line_heights_debug_image: Optional[Image]
 89    page_disconnected_text_region_collection: PageDisconnectedTextRegionCollection
 90    page_non_text_region_collection: PageNonTextRegionCollection
 91
 92
 93_E = TypeVar('_E', Point, Polygon)
 94
 95
 96class ElementFlattener(Generic[_E]):
 97
 98    def __init__(self, grouped_elements: Sequence[Sequence[_E]]):
 99        self.grouped_elements = grouped_elements
100        self.group_sizes = [len(elements) for elements in grouped_elements]
101
102    def flatten(self):
103        return tuple(itertools.chain.from_iterable(self.grouped_elements))
104
105    def unflatten(self, flattened_elements: Sequence[_E]) -> Sequence[Sequence[_E]]:
106        assert len(flattened_elements) == sum(self.group_sizes)
107        grouped_elements: List[Sequence[_E]] = []
108        begin = 0
109        for group_size in self.group_sizes:
110            end = begin + group_size
111            grouped_elements.append(flattened_elements[begin:end])
112            begin = end
113        return grouped_elements
114
115
116class PageDistortionStep(
117    PipelineStep[
118        PageDistortionStepConfig,
119        PageDistortionStepInput,
120        PageDistortionStepOutput,
121    ]
122):  # yapf: disable
123
124    def __init__(self, config: PageDistortionStepConfig):
125        super().__init__(config)
126
127        self.random_distortion = random_distortion_factory.create(
128            self.config.random_distortion_factory_config
129        )
130
131    @classmethod
132    def fill_page_inactive_region(
133        cls,
134        page_image: Image,
135        page_active_mask: Mask,
136        page_bottom_layer_image: Image,
137    ):
138        assert page_image.shape == page_active_mask.shape
139
140        if page_bottom_layer_image.shape != page_image.shape:
141            page_bottom_layer_image = page_bottom_layer_image.to_resized_image(
142                resized_height=page_image.height,
143                resized_width=page_image.width,
144            )
145
146        page_active_mask.to_inverted_mask().fill_image(page_image, page_bottom_layer_image)
147
148    def generate_text_line_labelings(
149        self,
150        distorted_image: Image,
151        text_line_polygons: Sequence[Polygon],
152        text_line_height_points_up: PointList,
153        text_line_height_points_down: PointList,
154        text_line_height_points_group_sizes: Sequence[int],
155    ):
156        text_line_mask: Optional[Mask] = None
157        if self.config.enable_distorted_text_line_mask:
158            text_line_mask = Mask.from_shapable(distorted_image)
159            for polygon in text_line_polygons:
160                polygon.fill_mask(text_line_mask)
161
162        text_line_height_score_map: Optional[ScoreMap] = None
163        text_line_heights: Optional[List[float]] = None
164        text_line_heights_debug_image: Optional[Image] = None
165
166        if self.config.enable_distorted_text_line_height_score_map:
167            np_height_points_up = text_line_height_points_up.to_smooth_np_array()
168            np_height_points_down = text_line_height_points_down.to_smooth_np_array()
169            np_heights: np.ndarray = np.linalg.norm(
170                np_height_points_down - np_height_points_up,
171                axis=1,
172            )
173            # Add one to compensate.
174            np_heights += 1
175            assert sum(text_line_height_points_group_sizes) == np_heights.shape[0]
176
177            text_line_heights = []
178            text_line_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False)
179            begin = 0
180            for polygon, group_size in zip(text_line_polygons, text_line_height_points_group_sizes):
181                end = begin + group_size - 1
182                text_line_height = float(np_heights[begin:end + 1].mean())
183                text_line_heights.append(text_line_height)
184                polygon.fill_score_map(
185                    score_map=text_line_height_score_map,
186                    value=text_line_height,
187                )
188                begin = end + 1
189
190            if self.config.enable_debug_distorted_text_line_heights:
191                painter = Painter.create(distorted_image)
192                painter.paint_polygons(text_line_polygons)
193
194                texts: List[str] = []
195                points = PointList()
196                for polygon, height in zip(text_line_polygons, text_line_heights):
197                    texts.append(f'{height:.1f}')
198                    points.append(polygon.get_center_point())
199                painter.paint_texts(texts, points, alpha=1.0)
200
201                text_line_heights_debug_image = painter.image
202
203        return (
204            text_line_mask,
205            text_line_height_score_map,
206            text_line_heights,
207            text_line_heights_debug_image,
208        )
209
210    def generate_char_labelings(
211        self,
212        distorted_image: Image,
213        char_polygons: Sequence[Polygon],
214        char_height_points_up: PointList,
215        char_height_points_down: PointList,
216    ):
217        char_mask: Optional[Mask] = None
218        if self.config.enable_distorted_text_line_mask:
219            char_mask = Mask.from_shapable(distorted_image)
220            for polygon in char_polygons:
221                polygon.fill_mask(char_mask)
222
223        char_height_score_map: Optional[ScoreMap] = None
224        char_heights: Optional[List[float]] = None
225        char_heights_debug_image: Optional[Image] = None
226
227        if self.config.enable_distorted_char_height_score_map:
228            np_height_points_up = char_height_points_up.to_smooth_np_array()
229            np_height_points_down = char_height_points_down.to_smooth_np_array()
230            np_heights: np.ndarray = np.linalg.norm(
231                np_height_points_down - np_height_points_up,
232                axis=1,
233            )
234            # Add one to compensate.
235            np_heights += 1
236
237            # Fill from large height to small height,
238            # in order to preserve small height labeling when two char boxes overlapped.
239            sorted_char_polygon_indices: Tuple[int, ...] = tuple(reversed(np_heights.argsort()))
240
241            char_heights = [0.0] * len(char_polygons)
242            char_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False)
243
244            for idx in sorted_char_polygon_indices:
245                polygon = char_polygons[idx]
246                char_height = float(np_heights[idx])
247                char_heights[idx] = char_height
248                polygon.fill_score_map(
249                    score_map=char_height_score_map,
250                    value=char_height,
251                )
252
253            if self.config.enable_debug_distorted_char_heights:
254                painter = Painter.create(distorted_image)
255                painter.paint_polygons(char_polygons)
256
257                texts: List[str] = []
258                points = PointList()
259                for polygon, height in zip(char_polygons, char_heights):
260                    texts.append(f'{height:.1f}')
261                    points.append(polygon.get_center_point())
262                painter.paint_texts(texts, points, alpha=1.0)
263
264                char_heights_debug_image = painter.image
265
266        return (
267            char_mask,
268            char_height_score_map,
269            char_heights,
270            char_heights_debug_image,
271        )
272
273    def run(self, input: PageDistortionStepInput, rng: RandomGenerator):
274        page_assembler_step_output = input.page_assembler_step_output
275        page = page_assembler_step_output.page
276        page_bottom_layer_image = page.page_bottom_layer_image
277        page_char_polygon_collection = page.page_char_polygon_collection
278        page_text_line_polygon_collection = page.page_text_line_polygon_collection
279        page_disconnected_text_region_collection = page.page_disconnected_text_region_collection
280        page_non_text_region_collection = page.page_non_text_region_collection
281
282        # Flatten.
283        polygon_flattener = ElementFlattener([
284            # Char level.
285            page_char_polygon_collection.polygons,
286            # Text line level.
287            page_text_line_polygon_collection.polygons,
288            # For char-level polygon regression.
289            tuple(page_disconnected_text_region_collection.to_polygons()),
290            # For sampling negative text region area.
291            tuple(page_non_text_region_collection.to_polygons()),
292        ])
293        point_flattener = ElementFlattener([
294            # Char level.
295            page_char_polygon_collection.height_points_up,
296            page_char_polygon_collection.height_points_down,
297            # Text line level.
298            page_text_line_polygon_collection.height_points_up,
299            page_text_line_polygon_collection.height_points_down,
300        ])
301
302        # Distort.
303        page_random_distortion_debug = None
304        if self.config.enable_debug_random_distortion:
305            page_random_distortion_debug = RandomDistortionDebug()
306
307        page_active_mask = Mask.from_shapable(page.image, value=1)
308        # To mitigate a bug in cv.remap, in which the border interpolation is wrong.
309        # This mitigation DO remove 1-pixel width border, but it should be fine.
310        with page_active_mask.writable_context:
311            page_active_mask.mat[0] = 0
312            page_active_mask.mat[-1] = 0
313            page_active_mask.mat[:, 0] = 0
314            page_active_mask.mat[:, -1] = 0
315
316        result = self.random_distortion.distort(
317            image=page.image,
318            mask=page_active_mask,
319            polygons=polygon_flattener.flatten(),
320            points=PointList(point_flattener.flatten()),
321            rng=rng,
322            debug=page_random_distortion_debug,
323        )
324        assert result.image
325        assert result.mask
326        assert result.polygons
327        assert result.points
328
329        # Fill inplace the inactive (black) region with page_bottom_layer_image.
330        self.fill_page_inactive_region(
331            page_image=result.image,
332            page_active_mask=result.mask,
333            page_bottom_layer_image=page_bottom_layer_image,
334        )
335
336        # Unflatten.
337        (
338            # Char level.
339            char_polygons,
340            # Text line level.
341            text_line_polygons,
342            # For char-level polygon regression.
343            disconnected_text_region_polygons,
344            # For sampling negative text region area.
345            non_text_region_polygons,
346        ) = polygon_flattener.unflatten(result.polygons)
347
348        (
349            # Char level.
350            char_height_points_up,
351            char_height_points_down,
352            # Text line level.
353            text_line_height_points_up,
354            text_line_height_points_down,
355        ) = map(PointList, point_flattener.unflatten(result.points))
356
357        text_line_height_points_group_sizes = \
358            page_text_line_polygon_collection.height_points_group_sizes
359        assert len(text_line_polygons) == len(text_line_height_points_group_sizes)
360        assert len(text_line_height_points_up) == len(text_line_height_points_down)
361
362        # Labelings.
363        (
364            text_line_mask,
365            text_line_height_score_map,
366            text_line_heights,
367            text_line_heights_debug_image,
368        ) = self.generate_text_line_labelings(
369            distorted_image=result.image,
370            text_line_polygons=text_line_polygons,
371            text_line_height_points_up=text_line_height_points_up,
372            text_line_height_points_down=text_line_height_points_down,
373            text_line_height_points_group_sizes=text_line_height_points_group_sizes,
374        )
375        (
376            char_mask,
377            char_height_score_map,
378            char_heights,
379            char_heights_debug_image,
380        ) = self.generate_char_labelings(
381            distorted_image=result.image,
382            char_polygons=char_polygons,
383            char_height_points_up=char_height_points_up,
384            char_height_points_down=char_height_points_down,
385        )
386
387        return PageDistortionStepOutput(
388            page_image=result.image,
389            page_random_distortion_debug=page_random_distortion_debug,
390            page_active_mask=result.mask,
391            page_char_polygon_collection=PageCharPolygonCollection(
392                height=result.image.height,
393                width=result.image.width,
394                polygons=char_polygons,
395                height_points_up=char_height_points_up,
396                height_points_down=char_height_points_down,
397            ),
398            page_char_mask=char_mask,
399            page_char_height_score_map=char_height_score_map,
400            page_char_heights=char_heights,
401            page_char_heights_debug_image=char_heights_debug_image,
402            page_text_line_polygon_collection=PageTextLinePolygonCollection(
403                height=result.image.height,
404                width=result.image.width,
405                polygons=text_line_polygons,
406                height_points_group_sizes=text_line_height_points_group_sizes,
407                height_points_up=text_line_height_points_up,
408                height_points_down=text_line_height_points_down,
409            ),
410            page_text_line_mask=text_line_mask,
411            page_text_line_height_score_map=text_line_height_score_map,
412            page_text_line_heights=text_line_heights,
413            page_text_line_heights_debug_image=text_line_heights_debug_image,
414            page_disconnected_text_region_collection=PageDisconnectedTextRegionCollection(
415                disconnected_text_regions=[
416                    DisconnectedTextRegion(disconnected_text_region_polygon)
417                    for disconnected_text_region_polygon in disconnected_text_region_polygons
418                ],
419            ),
420            page_non_text_region_collection=PageNonTextRegionCollection(
421                non_text_regions=[
422                    NonTextRegion(non_text_region_polygon)
423                    for non_text_region_polygon in non_text_region_polygons
424                ],
425            )
426        )
427
428
429page_distortion_step_factory = PipelineStepFactory(PageDistortionStep)
class PageDistortionStepConfig:
50class PageDistortionStepConfig:
51    random_distortion_factory_config: Optional[Union[Mapping[str, Any], PathType]] = attrs.field(
52        factory=lambda: {
53            # NOTE: defocus blur and zoom in blur introduce labeling noise.
54            # TODO: enhance those blurring methods for page.
55            'disabled_policy_names': [
56                'defocus_blur',
57                'zoom_in_blur',
58            ],
59        }
60    )
61    enable_debug_random_distortion: bool = False
62    enable_distorted_char_mask: bool = True
63    enable_distorted_char_height_score_map: bool = True
64    enable_debug_distorted_char_heights: bool = False
65    enable_distorted_text_line_mask: bool = True
66    enable_distorted_text_line_height_score_map: bool = True
67    enable_debug_distorted_text_line_heights: bool = False
PageDistortionStepConfig( random_distortion_factory_config: Union[Mapping[str, Any], str, os.PathLike, NoneType] = NOTHING, enable_debug_random_distortion: bool = False, enable_distorted_char_mask: bool = True, enable_distorted_char_height_score_map: bool = True, enable_debug_distorted_char_heights: bool = False, enable_distorted_text_line_mask: bool = True, enable_distorted_text_line_height_score_map: bool = True, enable_debug_distorted_text_line_heights: bool = False)
 2def __init__(self, random_distortion_factory_config=NOTHING, enable_debug_random_distortion=attr_dict['enable_debug_random_distortion'].default, enable_distorted_char_mask=attr_dict['enable_distorted_char_mask'].default, enable_distorted_char_height_score_map=attr_dict['enable_distorted_char_height_score_map'].default, enable_debug_distorted_char_heights=attr_dict['enable_debug_distorted_char_heights'].default, enable_distorted_text_line_mask=attr_dict['enable_distorted_text_line_mask'].default, enable_distorted_text_line_height_score_map=attr_dict['enable_distorted_text_line_height_score_map'].default, enable_debug_distorted_text_line_heights=attr_dict['enable_debug_distorted_text_line_heights'].default):
 3    if random_distortion_factory_config is not NOTHING:
 4        self.random_distortion_factory_config = random_distortion_factory_config
 5    else:
 6        self.random_distortion_factory_config = __attr_factory_random_distortion_factory_config()
 7    self.enable_debug_random_distortion = enable_debug_random_distortion
 8    self.enable_distorted_char_mask = enable_distorted_char_mask
 9    self.enable_distorted_char_height_score_map = enable_distorted_char_height_score_map
10    self.enable_debug_distorted_char_heights = enable_debug_distorted_char_heights
11    self.enable_distorted_text_line_mask = enable_distorted_text_line_mask
12    self.enable_distorted_text_line_height_score_map = enable_distorted_text_line_height_score_map
13    self.enable_debug_distorted_text_line_heights = enable_debug_distorted_text_line_heights

Method generated by attrs for class PageDistortionStepConfig.

class PageDistortionStepInput:
71class PageDistortionStepInput:
72    page_assembler_step_output: PageAssemblerStepOutput
PageDistortionStepInput( page_assembler_step_output: vkit.pipeline.text_detection.page_assembler.PageAssemblerStepOutput)
2def __init__(self, page_assembler_step_output):
3    self.page_assembler_step_output = page_assembler_step_output

Method generated by attrs for class PageDistortionStepInput.

class PageDistortionStepOutput:
76class PageDistortionStepOutput:
77    page_image: Image
78    page_random_distortion_debug: Optional[RandomDistortionDebug]
79    page_active_mask: Mask
80    page_char_polygon_collection: PageCharPolygonCollection
81    page_char_mask: Optional[Mask]
82    page_char_height_score_map: Optional[ScoreMap]
83    page_char_heights: Optional[Sequence[float]]
84    page_char_heights_debug_image: Optional[Image]
85    page_text_line_polygon_collection: PageTextLinePolygonCollection
86    page_text_line_mask: Optional[Mask]
87    page_text_line_height_score_map: Optional[ScoreMap]
88    page_text_line_heights: Optional[Sequence[float]]
89    page_text_line_heights_debug_image: Optional[Image]
90    page_disconnected_text_region_collection: PageDisconnectedTextRegionCollection
91    page_non_text_region_collection: PageNonTextRegionCollection
PageDistortionStepOutput( page_image: vkit.element.image.Image, page_random_distortion_debug: Union[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionDebug, NoneType], page_active_mask: vkit.element.mask.Mask, page_char_polygon_collection: vkit.pipeline.text_detection.page_text_line_label.PageCharPolygonCollection, page_char_mask: Union[vkit.element.mask.Mask, NoneType], page_char_height_score_map: Union[vkit.element.score_map.ScoreMap, NoneType], page_char_heights: Union[Sequence[float], NoneType], page_char_heights_debug_image: Union[vkit.element.image.Image, NoneType], page_text_line_polygon_collection: vkit.pipeline.text_detection.page_text_line_label.PageTextLinePolygonCollection, page_text_line_mask: Union[vkit.element.mask.Mask, NoneType], page_text_line_height_score_map: Union[vkit.element.score_map.ScoreMap, NoneType], page_text_line_heights: Union[Sequence[float], NoneType], page_text_line_heights_debug_image: Union[vkit.element.image.Image, NoneType], page_disconnected_text_region_collection: vkit.pipeline.text_detection.page_assembler.PageDisconnectedTextRegionCollection, page_non_text_region_collection: vkit.pipeline.text_detection.page_assembler.PageNonTextRegionCollection)
 2def __init__(self, page_image, page_random_distortion_debug, page_active_mask, page_char_polygon_collection, page_char_mask, page_char_height_score_map, page_char_heights, page_char_heights_debug_image, page_text_line_polygon_collection, page_text_line_mask, page_text_line_height_score_map, page_text_line_heights, page_text_line_heights_debug_image, page_disconnected_text_region_collection, page_non_text_region_collection):
 3    self.page_image = page_image
 4    self.page_random_distortion_debug = page_random_distortion_debug
 5    self.page_active_mask = page_active_mask
 6    self.page_char_polygon_collection = page_char_polygon_collection
 7    self.page_char_mask = page_char_mask
 8    self.page_char_height_score_map = page_char_height_score_map
 9    self.page_char_heights = page_char_heights
10    self.page_char_heights_debug_image = page_char_heights_debug_image
11    self.page_text_line_polygon_collection = page_text_line_polygon_collection
12    self.page_text_line_mask = page_text_line_mask
13    self.page_text_line_height_score_map = page_text_line_height_score_map
14    self.page_text_line_heights = page_text_line_heights
15    self.page_text_line_heights_debug_image = page_text_line_heights_debug_image
16    self.page_disconnected_text_region_collection = page_disconnected_text_region_collection
17    self.page_non_text_region_collection = page_non_text_region_collection

Method generated by attrs for class PageDistortionStepOutput.

class ElementFlattener(typing.Generic[~_E]):
 97class ElementFlattener(Generic[_E]):
 98
 99    def __init__(self, grouped_elements: Sequence[Sequence[_E]]):
100        self.grouped_elements = grouped_elements
101        self.group_sizes = [len(elements) for elements in grouped_elements]
102
103    def flatten(self):
104        return tuple(itertools.chain.from_iterable(self.grouped_elements))
105
106    def unflatten(self, flattened_elements: Sequence[_E]) -> Sequence[Sequence[_E]]:
107        assert len(flattened_elements) == sum(self.group_sizes)
108        grouped_elements: List[Sequence[_E]] = []
109        begin = 0
110        for group_size in self.group_sizes:
111            end = begin + group_size
112            grouped_elements.append(flattened_elements[begin:end])
113            begin = end
114        return grouped_elements

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

ElementFlattener(grouped_elements: Sequence[Sequence[~_E]])
 99    def __init__(self, grouped_elements: Sequence[Sequence[_E]]):
100        self.grouped_elements = grouped_elements
101        self.group_sizes = [len(elements) for elements in grouped_elements]
def flatten(self):
103    def flatten(self):
104        return tuple(itertools.chain.from_iterable(self.grouped_elements))
def unflatten(self, flattened_elements: Sequence[~_E]) -> Sequence[Sequence[~_E]]:
106    def unflatten(self, flattened_elements: Sequence[_E]) -> Sequence[Sequence[_E]]:
107        assert len(flattened_elements) == sum(self.group_sizes)
108        grouped_elements: List[Sequence[_E]] = []
109        begin = 0
110        for group_size in self.group_sizes:
111            end = begin + group_size
112            grouped_elements.append(flattened_elements[begin:end])
113            begin = end
114        return grouped_elements
117class PageDistortionStep(
118    PipelineStep[
119        PageDistortionStepConfig,
120        PageDistortionStepInput,
121        PageDistortionStepOutput,
122    ]
123):  # yapf: disable
124
125    def __init__(self, config: PageDistortionStepConfig):
126        super().__init__(config)
127
128        self.random_distortion = random_distortion_factory.create(
129            self.config.random_distortion_factory_config
130        )
131
132    @classmethod
133    def fill_page_inactive_region(
134        cls,
135        page_image: Image,
136        page_active_mask: Mask,
137        page_bottom_layer_image: Image,
138    ):
139        assert page_image.shape == page_active_mask.shape
140
141        if page_bottom_layer_image.shape != page_image.shape:
142            page_bottom_layer_image = page_bottom_layer_image.to_resized_image(
143                resized_height=page_image.height,
144                resized_width=page_image.width,
145            )
146
147        page_active_mask.to_inverted_mask().fill_image(page_image, page_bottom_layer_image)
148
149    def generate_text_line_labelings(
150        self,
151        distorted_image: Image,
152        text_line_polygons: Sequence[Polygon],
153        text_line_height_points_up: PointList,
154        text_line_height_points_down: PointList,
155        text_line_height_points_group_sizes: Sequence[int],
156    ):
157        text_line_mask: Optional[Mask] = None
158        if self.config.enable_distorted_text_line_mask:
159            text_line_mask = Mask.from_shapable(distorted_image)
160            for polygon in text_line_polygons:
161                polygon.fill_mask(text_line_mask)
162
163        text_line_height_score_map: Optional[ScoreMap] = None
164        text_line_heights: Optional[List[float]] = None
165        text_line_heights_debug_image: Optional[Image] = None
166
167        if self.config.enable_distorted_text_line_height_score_map:
168            np_height_points_up = text_line_height_points_up.to_smooth_np_array()
169            np_height_points_down = text_line_height_points_down.to_smooth_np_array()
170            np_heights: np.ndarray = np.linalg.norm(
171                np_height_points_down - np_height_points_up,
172                axis=1,
173            )
174            # Add one to compensate.
175            np_heights += 1
176            assert sum(text_line_height_points_group_sizes) == np_heights.shape[0]
177
178            text_line_heights = []
179            text_line_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False)
180            begin = 0
181            for polygon, group_size in zip(text_line_polygons, text_line_height_points_group_sizes):
182                end = begin + group_size - 1
183                text_line_height = float(np_heights[begin:end + 1].mean())
184                text_line_heights.append(text_line_height)
185                polygon.fill_score_map(
186                    score_map=text_line_height_score_map,
187                    value=text_line_height,
188                )
189                begin = end + 1
190
191            if self.config.enable_debug_distorted_text_line_heights:
192                painter = Painter.create(distorted_image)
193                painter.paint_polygons(text_line_polygons)
194
195                texts: List[str] = []
196                points = PointList()
197                for polygon, height in zip(text_line_polygons, text_line_heights):
198                    texts.append(f'{height:.1f}')
199                    points.append(polygon.get_center_point())
200                painter.paint_texts(texts, points, alpha=1.0)
201
202                text_line_heights_debug_image = painter.image
203
204        return (
205            text_line_mask,
206            text_line_height_score_map,
207            text_line_heights,
208            text_line_heights_debug_image,
209        )
210
211    def generate_char_labelings(
212        self,
213        distorted_image: Image,
214        char_polygons: Sequence[Polygon],
215        char_height_points_up: PointList,
216        char_height_points_down: PointList,
217    ):
218        char_mask: Optional[Mask] = None
219        if self.config.enable_distorted_text_line_mask:
220            char_mask = Mask.from_shapable(distorted_image)
221            for polygon in char_polygons:
222                polygon.fill_mask(char_mask)
223
224        char_height_score_map: Optional[ScoreMap] = None
225        char_heights: Optional[List[float]] = None
226        char_heights_debug_image: Optional[Image] = None
227
228        if self.config.enable_distorted_char_height_score_map:
229            np_height_points_up = char_height_points_up.to_smooth_np_array()
230            np_height_points_down = char_height_points_down.to_smooth_np_array()
231            np_heights: np.ndarray = np.linalg.norm(
232                np_height_points_down - np_height_points_up,
233                axis=1,
234            )
235            # Add one to compensate.
236            np_heights += 1
237
238            # Fill from large height to small height,
239            # in order to preserve small height labeling when two char boxes overlapped.
240            sorted_char_polygon_indices: Tuple[int, ...] = tuple(reversed(np_heights.argsort()))
241
242            char_heights = [0.0] * len(char_polygons)
243            char_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False)
244
245            for idx in sorted_char_polygon_indices:
246                polygon = char_polygons[idx]
247                char_height = float(np_heights[idx])
248                char_heights[idx] = char_height
249                polygon.fill_score_map(
250                    score_map=char_height_score_map,
251                    value=char_height,
252                )
253
254            if self.config.enable_debug_distorted_char_heights:
255                painter = Painter.create(distorted_image)
256                painter.paint_polygons(char_polygons)
257
258                texts: List[str] = []
259                points = PointList()
260                for polygon, height in zip(char_polygons, char_heights):
261                    texts.append(f'{height:.1f}')
262                    points.append(polygon.get_center_point())
263                painter.paint_texts(texts, points, alpha=1.0)
264
265                char_heights_debug_image = painter.image
266
267        return (
268            char_mask,
269            char_height_score_map,
270            char_heights,
271            char_heights_debug_image,
272        )
273
274    def run(self, input: PageDistortionStepInput, rng: RandomGenerator):
275        page_assembler_step_output = input.page_assembler_step_output
276        page = page_assembler_step_output.page
277        page_bottom_layer_image = page.page_bottom_layer_image
278        page_char_polygon_collection = page.page_char_polygon_collection
279        page_text_line_polygon_collection = page.page_text_line_polygon_collection
280        page_disconnected_text_region_collection = page.page_disconnected_text_region_collection
281        page_non_text_region_collection = page.page_non_text_region_collection
282
283        # Flatten.
284        polygon_flattener = ElementFlattener([
285            # Char level.
286            page_char_polygon_collection.polygons,
287            # Text line level.
288            page_text_line_polygon_collection.polygons,
289            # For char-level polygon regression.
290            tuple(page_disconnected_text_region_collection.to_polygons()),
291            # For sampling negative text region area.
292            tuple(page_non_text_region_collection.to_polygons()),
293        ])
294        point_flattener = ElementFlattener([
295            # Char level.
296            page_char_polygon_collection.height_points_up,
297            page_char_polygon_collection.height_points_down,
298            # Text line level.
299            page_text_line_polygon_collection.height_points_up,
300            page_text_line_polygon_collection.height_points_down,
301        ])
302
303        # Distort.
304        page_random_distortion_debug = None
305        if self.config.enable_debug_random_distortion:
306            page_random_distortion_debug = RandomDistortionDebug()
307
308        page_active_mask = Mask.from_shapable(page.image, value=1)
309        # To mitigate a bug in cv.remap, in which the border interpolation is wrong.
310        # This mitigation DO remove 1-pixel width border, but it should be fine.
311        with page_active_mask.writable_context:
312            page_active_mask.mat[0] = 0
313            page_active_mask.mat[-1] = 0
314            page_active_mask.mat[:, 0] = 0
315            page_active_mask.mat[:, -1] = 0
316
317        result = self.random_distortion.distort(
318            image=page.image,
319            mask=page_active_mask,
320            polygons=polygon_flattener.flatten(),
321            points=PointList(point_flattener.flatten()),
322            rng=rng,
323            debug=page_random_distortion_debug,
324        )
325        assert result.image
326        assert result.mask
327        assert result.polygons
328        assert result.points
329
330        # Fill inplace the inactive (black) region with page_bottom_layer_image.
331        self.fill_page_inactive_region(
332            page_image=result.image,
333            page_active_mask=result.mask,
334            page_bottom_layer_image=page_bottom_layer_image,
335        )
336
337        # Unflatten.
338        (
339            # Char level.
340            char_polygons,
341            # Text line level.
342            text_line_polygons,
343            # For char-level polygon regression.
344            disconnected_text_region_polygons,
345            # For sampling negative text region area.
346            non_text_region_polygons,
347        ) = polygon_flattener.unflatten(result.polygons)
348
349        (
350            # Char level.
351            char_height_points_up,
352            char_height_points_down,
353            # Text line level.
354            text_line_height_points_up,
355            text_line_height_points_down,
356        ) = map(PointList, point_flattener.unflatten(result.points))
357
358        text_line_height_points_group_sizes = \
359            page_text_line_polygon_collection.height_points_group_sizes
360        assert len(text_line_polygons) == len(text_line_height_points_group_sizes)
361        assert len(text_line_height_points_up) == len(text_line_height_points_down)
362
363        # Labelings.
364        (
365            text_line_mask,
366            text_line_height_score_map,
367            text_line_heights,
368            text_line_heights_debug_image,
369        ) = self.generate_text_line_labelings(
370            distorted_image=result.image,
371            text_line_polygons=text_line_polygons,
372            text_line_height_points_up=text_line_height_points_up,
373            text_line_height_points_down=text_line_height_points_down,
374            text_line_height_points_group_sizes=text_line_height_points_group_sizes,
375        )
376        (
377            char_mask,
378            char_height_score_map,
379            char_heights,
380            char_heights_debug_image,
381        ) = self.generate_char_labelings(
382            distorted_image=result.image,
383            char_polygons=char_polygons,
384            char_height_points_up=char_height_points_up,
385            char_height_points_down=char_height_points_down,
386        )
387
388        return PageDistortionStepOutput(
389            page_image=result.image,
390            page_random_distortion_debug=page_random_distortion_debug,
391            page_active_mask=result.mask,
392            page_char_polygon_collection=PageCharPolygonCollection(
393                height=result.image.height,
394                width=result.image.width,
395                polygons=char_polygons,
396                height_points_up=char_height_points_up,
397                height_points_down=char_height_points_down,
398            ),
399            page_char_mask=char_mask,
400            page_char_height_score_map=char_height_score_map,
401            page_char_heights=char_heights,
402            page_char_heights_debug_image=char_heights_debug_image,
403            page_text_line_polygon_collection=PageTextLinePolygonCollection(
404                height=result.image.height,
405                width=result.image.width,
406                polygons=text_line_polygons,
407                height_points_group_sizes=text_line_height_points_group_sizes,
408                height_points_up=text_line_height_points_up,
409                height_points_down=text_line_height_points_down,
410            ),
411            page_text_line_mask=text_line_mask,
412            page_text_line_height_score_map=text_line_height_score_map,
413            page_text_line_heights=text_line_heights,
414            page_text_line_heights_debug_image=text_line_heights_debug_image,
415            page_disconnected_text_region_collection=PageDisconnectedTextRegionCollection(
416                disconnected_text_regions=[
417                    DisconnectedTextRegion(disconnected_text_region_polygon)
418                    for disconnected_text_region_polygon in disconnected_text_region_polygons
419                ],
420            ),
421            page_non_text_region_collection=PageNonTextRegionCollection(
422                non_text_regions=[
423                    NonTextRegion(non_text_region_polygon)
424                    for non_text_region_polygon in non_text_region_polygons
425                ],
426            )
427        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

125    def __init__(self, config: PageDistortionStepConfig):
126        super().__init__(config)
127
128        self.random_distortion = random_distortion_factory.create(
129            self.config.random_distortion_factory_config
130        )
@classmethod
def fill_page_inactive_region( cls, page_image: vkit.element.image.Image, page_active_mask: vkit.element.mask.Mask, page_bottom_layer_image: vkit.element.image.Image):
132    @classmethod
133    def fill_page_inactive_region(
134        cls,
135        page_image: Image,
136        page_active_mask: Mask,
137        page_bottom_layer_image: Image,
138    ):
139        assert page_image.shape == page_active_mask.shape
140
141        if page_bottom_layer_image.shape != page_image.shape:
142            page_bottom_layer_image = page_bottom_layer_image.to_resized_image(
143                resized_height=page_image.height,
144                resized_width=page_image.width,
145            )
146
147        page_active_mask.to_inverted_mask().fill_image(page_image, page_bottom_layer_image)
def generate_text_line_labelings( self, distorted_image: vkit.element.image.Image, text_line_polygons: Sequence[vkit.element.polygon.Polygon], text_line_height_points_up: vkit.element.point.PointList, text_line_height_points_down: vkit.element.point.PointList, text_line_height_points_group_sizes: Sequence[int]):
149    def generate_text_line_labelings(
150        self,
151        distorted_image: Image,
152        text_line_polygons: Sequence[Polygon],
153        text_line_height_points_up: PointList,
154        text_line_height_points_down: PointList,
155        text_line_height_points_group_sizes: Sequence[int],
156    ):
157        text_line_mask: Optional[Mask] = None
158        if self.config.enable_distorted_text_line_mask:
159            text_line_mask = Mask.from_shapable(distorted_image)
160            for polygon in text_line_polygons:
161                polygon.fill_mask(text_line_mask)
162
163        text_line_height_score_map: Optional[ScoreMap] = None
164        text_line_heights: Optional[List[float]] = None
165        text_line_heights_debug_image: Optional[Image] = None
166
167        if self.config.enable_distorted_text_line_height_score_map:
168            np_height_points_up = text_line_height_points_up.to_smooth_np_array()
169            np_height_points_down = text_line_height_points_down.to_smooth_np_array()
170            np_heights: np.ndarray = np.linalg.norm(
171                np_height_points_down - np_height_points_up,
172                axis=1,
173            )
174            # Add one to compensate.
175            np_heights += 1
176            assert sum(text_line_height_points_group_sizes) == np_heights.shape[0]
177
178            text_line_heights = []
179            text_line_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False)
180            begin = 0
181            for polygon, group_size in zip(text_line_polygons, text_line_height_points_group_sizes):
182                end = begin + group_size - 1
183                text_line_height = float(np_heights[begin:end + 1].mean())
184                text_line_heights.append(text_line_height)
185                polygon.fill_score_map(
186                    score_map=text_line_height_score_map,
187                    value=text_line_height,
188                )
189                begin = end + 1
190
191            if self.config.enable_debug_distorted_text_line_heights:
192                painter = Painter.create(distorted_image)
193                painter.paint_polygons(text_line_polygons)
194
195                texts: List[str] = []
196                points = PointList()
197                for polygon, height in zip(text_line_polygons, text_line_heights):
198                    texts.append(f'{height:.1f}')
199                    points.append(polygon.get_center_point())
200                painter.paint_texts(texts, points, alpha=1.0)
201
202                text_line_heights_debug_image = painter.image
203
204        return (
205            text_line_mask,
206            text_line_height_score_map,
207            text_line_heights,
208            text_line_heights_debug_image,
209        )
def generate_char_labelings( self, distorted_image: vkit.element.image.Image, char_polygons: Sequence[vkit.element.polygon.Polygon], char_height_points_up: vkit.element.point.PointList, char_height_points_down: vkit.element.point.PointList):
211    def generate_char_labelings(
212        self,
213        distorted_image: Image,
214        char_polygons: Sequence[Polygon],
215        char_height_points_up: PointList,
216        char_height_points_down: PointList,
217    ):
218        char_mask: Optional[Mask] = None
219        if self.config.enable_distorted_text_line_mask:
220            char_mask = Mask.from_shapable(distorted_image)
221            for polygon in char_polygons:
222                polygon.fill_mask(char_mask)
223
224        char_height_score_map: Optional[ScoreMap] = None
225        char_heights: Optional[List[float]] = None
226        char_heights_debug_image: Optional[Image] = None
227
228        if self.config.enable_distorted_char_height_score_map:
229            np_height_points_up = char_height_points_up.to_smooth_np_array()
230            np_height_points_down = char_height_points_down.to_smooth_np_array()
231            np_heights: np.ndarray = np.linalg.norm(
232                np_height_points_down - np_height_points_up,
233                axis=1,
234            )
235            # Add one to compensate.
236            np_heights += 1
237
238            # Fill from large height to small height,
239            # in order to preserve small height labeling when two char boxes overlapped.
240            sorted_char_polygon_indices: Tuple[int, ...] = tuple(reversed(np_heights.argsort()))
241
242            char_heights = [0.0] * len(char_polygons)
243            char_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False)
244
245            for idx in sorted_char_polygon_indices:
246                polygon = char_polygons[idx]
247                char_height = float(np_heights[idx])
248                char_heights[idx] = char_height
249                polygon.fill_score_map(
250                    score_map=char_height_score_map,
251                    value=char_height,
252                )
253
254            if self.config.enable_debug_distorted_char_heights:
255                painter = Painter.create(distorted_image)
256                painter.paint_polygons(char_polygons)
257
258                texts: List[str] = []
259                points = PointList()
260                for polygon, height in zip(char_polygons, char_heights):
261                    texts.append(f'{height:.1f}')
262                    points.append(polygon.get_center_point())
263                painter.paint_texts(texts, points, alpha=1.0)
264
265                char_heights_debug_image = painter.image
266
267        return (
268            char_mask,
269            char_height_score_map,
270            char_heights,
271            char_heights_debug_image,
272        )
def run( self, input: vkit.pipeline.text_detection.page_distortion.PageDistortionStepInput, rng: numpy.random._generator.Generator):
274    def run(self, input: PageDistortionStepInput, rng: RandomGenerator):
275        page_assembler_step_output = input.page_assembler_step_output
276        page = page_assembler_step_output.page
277        page_bottom_layer_image = page.page_bottom_layer_image
278        page_char_polygon_collection = page.page_char_polygon_collection
279        page_text_line_polygon_collection = page.page_text_line_polygon_collection
280        page_disconnected_text_region_collection = page.page_disconnected_text_region_collection
281        page_non_text_region_collection = page.page_non_text_region_collection
282
283        # Flatten.
284        polygon_flattener = ElementFlattener([
285            # Char level.
286            page_char_polygon_collection.polygons,
287            # Text line level.
288            page_text_line_polygon_collection.polygons,
289            # For char-level polygon regression.
290            tuple(page_disconnected_text_region_collection.to_polygons()),
291            # For sampling negative text region area.
292            tuple(page_non_text_region_collection.to_polygons()),
293        ])
294        point_flattener = ElementFlattener([
295            # Char level.
296            page_char_polygon_collection.height_points_up,
297            page_char_polygon_collection.height_points_down,
298            # Text line level.
299            page_text_line_polygon_collection.height_points_up,
300            page_text_line_polygon_collection.height_points_down,
301        ])
302
303        # Distort.
304        page_random_distortion_debug = None
305        if self.config.enable_debug_random_distortion:
306            page_random_distortion_debug = RandomDistortionDebug()
307
308        page_active_mask = Mask.from_shapable(page.image, value=1)
309        # To mitigate a bug in cv.remap, in which the border interpolation is wrong.
310        # This mitigation DO remove 1-pixel width border, but it should be fine.
311        with page_active_mask.writable_context:
312            page_active_mask.mat[0] = 0
313            page_active_mask.mat[-1] = 0
314            page_active_mask.mat[:, 0] = 0
315            page_active_mask.mat[:, -1] = 0
316
317        result = self.random_distortion.distort(
318            image=page.image,
319            mask=page_active_mask,
320            polygons=polygon_flattener.flatten(),
321            points=PointList(point_flattener.flatten()),
322            rng=rng,
323            debug=page_random_distortion_debug,
324        )
325        assert result.image
326        assert result.mask
327        assert result.polygons
328        assert result.points
329
330        # Fill inplace the inactive (black) region with page_bottom_layer_image.
331        self.fill_page_inactive_region(
332            page_image=result.image,
333            page_active_mask=result.mask,
334            page_bottom_layer_image=page_bottom_layer_image,
335        )
336
337        # Unflatten.
338        (
339            # Char level.
340            char_polygons,
341            # Text line level.
342            text_line_polygons,
343            # For char-level polygon regression.
344            disconnected_text_region_polygons,
345            # For sampling negative text region area.
346            non_text_region_polygons,
347        ) = polygon_flattener.unflatten(result.polygons)
348
349        (
350            # Char level.
351            char_height_points_up,
352            char_height_points_down,
353            # Text line level.
354            text_line_height_points_up,
355            text_line_height_points_down,
356        ) = map(PointList, point_flattener.unflatten(result.points))
357
358        text_line_height_points_group_sizes = \
359            page_text_line_polygon_collection.height_points_group_sizes
360        assert len(text_line_polygons) == len(text_line_height_points_group_sizes)
361        assert len(text_line_height_points_up) == len(text_line_height_points_down)
362
363        # Labelings.
364        (
365            text_line_mask,
366            text_line_height_score_map,
367            text_line_heights,
368            text_line_heights_debug_image,
369        ) = self.generate_text_line_labelings(
370            distorted_image=result.image,
371            text_line_polygons=text_line_polygons,
372            text_line_height_points_up=text_line_height_points_up,
373            text_line_height_points_down=text_line_height_points_down,
374            text_line_height_points_group_sizes=text_line_height_points_group_sizes,
375        )
376        (
377            char_mask,
378            char_height_score_map,
379            char_heights,
380            char_heights_debug_image,
381        ) = self.generate_char_labelings(
382            distorted_image=result.image,
383            char_polygons=char_polygons,
384            char_height_points_up=char_height_points_up,
385            char_height_points_down=char_height_points_down,
386        )
387
388        return PageDistortionStepOutput(
389            page_image=result.image,
390            page_random_distortion_debug=page_random_distortion_debug,
391            page_active_mask=result.mask,
392            page_char_polygon_collection=PageCharPolygonCollection(
393                height=result.image.height,
394                width=result.image.width,
395                polygons=char_polygons,
396                height_points_up=char_height_points_up,
397                height_points_down=char_height_points_down,
398            ),
399            page_char_mask=char_mask,
400            page_char_height_score_map=char_height_score_map,
401            page_char_heights=char_heights,
402            page_char_heights_debug_image=char_heights_debug_image,
403            page_text_line_polygon_collection=PageTextLinePolygonCollection(
404                height=result.image.height,
405                width=result.image.width,
406                polygons=text_line_polygons,
407                height_points_group_sizes=text_line_height_points_group_sizes,
408                height_points_up=text_line_height_points_up,
409                height_points_down=text_line_height_points_down,
410            ),
411            page_text_line_mask=text_line_mask,
412            page_text_line_height_score_map=text_line_height_score_map,
413            page_text_line_heights=text_line_heights,
414            page_text_line_heights_debug_image=text_line_heights_debug_image,
415            page_disconnected_text_region_collection=PageDisconnectedTextRegionCollection(
416                disconnected_text_regions=[
417                    DisconnectedTextRegion(disconnected_text_region_polygon)
418                    for disconnected_text_region_polygon in disconnected_text_region_polygons
419                ],
420            ),
421            page_non_text_region_collection=PageNonTextRegionCollection(
422                non_text_regions=[
423                    NonTextRegion(non_text_region_polygon)
424                    for non_text_region_polygon in non_text_region_polygons
425                ],
426            )
427        )