vkit.pipeline.text_detection.page_distortion
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Optional, Union, Mapping, Any, List, Tuple, Sequence, TypeVar, Generic 15import itertools 16 17import attrs 18from numpy.random import Generator as RandomGenerator 19import numpy as np 20 21from vkit.utility import PathType 22from vkit.element import ( 23 Point, 24 PointList, 25 Polygon, 26 Mask, 27 ScoreMap, 28 Image, 29) 30from vkit.mechanism.distortion_policy import ( 31 random_distortion_factory, 32 RandomDistortionDebug, 33) 34from vkit.mechanism.painter import Painter 35from ..interface import PipelineStep, PipelineStepFactory 36from .page_layout import DisconnectedTextRegion, NonTextRegion 37from .page_text_line_label import ( 38 PageCharPolygonCollection, 39 PageTextLinePolygonCollection, 40) 41from .page_assembler import ( 42 PageAssemblerStepOutput, 43 PageDisconnectedTextRegionCollection, 44 PageNonTextRegionCollection, 45) 46 47 48@attrs.define 49class PageDistortionStepConfig: 50 random_distortion_factory_config: Optional[Union[Mapping[str, Any], PathType]] = attrs.field( 51 factory=lambda: { 52 # NOTE: defocus blur and zoom in blur introduce labeling noise. 53 # TODO: enhance those blurring methods for page. 54 'disabled_policy_names': [ 55 'defocus_blur', 56 'zoom_in_blur', 57 ], 58 } 59 ) 60 enable_debug_random_distortion: bool = False 61 enable_distorted_char_mask: bool = True 62 enable_distorted_char_height_score_map: bool = True 63 enable_debug_distorted_char_heights: bool = False 64 enable_distorted_text_line_mask: bool = True 65 enable_distorted_text_line_height_score_map: bool = True 66 enable_debug_distorted_text_line_heights: bool = False 67 68 69@attrs.define 70class PageDistortionStepInput: 71 page_assembler_step_output: PageAssemblerStepOutput 72 73 74@attrs.define 75class PageDistortionStepOutput: 76 page_image: Image 77 page_random_distortion_debug: Optional[RandomDistortionDebug] 78 page_active_mask: Mask 79 page_char_polygon_collection: PageCharPolygonCollection 80 page_char_mask: Optional[Mask] 81 page_char_height_score_map: Optional[ScoreMap] 82 page_char_heights: Optional[Sequence[float]] 83 page_char_heights_debug_image: Optional[Image] 84 page_text_line_polygon_collection: PageTextLinePolygonCollection 85 page_text_line_mask: Optional[Mask] 86 page_text_line_height_score_map: Optional[ScoreMap] 87 page_text_line_heights: Optional[Sequence[float]] 88 page_text_line_heights_debug_image: Optional[Image] 89 page_disconnected_text_region_collection: PageDisconnectedTextRegionCollection 90 page_non_text_region_collection: PageNonTextRegionCollection 91 92 93_E = TypeVar('_E', Point, Polygon) 94 95 96class ElementFlattener(Generic[_E]): 97 98 def __init__(self, grouped_elements: Sequence[Sequence[_E]]): 99 self.grouped_elements = grouped_elements 100 self.group_sizes = [len(elements) for elements in grouped_elements] 101 102 def flatten(self): 103 return tuple(itertools.chain.from_iterable(self.grouped_elements)) 104 105 def unflatten(self, flattened_elements: Sequence[_E]) -> Sequence[Sequence[_E]]: 106 assert len(flattened_elements) == sum(self.group_sizes) 107 grouped_elements: List[Sequence[_E]] = [] 108 begin = 0 109 for group_size in self.group_sizes: 110 end = begin + group_size 111 grouped_elements.append(flattened_elements[begin:end]) 112 begin = end 113 return grouped_elements 114 115 116class PageDistortionStep( 117 PipelineStep[ 118 PageDistortionStepConfig, 119 PageDistortionStepInput, 120 PageDistortionStepOutput, 121 ] 122): # yapf: disable 123 124 def __init__(self, config: PageDistortionStepConfig): 125 super().__init__(config) 126 127 self.random_distortion = random_distortion_factory.create( 128 self.config.random_distortion_factory_config 129 ) 130 131 @classmethod 132 def fill_page_inactive_region( 133 cls, 134 page_image: Image, 135 page_active_mask: Mask, 136 page_bottom_layer_image: Image, 137 ): 138 assert page_image.shape == page_active_mask.shape 139 140 if page_bottom_layer_image.shape != page_image.shape: 141 page_bottom_layer_image = page_bottom_layer_image.to_resized_image( 142 resized_height=page_image.height, 143 resized_width=page_image.width, 144 ) 145 146 page_active_mask.to_inverted_mask().fill_image(page_image, page_bottom_layer_image) 147 148 def generate_text_line_labelings( 149 self, 150 distorted_image: Image, 151 text_line_polygons: Sequence[Polygon], 152 text_line_height_points_up: PointList, 153 text_line_height_points_down: PointList, 154 text_line_height_points_group_sizes: Sequence[int], 155 ): 156 text_line_mask: Optional[Mask] = None 157 if self.config.enable_distorted_text_line_mask: 158 text_line_mask = Mask.from_shapable(distorted_image) 159 for polygon in text_line_polygons: 160 polygon.fill_mask(text_line_mask) 161 162 text_line_height_score_map: Optional[ScoreMap] = None 163 text_line_heights: Optional[List[float]] = None 164 text_line_heights_debug_image: Optional[Image] = None 165 166 if self.config.enable_distorted_text_line_height_score_map: 167 np_height_points_up = text_line_height_points_up.to_smooth_np_array() 168 np_height_points_down = text_line_height_points_down.to_smooth_np_array() 169 np_heights: np.ndarray = np.linalg.norm( 170 np_height_points_down - np_height_points_up, 171 axis=1, 172 ) 173 # Add one to compensate. 174 np_heights += 1 175 assert sum(text_line_height_points_group_sizes) == np_heights.shape[0] 176 177 text_line_heights = [] 178 text_line_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False) 179 begin = 0 180 for polygon, group_size in zip(text_line_polygons, text_line_height_points_group_sizes): 181 end = begin + group_size - 1 182 text_line_height = float(np_heights[begin:end + 1].mean()) 183 text_line_heights.append(text_line_height) 184 polygon.fill_score_map( 185 score_map=text_line_height_score_map, 186 value=text_line_height, 187 ) 188 begin = end + 1 189 190 if self.config.enable_debug_distorted_text_line_heights: 191 painter = Painter.create(distorted_image) 192 painter.paint_polygons(text_line_polygons) 193 194 texts: List[str] = [] 195 points = PointList() 196 for polygon, height in zip(text_line_polygons, text_line_heights): 197 texts.append(f'{height:.1f}') 198 points.append(polygon.get_center_point()) 199 painter.paint_texts(texts, points, alpha=1.0) 200 201 text_line_heights_debug_image = painter.image 202 203 return ( 204 text_line_mask, 205 text_line_height_score_map, 206 text_line_heights, 207 text_line_heights_debug_image, 208 ) 209 210 def generate_char_labelings( 211 self, 212 distorted_image: Image, 213 char_polygons: Sequence[Polygon], 214 char_height_points_up: PointList, 215 char_height_points_down: PointList, 216 ): 217 char_mask: Optional[Mask] = None 218 if self.config.enable_distorted_text_line_mask: 219 char_mask = Mask.from_shapable(distorted_image) 220 for polygon in char_polygons: 221 polygon.fill_mask(char_mask) 222 223 char_height_score_map: Optional[ScoreMap] = None 224 char_heights: Optional[List[float]] = None 225 char_heights_debug_image: Optional[Image] = None 226 227 if self.config.enable_distorted_char_height_score_map: 228 np_height_points_up = char_height_points_up.to_smooth_np_array() 229 np_height_points_down = char_height_points_down.to_smooth_np_array() 230 np_heights: np.ndarray = np.linalg.norm( 231 np_height_points_down - np_height_points_up, 232 axis=1, 233 ) 234 # Add one to compensate. 235 np_heights += 1 236 237 # Fill from large height to small height, 238 # in order to preserve small height labeling when two char boxes overlapped. 239 sorted_char_polygon_indices: Tuple[int, ...] = tuple(reversed(np_heights.argsort())) 240 241 char_heights = [0.0] * len(char_polygons) 242 char_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False) 243 244 for idx in sorted_char_polygon_indices: 245 polygon = char_polygons[idx] 246 char_height = float(np_heights[idx]) 247 char_heights[idx] = char_height 248 polygon.fill_score_map( 249 score_map=char_height_score_map, 250 value=char_height, 251 ) 252 253 if self.config.enable_debug_distorted_char_heights: 254 painter = Painter.create(distorted_image) 255 painter.paint_polygons(char_polygons) 256 257 texts: List[str] = [] 258 points = PointList() 259 for polygon, height in zip(char_polygons, char_heights): 260 texts.append(f'{height:.1f}') 261 points.append(polygon.get_center_point()) 262 painter.paint_texts(texts, points, alpha=1.0) 263 264 char_heights_debug_image = painter.image 265 266 return ( 267 char_mask, 268 char_height_score_map, 269 char_heights, 270 char_heights_debug_image, 271 ) 272 273 def run(self, input: PageDistortionStepInput, rng: RandomGenerator): 274 page_assembler_step_output = input.page_assembler_step_output 275 page = page_assembler_step_output.page 276 page_bottom_layer_image = page.page_bottom_layer_image 277 page_char_polygon_collection = page.page_char_polygon_collection 278 page_text_line_polygon_collection = page.page_text_line_polygon_collection 279 page_disconnected_text_region_collection = page.page_disconnected_text_region_collection 280 page_non_text_region_collection = page.page_non_text_region_collection 281 282 # Flatten. 283 polygon_flattener = ElementFlattener([ 284 # Char level. 285 page_char_polygon_collection.polygons, 286 # Text line level. 287 page_text_line_polygon_collection.polygons, 288 # For char-level polygon regression. 289 tuple(page_disconnected_text_region_collection.to_polygons()), 290 # For sampling negative text region area. 291 tuple(page_non_text_region_collection.to_polygons()), 292 ]) 293 point_flattener = ElementFlattener([ 294 # Char level. 295 page_char_polygon_collection.height_points_up, 296 page_char_polygon_collection.height_points_down, 297 # Text line level. 298 page_text_line_polygon_collection.height_points_up, 299 page_text_line_polygon_collection.height_points_down, 300 ]) 301 302 # Distort. 303 page_random_distortion_debug = None 304 if self.config.enable_debug_random_distortion: 305 page_random_distortion_debug = RandomDistortionDebug() 306 307 page_active_mask = Mask.from_shapable(page.image, value=1) 308 # To mitigate a bug in cv.remap, in which the border interpolation is wrong. 309 # This mitigation DO remove 1-pixel width border, but it should be fine. 310 with page_active_mask.writable_context: 311 page_active_mask.mat[0] = 0 312 page_active_mask.mat[-1] = 0 313 page_active_mask.mat[:, 0] = 0 314 page_active_mask.mat[:, -1] = 0 315 316 result = self.random_distortion.distort( 317 image=page.image, 318 mask=page_active_mask, 319 polygons=polygon_flattener.flatten(), 320 points=PointList(point_flattener.flatten()), 321 rng=rng, 322 debug=page_random_distortion_debug, 323 ) 324 assert result.image 325 assert result.mask 326 assert result.polygons 327 assert result.points 328 329 # Fill inplace the inactive (black) region with page_bottom_layer_image. 330 self.fill_page_inactive_region( 331 page_image=result.image, 332 page_active_mask=result.mask, 333 page_bottom_layer_image=page_bottom_layer_image, 334 ) 335 336 # Unflatten. 337 ( 338 # Char level. 339 char_polygons, 340 # Text line level. 341 text_line_polygons, 342 # For char-level polygon regression. 343 disconnected_text_region_polygons, 344 # For sampling negative text region area. 345 non_text_region_polygons, 346 ) = polygon_flattener.unflatten(result.polygons) 347 348 ( 349 # Char level. 350 char_height_points_up, 351 char_height_points_down, 352 # Text line level. 353 text_line_height_points_up, 354 text_line_height_points_down, 355 ) = map(PointList, point_flattener.unflatten(result.points)) 356 357 text_line_height_points_group_sizes = \ 358 page_text_line_polygon_collection.height_points_group_sizes 359 assert len(text_line_polygons) == len(text_line_height_points_group_sizes) 360 assert len(text_line_height_points_up) == len(text_line_height_points_down) 361 362 # Labelings. 363 ( 364 text_line_mask, 365 text_line_height_score_map, 366 text_line_heights, 367 text_line_heights_debug_image, 368 ) = self.generate_text_line_labelings( 369 distorted_image=result.image, 370 text_line_polygons=text_line_polygons, 371 text_line_height_points_up=text_line_height_points_up, 372 text_line_height_points_down=text_line_height_points_down, 373 text_line_height_points_group_sizes=text_line_height_points_group_sizes, 374 ) 375 ( 376 char_mask, 377 char_height_score_map, 378 char_heights, 379 char_heights_debug_image, 380 ) = self.generate_char_labelings( 381 distorted_image=result.image, 382 char_polygons=char_polygons, 383 char_height_points_up=char_height_points_up, 384 char_height_points_down=char_height_points_down, 385 ) 386 387 return PageDistortionStepOutput( 388 page_image=result.image, 389 page_random_distortion_debug=page_random_distortion_debug, 390 page_active_mask=result.mask, 391 page_char_polygon_collection=PageCharPolygonCollection( 392 height=result.image.height, 393 width=result.image.width, 394 polygons=char_polygons, 395 height_points_up=char_height_points_up, 396 height_points_down=char_height_points_down, 397 ), 398 page_char_mask=char_mask, 399 page_char_height_score_map=char_height_score_map, 400 page_char_heights=char_heights, 401 page_char_heights_debug_image=char_heights_debug_image, 402 page_text_line_polygon_collection=PageTextLinePolygonCollection( 403 height=result.image.height, 404 width=result.image.width, 405 polygons=text_line_polygons, 406 height_points_group_sizes=text_line_height_points_group_sizes, 407 height_points_up=text_line_height_points_up, 408 height_points_down=text_line_height_points_down, 409 ), 410 page_text_line_mask=text_line_mask, 411 page_text_line_height_score_map=text_line_height_score_map, 412 page_text_line_heights=text_line_heights, 413 page_text_line_heights_debug_image=text_line_heights_debug_image, 414 page_disconnected_text_region_collection=PageDisconnectedTextRegionCollection( 415 disconnected_text_regions=[ 416 DisconnectedTextRegion(disconnected_text_region_polygon) 417 for disconnected_text_region_polygon in disconnected_text_region_polygons 418 ], 419 ), 420 page_non_text_region_collection=PageNonTextRegionCollection( 421 non_text_regions=[ 422 NonTextRegion(non_text_region_polygon) 423 for non_text_region_polygon in non_text_region_polygons 424 ], 425 ) 426 ) 427 428 429page_distortion_step_factory = PipelineStepFactory(PageDistortionStep)
50class PageDistortionStepConfig: 51 random_distortion_factory_config: Optional[Union[Mapping[str, Any], PathType]] = attrs.field( 52 factory=lambda: { 53 # NOTE: defocus blur and zoom in blur introduce labeling noise. 54 # TODO: enhance those blurring methods for page. 55 'disabled_policy_names': [ 56 'defocus_blur', 57 'zoom_in_blur', 58 ], 59 } 60 ) 61 enable_debug_random_distortion: bool = False 62 enable_distorted_char_mask: bool = True 63 enable_distorted_char_height_score_map: bool = True 64 enable_debug_distorted_char_heights: bool = False 65 enable_distorted_text_line_mask: bool = True 66 enable_distorted_text_line_height_score_map: bool = True 67 enable_debug_distorted_text_line_heights: bool = False
2def __init__(self, random_distortion_factory_config=NOTHING, enable_debug_random_distortion=attr_dict['enable_debug_random_distortion'].default, enable_distorted_char_mask=attr_dict['enable_distorted_char_mask'].default, enable_distorted_char_height_score_map=attr_dict['enable_distorted_char_height_score_map'].default, enable_debug_distorted_char_heights=attr_dict['enable_debug_distorted_char_heights'].default, enable_distorted_text_line_mask=attr_dict['enable_distorted_text_line_mask'].default, enable_distorted_text_line_height_score_map=attr_dict['enable_distorted_text_line_height_score_map'].default, enable_debug_distorted_text_line_heights=attr_dict['enable_debug_distorted_text_line_heights'].default): 3 if random_distortion_factory_config is not NOTHING: 4 self.random_distortion_factory_config = random_distortion_factory_config 5 else: 6 self.random_distortion_factory_config = __attr_factory_random_distortion_factory_config() 7 self.enable_debug_random_distortion = enable_debug_random_distortion 8 self.enable_distorted_char_mask = enable_distorted_char_mask 9 self.enable_distorted_char_height_score_map = enable_distorted_char_height_score_map 10 self.enable_debug_distorted_char_heights = enable_debug_distorted_char_heights 11 self.enable_distorted_text_line_mask = enable_distorted_text_line_mask 12 self.enable_distorted_text_line_height_score_map = enable_distorted_text_line_height_score_map 13 self.enable_debug_distorted_text_line_heights = enable_debug_distorted_text_line_heights
Method generated by attrs for class PageDistortionStepConfig.
2def __init__(self, page_assembler_step_output): 3 self.page_assembler_step_output = page_assembler_step_output
Method generated by attrs for class PageDistortionStepInput.
76class PageDistortionStepOutput: 77 page_image: Image 78 page_random_distortion_debug: Optional[RandomDistortionDebug] 79 page_active_mask: Mask 80 page_char_polygon_collection: PageCharPolygonCollection 81 page_char_mask: Optional[Mask] 82 page_char_height_score_map: Optional[ScoreMap] 83 page_char_heights: Optional[Sequence[float]] 84 page_char_heights_debug_image: Optional[Image] 85 page_text_line_polygon_collection: PageTextLinePolygonCollection 86 page_text_line_mask: Optional[Mask] 87 page_text_line_height_score_map: Optional[ScoreMap] 88 page_text_line_heights: Optional[Sequence[float]] 89 page_text_line_heights_debug_image: Optional[Image] 90 page_disconnected_text_region_collection: PageDisconnectedTextRegionCollection 91 page_non_text_region_collection: PageNonTextRegionCollection
2def __init__(self, page_image, page_random_distortion_debug, page_active_mask, page_char_polygon_collection, page_char_mask, page_char_height_score_map, page_char_heights, page_char_heights_debug_image, page_text_line_polygon_collection, page_text_line_mask, page_text_line_height_score_map, page_text_line_heights, page_text_line_heights_debug_image, page_disconnected_text_region_collection, page_non_text_region_collection): 3 self.page_image = page_image 4 self.page_random_distortion_debug = page_random_distortion_debug 5 self.page_active_mask = page_active_mask 6 self.page_char_polygon_collection = page_char_polygon_collection 7 self.page_char_mask = page_char_mask 8 self.page_char_height_score_map = page_char_height_score_map 9 self.page_char_heights = page_char_heights 10 self.page_char_heights_debug_image = page_char_heights_debug_image 11 self.page_text_line_polygon_collection = page_text_line_polygon_collection 12 self.page_text_line_mask = page_text_line_mask 13 self.page_text_line_height_score_map = page_text_line_height_score_map 14 self.page_text_line_heights = page_text_line_heights 15 self.page_text_line_heights_debug_image = page_text_line_heights_debug_image 16 self.page_disconnected_text_region_collection = page_disconnected_text_region_collection 17 self.page_non_text_region_collection = page_non_text_region_collection
Method generated by attrs for class PageDistortionStepOutput.
97class ElementFlattener(Generic[_E]): 98 99 def __init__(self, grouped_elements: Sequence[Sequence[_E]]): 100 self.grouped_elements = grouped_elements 101 self.group_sizes = [len(elements) for elements in grouped_elements] 102 103 def flatten(self): 104 return tuple(itertools.chain.from_iterable(self.grouped_elements)) 105 106 def unflatten(self, flattened_elements: Sequence[_E]) -> Sequence[Sequence[_E]]: 107 assert len(flattened_elements) == sum(self.group_sizes) 108 grouped_elements: List[Sequence[_E]] = [] 109 begin = 0 110 for group_size in self.group_sizes: 111 end = begin + group_size 112 grouped_elements.append(flattened_elements[begin:end]) 113 begin = end 114 return grouped_elements
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
106 def unflatten(self, flattened_elements: Sequence[_E]) -> Sequence[Sequence[_E]]: 107 assert len(flattened_elements) == sum(self.group_sizes) 108 grouped_elements: List[Sequence[_E]] = [] 109 begin = 0 110 for group_size in self.group_sizes: 111 end = begin + group_size 112 grouped_elements.append(flattened_elements[begin:end]) 113 begin = end 114 return grouped_elements
117class PageDistortionStep( 118 PipelineStep[ 119 PageDistortionStepConfig, 120 PageDistortionStepInput, 121 PageDistortionStepOutput, 122 ] 123): # yapf: disable 124 125 def __init__(self, config: PageDistortionStepConfig): 126 super().__init__(config) 127 128 self.random_distortion = random_distortion_factory.create( 129 self.config.random_distortion_factory_config 130 ) 131 132 @classmethod 133 def fill_page_inactive_region( 134 cls, 135 page_image: Image, 136 page_active_mask: Mask, 137 page_bottom_layer_image: Image, 138 ): 139 assert page_image.shape == page_active_mask.shape 140 141 if page_bottom_layer_image.shape != page_image.shape: 142 page_bottom_layer_image = page_bottom_layer_image.to_resized_image( 143 resized_height=page_image.height, 144 resized_width=page_image.width, 145 ) 146 147 page_active_mask.to_inverted_mask().fill_image(page_image, page_bottom_layer_image) 148 149 def generate_text_line_labelings( 150 self, 151 distorted_image: Image, 152 text_line_polygons: Sequence[Polygon], 153 text_line_height_points_up: PointList, 154 text_line_height_points_down: PointList, 155 text_line_height_points_group_sizes: Sequence[int], 156 ): 157 text_line_mask: Optional[Mask] = None 158 if self.config.enable_distorted_text_line_mask: 159 text_line_mask = Mask.from_shapable(distorted_image) 160 for polygon in text_line_polygons: 161 polygon.fill_mask(text_line_mask) 162 163 text_line_height_score_map: Optional[ScoreMap] = None 164 text_line_heights: Optional[List[float]] = None 165 text_line_heights_debug_image: Optional[Image] = None 166 167 if self.config.enable_distorted_text_line_height_score_map: 168 np_height_points_up = text_line_height_points_up.to_smooth_np_array() 169 np_height_points_down = text_line_height_points_down.to_smooth_np_array() 170 np_heights: np.ndarray = np.linalg.norm( 171 np_height_points_down - np_height_points_up, 172 axis=1, 173 ) 174 # Add one to compensate. 175 np_heights += 1 176 assert sum(text_line_height_points_group_sizes) == np_heights.shape[0] 177 178 text_line_heights = [] 179 text_line_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False) 180 begin = 0 181 for polygon, group_size in zip(text_line_polygons, text_line_height_points_group_sizes): 182 end = begin + group_size - 1 183 text_line_height = float(np_heights[begin:end + 1].mean()) 184 text_line_heights.append(text_line_height) 185 polygon.fill_score_map( 186 score_map=text_line_height_score_map, 187 value=text_line_height, 188 ) 189 begin = end + 1 190 191 if self.config.enable_debug_distorted_text_line_heights: 192 painter = Painter.create(distorted_image) 193 painter.paint_polygons(text_line_polygons) 194 195 texts: List[str] = [] 196 points = PointList() 197 for polygon, height in zip(text_line_polygons, text_line_heights): 198 texts.append(f'{height:.1f}') 199 points.append(polygon.get_center_point()) 200 painter.paint_texts(texts, points, alpha=1.0) 201 202 text_line_heights_debug_image = painter.image 203 204 return ( 205 text_line_mask, 206 text_line_height_score_map, 207 text_line_heights, 208 text_line_heights_debug_image, 209 ) 210 211 def generate_char_labelings( 212 self, 213 distorted_image: Image, 214 char_polygons: Sequence[Polygon], 215 char_height_points_up: PointList, 216 char_height_points_down: PointList, 217 ): 218 char_mask: Optional[Mask] = None 219 if self.config.enable_distorted_text_line_mask: 220 char_mask = Mask.from_shapable(distorted_image) 221 for polygon in char_polygons: 222 polygon.fill_mask(char_mask) 223 224 char_height_score_map: Optional[ScoreMap] = None 225 char_heights: Optional[List[float]] = None 226 char_heights_debug_image: Optional[Image] = None 227 228 if self.config.enable_distorted_char_height_score_map: 229 np_height_points_up = char_height_points_up.to_smooth_np_array() 230 np_height_points_down = char_height_points_down.to_smooth_np_array() 231 np_heights: np.ndarray = np.linalg.norm( 232 np_height_points_down - np_height_points_up, 233 axis=1, 234 ) 235 # Add one to compensate. 236 np_heights += 1 237 238 # Fill from large height to small height, 239 # in order to preserve small height labeling when two char boxes overlapped. 240 sorted_char_polygon_indices: Tuple[int, ...] = tuple(reversed(np_heights.argsort())) 241 242 char_heights = [0.0] * len(char_polygons) 243 char_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False) 244 245 for idx in sorted_char_polygon_indices: 246 polygon = char_polygons[idx] 247 char_height = float(np_heights[idx]) 248 char_heights[idx] = char_height 249 polygon.fill_score_map( 250 score_map=char_height_score_map, 251 value=char_height, 252 ) 253 254 if self.config.enable_debug_distorted_char_heights: 255 painter = Painter.create(distorted_image) 256 painter.paint_polygons(char_polygons) 257 258 texts: List[str] = [] 259 points = PointList() 260 for polygon, height in zip(char_polygons, char_heights): 261 texts.append(f'{height:.1f}') 262 points.append(polygon.get_center_point()) 263 painter.paint_texts(texts, points, alpha=1.0) 264 265 char_heights_debug_image = painter.image 266 267 return ( 268 char_mask, 269 char_height_score_map, 270 char_heights, 271 char_heights_debug_image, 272 ) 273 274 def run(self, input: PageDistortionStepInput, rng: RandomGenerator): 275 page_assembler_step_output = input.page_assembler_step_output 276 page = page_assembler_step_output.page 277 page_bottom_layer_image = page.page_bottom_layer_image 278 page_char_polygon_collection = page.page_char_polygon_collection 279 page_text_line_polygon_collection = page.page_text_line_polygon_collection 280 page_disconnected_text_region_collection = page.page_disconnected_text_region_collection 281 page_non_text_region_collection = page.page_non_text_region_collection 282 283 # Flatten. 284 polygon_flattener = ElementFlattener([ 285 # Char level. 286 page_char_polygon_collection.polygons, 287 # Text line level. 288 page_text_line_polygon_collection.polygons, 289 # For char-level polygon regression. 290 tuple(page_disconnected_text_region_collection.to_polygons()), 291 # For sampling negative text region area. 292 tuple(page_non_text_region_collection.to_polygons()), 293 ]) 294 point_flattener = ElementFlattener([ 295 # Char level. 296 page_char_polygon_collection.height_points_up, 297 page_char_polygon_collection.height_points_down, 298 # Text line level. 299 page_text_line_polygon_collection.height_points_up, 300 page_text_line_polygon_collection.height_points_down, 301 ]) 302 303 # Distort. 304 page_random_distortion_debug = None 305 if self.config.enable_debug_random_distortion: 306 page_random_distortion_debug = RandomDistortionDebug() 307 308 page_active_mask = Mask.from_shapable(page.image, value=1) 309 # To mitigate a bug in cv.remap, in which the border interpolation is wrong. 310 # This mitigation DO remove 1-pixel width border, but it should be fine. 311 with page_active_mask.writable_context: 312 page_active_mask.mat[0] = 0 313 page_active_mask.mat[-1] = 0 314 page_active_mask.mat[:, 0] = 0 315 page_active_mask.mat[:, -1] = 0 316 317 result = self.random_distortion.distort( 318 image=page.image, 319 mask=page_active_mask, 320 polygons=polygon_flattener.flatten(), 321 points=PointList(point_flattener.flatten()), 322 rng=rng, 323 debug=page_random_distortion_debug, 324 ) 325 assert result.image 326 assert result.mask 327 assert result.polygons 328 assert result.points 329 330 # Fill inplace the inactive (black) region with page_bottom_layer_image. 331 self.fill_page_inactive_region( 332 page_image=result.image, 333 page_active_mask=result.mask, 334 page_bottom_layer_image=page_bottom_layer_image, 335 ) 336 337 # Unflatten. 338 ( 339 # Char level. 340 char_polygons, 341 # Text line level. 342 text_line_polygons, 343 # For char-level polygon regression. 344 disconnected_text_region_polygons, 345 # For sampling negative text region area. 346 non_text_region_polygons, 347 ) = polygon_flattener.unflatten(result.polygons) 348 349 ( 350 # Char level. 351 char_height_points_up, 352 char_height_points_down, 353 # Text line level. 354 text_line_height_points_up, 355 text_line_height_points_down, 356 ) = map(PointList, point_flattener.unflatten(result.points)) 357 358 text_line_height_points_group_sizes = \ 359 page_text_line_polygon_collection.height_points_group_sizes 360 assert len(text_line_polygons) == len(text_line_height_points_group_sizes) 361 assert len(text_line_height_points_up) == len(text_line_height_points_down) 362 363 # Labelings. 364 ( 365 text_line_mask, 366 text_line_height_score_map, 367 text_line_heights, 368 text_line_heights_debug_image, 369 ) = self.generate_text_line_labelings( 370 distorted_image=result.image, 371 text_line_polygons=text_line_polygons, 372 text_line_height_points_up=text_line_height_points_up, 373 text_line_height_points_down=text_line_height_points_down, 374 text_line_height_points_group_sizes=text_line_height_points_group_sizes, 375 ) 376 ( 377 char_mask, 378 char_height_score_map, 379 char_heights, 380 char_heights_debug_image, 381 ) = self.generate_char_labelings( 382 distorted_image=result.image, 383 char_polygons=char_polygons, 384 char_height_points_up=char_height_points_up, 385 char_height_points_down=char_height_points_down, 386 ) 387 388 return PageDistortionStepOutput( 389 page_image=result.image, 390 page_random_distortion_debug=page_random_distortion_debug, 391 page_active_mask=result.mask, 392 page_char_polygon_collection=PageCharPolygonCollection( 393 height=result.image.height, 394 width=result.image.width, 395 polygons=char_polygons, 396 height_points_up=char_height_points_up, 397 height_points_down=char_height_points_down, 398 ), 399 page_char_mask=char_mask, 400 page_char_height_score_map=char_height_score_map, 401 page_char_heights=char_heights, 402 page_char_heights_debug_image=char_heights_debug_image, 403 page_text_line_polygon_collection=PageTextLinePolygonCollection( 404 height=result.image.height, 405 width=result.image.width, 406 polygons=text_line_polygons, 407 height_points_group_sizes=text_line_height_points_group_sizes, 408 height_points_up=text_line_height_points_up, 409 height_points_down=text_line_height_points_down, 410 ), 411 page_text_line_mask=text_line_mask, 412 page_text_line_height_score_map=text_line_height_score_map, 413 page_text_line_heights=text_line_heights, 414 page_text_line_heights_debug_image=text_line_heights_debug_image, 415 page_disconnected_text_region_collection=PageDisconnectedTextRegionCollection( 416 disconnected_text_regions=[ 417 DisconnectedTextRegion(disconnected_text_region_polygon) 418 for disconnected_text_region_polygon in disconnected_text_region_polygons 419 ], 420 ), 421 page_non_text_region_collection=PageNonTextRegionCollection( 422 non_text_regions=[ 423 NonTextRegion(non_text_region_polygon) 424 for non_text_region_polygon in non_text_region_polygons 425 ], 426 ) 427 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
132 @classmethod 133 def fill_page_inactive_region( 134 cls, 135 page_image: Image, 136 page_active_mask: Mask, 137 page_bottom_layer_image: Image, 138 ): 139 assert page_image.shape == page_active_mask.shape 140 141 if page_bottom_layer_image.shape != page_image.shape: 142 page_bottom_layer_image = page_bottom_layer_image.to_resized_image( 143 resized_height=page_image.height, 144 resized_width=page_image.width, 145 ) 146 147 page_active_mask.to_inverted_mask().fill_image(page_image, page_bottom_layer_image)
149 def generate_text_line_labelings( 150 self, 151 distorted_image: Image, 152 text_line_polygons: Sequence[Polygon], 153 text_line_height_points_up: PointList, 154 text_line_height_points_down: PointList, 155 text_line_height_points_group_sizes: Sequence[int], 156 ): 157 text_line_mask: Optional[Mask] = None 158 if self.config.enable_distorted_text_line_mask: 159 text_line_mask = Mask.from_shapable(distorted_image) 160 for polygon in text_line_polygons: 161 polygon.fill_mask(text_line_mask) 162 163 text_line_height_score_map: Optional[ScoreMap] = None 164 text_line_heights: Optional[List[float]] = None 165 text_line_heights_debug_image: Optional[Image] = None 166 167 if self.config.enable_distorted_text_line_height_score_map: 168 np_height_points_up = text_line_height_points_up.to_smooth_np_array() 169 np_height_points_down = text_line_height_points_down.to_smooth_np_array() 170 np_heights: np.ndarray = np.linalg.norm( 171 np_height_points_down - np_height_points_up, 172 axis=1, 173 ) 174 # Add one to compensate. 175 np_heights += 1 176 assert sum(text_line_height_points_group_sizes) == np_heights.shape[0] 177 178 text_line_heights = [] 179 text_line_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False) 180 begin = 0 181 for polygon, group_size in zip(text_line_polygons, text_line_height_points_group_sizes): 182 end = begin + group_size - 1 183 text_line_height = float(np_heights[begin:end + 1].mean()) 184 text_line_heights.append(text_line_height) 185 polygon.fill_score_map( 186 score_map=text_line_height_score_map, 187 value=text_line_height, 188 ) 189 begin = end + 1 190 191 if self.config.enable_debug_distorted_text_line_heights: 192 painter = Painter.create(distorted_image) 193 painter.paint_polygons(text_line_polygons) 194 195 texts: List[str] = [] 196 points = PointList() 197 for polygon, height in zip(text_line_polygons, text_line_heights): 198 texts.append(f'{height:.1f}') 199 points.append(polygon.get_center_point()) 200 painter.paint_texts(texts, points, alpha=1.0) 201 202 text_line_heights_debug_image = painter.image 203 204 return ( 205 text_line_mask, 206 text_line_height_score_map, 207 text_line_heights, 208 text_line_heights_debug_image, 209 )
211 def generate_char_labelings( 212 self, 213 distorted_image: Image, 214 char_polygons: Sequence[Polygon], 215 char_height_points_up: PointList, 216 char_height_points_down: PointList, 217 ): 218 char_mask: Optional[Mask] = None 219 if self.config.enable_distorted_text_line_mask: 220 char_mask = Mask.from_shapable(distorted_image) 221 for polygon in char_polygons: 222 polygon.fill_mask(char_mask) 223 224 char_height_score_map: Optional[ScoreMap] = None 225 char_heights: Optional[List[float]] = None 226 char_heights_debug_image: Optional[Image] = None 227 228 if self.config.enable_distorted_char_height_score_map: 229 np_height_points_up = char_height_points_up.to_smooth_np_array() 230 np_height_points_down = char_height_points_down.to_smooth_np_array() 231 np_heights: np.ndarray = np.linalg.norm( 232 np_height_points_down - np_height_points_up, 233 axis=1, 234 ) 235 # Add one to compensate. 236 np_heights += 1 237 238 # Fill from large height to small height, 239 # in order to preserve small height labeling when two char boxes overlapped. 240 sorted_char_polygon_indices: Tuple[int, ...] = tuple(reversed(np_heights.argsort())) 241 242 char_heights = [0.0] * len(char_polygons) 243 char_height_score_map = ScoreMap.from_shapable(distorted_image, is_prob=False) 244 245 for idx in sorted_char_polygon_indices: 246 polygon = char_polygons[idx] 247 char_height = float(np_heights[idx]) 248 char_heights[idx] = char_height 249 polygon.fill_score_map( 250 score_map=char_height_score_map, 251 value=char_height, 252 ) 253 254 if self.config.enable_debug_distorted_char_heights: 255 painter = Painter.create(distorted_image) 256 painter.paint_polygons(char_polygons) 257 258 texts: List[str] = [] 259 points = PointList() 260 for polygon, height in zip(char_polygons, char_heights): 261 texts.append(f'{height:.1f}') 262 points.append(polygon.get_center_point()) 263 painter.paint_texts(texts, points, alpha=1.0) 264 265 char_heights_debug_image = painter.image 266 267 return ( 268 char_mask, 269 char_height_score_map, 270 char_heights, 271 char_heights_debug_image, 272 )
274 def run(self, input: PageDistortionStepInput, rng: RandomGenerator): 275 page_assembler_step_output = input.page_assembler_step_output 276 page = page_assembler_step_output.page 277 page_bottom_layer_image = page.page_bottom_layer_image 278 page_char_polygon_collection = page.page_char_polygon_collection 279 page_text_line_polygon_collection = page.page_text_line_polygon_collection 280 page_disconnected_text_region_collection = page.page_disconnected_text_region_collection 281 page_non_text_region_collection = page.page_non_text_region_collection 282 283 # Flatten. 284 polygon_flattener = ElementFlattener([ 285 # Char level. 286 page_char_polygon_collection.polygons, 287 # Text line level. 288 page_text_line_polygon_collection.polygons, 289 # For char-level polygon regression. 290 tuple(page_disconnected_text_region_collection.to_polygons()), 291 # For sampling negative text region area. 292 tuple(page_non_text_region_collection.to_polygons()), 293 ]) 294 point_flattener = ElementFlattener([ 295 # Char level. 296 page_char_polygon_collection.height_points_up, 297 page_char_polygon_collection.height_points_down, 298 # Text line level. 299 page_text_line_polygon_collection.height_points_up, 300 page_text_line_polygon_collection.height_points_down, 301 ]) 302 303 # Distort. 304 page_random_distortion_debug = None 305 if self.config.enable_debug_random_distortion: 306 page_random_distortion_debug = RandomDistortionDebug() 307 308 page_active_mask = Mask.from_shapable(page.image, value=1) 309 # To mitigate a bug in cv.remap, in which the border interpolation is wrong. 310 # This mitigation DO remove 1-pixel width border, but it should be fine. 311 with page_active_mask.writable_context: 312 page_active_mask.mat[0] = 0 313 page_active_mask.mat[-1] = 0 314 page_active_mask.mat[:, 0] = 0 315 page_active_mask.mat[:, -1] = 0 316 317 result = self.random_distortion.distort( 318 image=page.image, 319 mask=page_active_mask, 320 polygons=polygon_flattener.flatten(), 321 points=PointList(point_flattener.flatten()), 322 rng=rng, 323 debug=page_random_distortion_debug, 324 ) 325 assert result.image 326 assert result.mask 327 assert result.polygons 328 assert result.points 329 330 # Fill inplace the inactive (black) region with page_bottom_layer_image. 331 self.fill_page_inactive_region( 332 page_image=result.image, 333 page_active_mask=result.mask, 334 page_bottom_layer_image=page_bottom_layer_image, 335 ) 336 337 # Unflatten. 338 ( 339 # Char level. 340 char_polygons, 341 # Text line level. 342 text_line_polygons, 343 # For char-level polygon regression. 344 disconnected_text_region_polygons, 345 # For sampling negative text region area. 346 non_text_region_polygons, 347 ) = polygon_flattener.unflatten(result.polygons) 348 349 ( 350 # Char level. 351 char_height_points_up, 352 char_height_points_down, 353 # Text line level. 354 text_line_height_points_up, 355 text_line_height_points_down, 356 ) = map(PointList, point_flattener.unflatten(result.points)) 357 358 text_line_height_points_group_sizes = \ 359 page_text_line_polygon_collection.height_points_group_sizes 360 assert len(text_line_polygons) == len(text_line_height_points_group_sizes) 361 assert len(text_line_height_points_up) == len(text_line_height_points_down) 362 363 # Labelings. 364 ( 365 text_line_mask, 366 text_line_height_score_map, 367 text_line_heights, 368 text_line_heights_debug_image, 369 ) = self.generate_text_line_labelings( 370 distorted_image=result.image, 371 text_line_polygons=text_line_polygons, 372 text_line_height_points_up=text_line_height_points_up, 373 text_line_height_points_down=text_line_height_points_down, 374 text_line_height_points_group_sizes=text_line_height_points_group_sizes, 375 ) 376 ( 377 char_mask, 378 char_height_score_map, 379 char_heights, 380 char_heights_debug_image, 381 ) = self.generate_char_labelings( 382 distorted_image=result.image, 383 char_polygons=char_polygons, 384 char_height_points_up=char_height_points_up, 385 char_height_points_down=char_height_points_down, 386 ) 387 388 return PageDistortionStepOutput( 389 page_image=result.image, 390 page_random_distortion_debug=page_random_distortion_debug, 391 page_active_mask=result.mask, 392 page_char_polygon_collection=PageCharPolygonCollection( 393 height=result.image.height, 394 width=result.image.width, 395 polygons=char_polygons, 396 height_points_up=char_height_points_up, 397 height_points_down=char_height_points_down, 398 ), 399 page_char_mask=char_mask, 400 page_char_height_score_map=char_height_score_map, 401 page_char_heights=char_heights, 402 page_char_heights_debug_image=char_heights_debug_image, 403 page_text_line_polygon_collection=PageTextLinePolygonCollection( 404 height=result.image.height, 405 width=result.image.width, 406 polygons=text_line_polygons, 407 height_points_group_sizes=text_line_height_points_group_sizes, 408 height_points_up=text_line_height_points_up, 409 height_points_down=text_line_height_points_down, 410 ), 411 page_text_line_mask=text_line_mask, 412 page_text_line_height_score_map=text_line_height_score_map, 413 page_text_line_heights=text_line_heights, 414 page_text_line_heights_debug_image=text_line_heights_debug_image, 415 page_disconnected_text_region_collection=PageDisconnectedTextRegionCollection( 416 disconnected_text_regions=[ 417 DisconnectedTextRegion(disconnected_text_region_polygon) 418 for disconnected_text_region_polygon in disconnected_text_region_polygons 419 ], 420 ), 421 page_non_text_region_collection=PageNonTextRegionCollection( 422 non_text_regions=[ 423 NonTextRegion(non_text_region_polygon) 424 for non_text_region_polygon in non_text_region_polygons 425 ], 426 ) 427 )