vkit.pipeline.text_detection.page_text_region_cropping
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Sequence, Tuple, List, Optional 15import itertools 16 17import attrs 18from numpy.random import Generator as RandomGenerator 19from shapely.strtree import STRtree 20from shapely.geometry import Point as ShapelyPoint 21import cv2 as cv 22 23from vkit.element import Box, Mask, ScoreMap, Image 24from vkit.mechanism.distortion import rotate 25from vkit.mechanism.cropper import Cropper 26from ..interface import PipelineStep, PipelineStepFactory 27from .page_cropping import PageCroppingStepOutput 28from .page_text_region import PageTextRegionStepOutput 29from .page_text_region_label import ( 30 PageCharRegressionLabelTag, 31 PageCharRegressionLabel, 32 PageTextRegionLabelStepOutput, 33) 34 35 36@attrs.define 37class PageTextRegionCroppingStepConfig: 38 core_size: int 39 pad_size: int 40 num_samples_factor_relative_to_num_cropped_pages: float = 1.0 41 num_centroid_points_min: int = 10 42 num_deviate_points_min: int = 10 43 pad_value: int = 0 44 enable_downsample_labeling: bool = True 45 downsample_labeling_factor: int = 2 46 47 48@attrs.define 49class PageTextRegionCroppingStepInput: 50 page_cropping_step_output: PageCroppingStepOutput 51 page_text_region_step_output: PageTextRegionStepOutput 52 page_text_region_label_step_output: PageTextRegionLabelStepOutput 53 54 55@attrs.define 56class DownsampledLabel: 57 shape: Tuple[int, int] 58 page_char_mask: Mask 59 page_char_height_score_map: ScoreMap 60 page_char_gaussian_score_map: ScoreMap 61 page_char_regression_labels: Sequence[PageCharRegressionLabel] 62 page_char_bounding_box_mask: Mask 63 target_core_box: Box 64 65 66@attrs.define 67class CroppedPageTextRegion: 68 page_image: Image 69 page_char_mask: Mask 70 page_char_height_score_map: ScoreMap 71 page_char_gaussian_score_map: ScoreMap 72 page_char_regression_labels: Sequence[PageCharRegressionLabel] 73 page_char_bounding_box_mask: Mask 74 target_core_box: Box 75 downsampled_label: Optional[DownsampledLabel] 76 77 78@attrs.define 79class PageTextRegionCroppingStepOutput: 80 cropped_page_text_regions: Sequence[CroppedPageTextRegion] 81 82 83class PageTextRegionCroppingStep( 84 PipelineStep[ 85 PageTextRegionCroppingStepConfig, 86 PageTextRegionCroppingStepInput, 87 PageTextRegionCroppingStepOutput, 88 ] 89): # yapf: disable 90 91 @classmethod 92 def build_strtree_for_page_char_regression_labels( 93 cls, 94 labels: Sequence[PageCharRegressionLabel], 95 ): 96 shapely_points: List[ShapelyPoint] = [] 97 98 for label in labels: 99 # Original resolution. 100 assert not label.is_downsampled 101 # As int. 102 xy_pair = (label.downsampled_label_point_x, label.downsampled_label_point_y) 103 shapely_points.append(ShapelyPoint(*xy_pair)) 104 105 strtree = STRtree(shapely_points) 106 return strtree 107 108 def sample_cropped_page_text_regions( 109 self, 110 page_image: Image, 111 shape_before_rotate: Tuple[int, int], 112 rotate_angle: int, 113 page_char_mask: Mask, 114 page_char_height_score_map: ScoreMap, 115 page_char_gaussian_score_map: ScoreMap, 116 page_char_bounding_box_mask: Mask, 117 centroid_strtree: STRtree, 118 centroid_page_char_regression_labels: Sequence[PageCharRegressionLabel], 119 deviate_strtree: STRtree, 120 deviate_page_char_regression_labels: Sequence[PageCharRegressionLabel], 121 rng: RandomGenerator, 122 ): 123 if rotate_angle != 0: 124 cropper_before_rotate = Cropper.create_from_random_proposal( 125 shape=shape_before_rotate, 126 core_size=self.config.core_size, 127 pad_size=self.config.pad_size, 128 pad_value=self.config.pad_value, 129 rng=rng, 130 ) 131 original_box_before_rotate = cropper_before_rotate.cropper_state.original_box 132 center_point_before_rotate = original_box_before_rotate.get_center_point() 133 134 rotated_result = rotate.distort( 135 {'angle': rotate_angle}, 136 shapable_or_shape=shape_before_rotate, 137 point=center_point_before_rotate, 138 ) 139 assert rotated_result.shape == page_image.shape 140 center_point = rotated_result.point 141 assert center_point 142 143 cropper = Cropper.create_from_center_point( 144 shape=page_image.shape, 145 core_size=self.config.core_size, 146 pad_size=self.config.pad_size, 147 pad_value=self.config.pad_value, 148 center_point=center_point, 149 ) 150 151 else: 152 cropper = Cropper.create_from_random_proposal( 153 shape=page_image.shape, 154 core_size=self.config.core_size, 155 pad_size=self.config.pad_size, 156 pad_value=self.config.pad_value, 157 rng=rng, 158 ) 159 160 # Remove labels out of the original core box. 161 original_core_shapely_polygon = cropper.original_core_box.to_shapely_polygon() 162 163 centroid_labels: List[PageCharRegressionLabel] = [] 164 for centroid_page_char_regression_label_idx in sorted( 165 centroid_strtree.query( 166 original_core_shapely_polygon, 167 predicate='intersects', 168 ) 169 ): 170 centroid_label = \ 171 centroid_page_char_regression_labels[centroid_page_char_regression_label_idx] 172 centroid_labels.append(centroid_label) 173 174 preserved_char_indices = set(centroid_label.char_idx for centroid_label in centroid_labels) 175 deviate_labels: List[PageCharRegressionLabel] = [] 176 for deviate_page_char_regression_label_idx in sorted( 177 deviate_strtree.query( 178 original_core_shapely_polygon, 179 predicate='intersects', 180 ) 181 ): 182 deviate_label = \ 183 deviate_page_char_regression_labels[deviate_page_char_regression_label_idx] 184 if deviate_label.char_idx not in preserved_char_indices: 185 # If the centroid is not preserved, ignore this deviate label as well. 186 continue 187 deviate_labels.append(deviate_label) 188 189 if len(centroid_labels) < self.config.num_centroid_points_min \ 190 or len(deviate_labels) < self.config.num_deviate_points_min: 191 return None 192 193 # Shift labels. 194 offset_y = cropper.target_box.up - cropper.original_box.up 195 offset_x = cropper.target_box.left - cropper.original_box.left 196 shifted_centroid_labels = [ 197 centroid_label.to_shifted_page_char_regression_label( 198 offset_y=offset_y, 199 offset_x=offset_x, 200 ) for centroid_label in centroid_labels 201 ] 202 shifted_deviate_labels = [ 203 deviate_label.to_shifted_page_char_regression_label( 204 offset_y=offset_y, 205 offset_x=offset_x, 206 ) for deviate_label in deviate_labels 207 ] 208 209 # Crop image and labeling maps. 210 page_image = cropper.crop_image(page_image) 211 page_char_mask = cropper.crop_mask( 212 page_char_mask, 213 core_only=True, 214 ) 215 page_char_height_score_map = cropper.crop_score_map( 216 page_char_height_score_map, 217 core_only=True, 218 ) 219 page_char_gaussian_score_map = cropper.crop_score_map( 220 page_char_gaussian_score_map, 221 core_only=True, 222 ) 223 page_char_bounding_box_mask = cropper.crop_mask( 224 page_char_bounding_box_mask, 225 core_only=True, 226 ) 227 228 downsampled_label: Optional[DownsampledLabel] = None 229 if self.config.enable_downsample_labeling: 230 downsample_labeling_factor = self.config.downsample_labeling_factor 231 232 assert cropper.crop_size % downsample_labeling_factor == 0 233 downsampled_size = cropper.crop_size // downsample_labeling_factor 234 downsampled_shape = (downsampled_size, downsampled_size) 235 236 assert self.config.pad_size % downsample_labeling_factor == 0 237 assert self.config.core_size % downsample_labeling_factor == 0 238 assert cropper.target_core_box.height \ 239 == cropper.target_core_box.width \ 240 == self.config.core_size 241 242 downsampled_pad_size = self.config.pad_size // downsample_labeling_factor 243 downsampled_core_size = self.config.core_size // downsample_labeling_factor 244 245 downsampled_target_core_begin = downsampled_pad_size 246 downsampled_target_core_end = downsampled_target_core_begin + downsampled_core_size - 1 247 downsampled_target_core_box = Box( 248 up=downsampled_target_core_begin, 249 down=downsampled_target_core_end, 250 left=downsampled_target_core_begin, 251 right=downsampled_target_core_end, 252 ) 253 254 downsampled_page_char_mask = page_char_mask.to_box_detached() 255 downsampled_page_char_mask = \ 256 downsampled_page_char_mask.to_resized_mask( 257 resized_height=downsampled_core_size, 258 resized_width=downsampled_core_size, 259 cv_resize_interpolation=cv.INTER_AREA, 260 ) 261 262 downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached() 263 downsampled_page_char_height_score_map = \ 264 downsampled_page_char_height_score_map.to_resized_score_map( 265 resized_height=downsampled_core_size, 266 resized_width=downsampled_core_size, 267 cv_resize_interpolation=cv.INTER_AREA, 268 ) 269 270 downsampled_page_char_gaussian_score_map = \ 271 page_char_gaussian_score_map.to_box_detached() 272 downsampled_page_char_gaussian_score_map = \ 273 downsampled_page_char_gaussian_score_map.to_resized_score_map( 274 resized_height=downsampled_core_size, 275 resized_width=downsampled_core_size, 276 cv_resize_interpolation=cv.INTER_AREA, 277 ) 278 279 downsampled_page_char_bounding_box_mask = \ 280 page_char_bounding_box_mask.to_box_detached() 281 downsampled_page_char_bounding_box_mask = \ 282 downsampled_page_char_bounding_box_mask.to_resized_mask( 283 resized_height=downsampled_core_size, 284 resized_width=downsampled_core_size, 285 cv_resize_interpolation=cv.INTER_AREA, 286 ) 287 288 downsampled_page_char_regression_labels = [ 289 label.to_downsampled_page_char_regression_label( 290 self.config.downsample_labeling_factor 291 ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels) 292 ] 293 294 downsampled_label = DownsampledLabel( 295 shape=downsampled_shape, 296 page_char_mask=downsampled_page_char_mask, 297 page_char_height_score_map=downsampled_page_char_height_score_map, 298 page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map, 299 page_char_regression_labels=downsampled_page_char_regression_labels, 300 page_char_bounding_box_mask=downsampled_page_char_bounding_box_mask, 301 target_core_box=downsampled_target_core_box, 302 ) 303 304 return CroppedPageTextRegion( 305 page_image=page_image, 306 page_char_mask=page_char_mask, 307 page_char_height_score_map=page_char_height_score_map, 308 page_char_gaussian_score_map=page_char_gaussian_score_map, 309 page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels, 310 page_char_bounding_box_mask=page_char_bounding_box_mask, 311 target_core_box=cropper.target_core_box, 312 downsampled_label=downsampled_label, 313 ) 314 315 def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator): 316 page_cropping_step_output = input.page_cropping_step_output 317 num_cropped_pages = len(page_cropping_step_output.cropped_pages) 318 319 page_text_region_step_output = input.page_text_region_step_output 320 page_image = page_text_region_step_output.page_image 321 shape_before_rotate = page_text_region_step_output.shape_before_rotate 322 rotate_angle = page_text_region_step_output.rotate_angle 323 324 page_text_region_label_step_output = input.page_text_region_label_step_output 325 page_char_mask = page_text_region_label_step_output.page_char_mask 326 page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map 327 page_char_gaussian_score_map = \ 328 page_text_region_label_step_output.page_char_gaussian_score_map 329 page_char_regression_labels = \ 330 page_text_region_label_step_output.page_char_regression_labels 331 page_char_bounding_box_mask = \ 332 page_text_region_label_step_output.page_char_bounding_box_mask 333 334 centroid_page_char_regression_labels = [ 335 label for label in page_char_regression_labels 336 if label.tag == PageCharRegressionLabelTag.CENTROID 337 ] 338 centroid_strtree = self.build_strtree_for_page_char_regression_labels( 339 centroid_page_char_regression_labels 340 ) 341 342 deviate_page_char_regression_labels = [ 343 label for label in page_char_regression_labels 344 if label.tag == PageCharRegressionLabelTag.DEVIATE 345 ] 346 deviate_strtree = self.build_strtree_for_page_char_regression_labels( 347 deviate_page_char_regression_labels 348 ) 349 350 num_samples = round( 351 self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages 352 ) 353 354 run_count_max = max(3, 2 * num_samples) 355 run_count = 0 356 357 cropped_page_text_regions: List[CroppedPageTextRegion] = [] 358 359 while len(cropped_page_text_regions) < num_samples and run_count < run_count_max: 360 cropped_page_text_region = self.sample_cropped_page_text_regions( 361 page_image=page_image, 362 shape_before_rotate=shape_before_rotate, 363 rotate_angle=rotate_angle, 364 page_char_mask=page_char_mask, 365 page_char_height_score_map=page_char_height_score_map, 366 page_char_gaussian_score_map=page_char_gaussian_score_map, 367 page_char_bounding_box_mask=page_char_bounding_box_mask, 368 centroid_strtree=centroid_strtree, 369 centroid_page_char_regression_labels=centroid_page_char_regression_labels, 370 deviate_strtree=deviate_strtree, 371 deviate_page_char_regression_labels=deviate_page_char_regression_labels, 372 rng=rng, 373 ) 374 if cropped_page_text_region: 375 cropped_page_text_regions.append(cropped_page_text_region) 376 run_count += 1 377 378 return PageTextRegionCroppingStepOutput( 379 cropped_page_text_regions=cropped_page_text_regions, 380 ) 381 382 383page_text_region_cropping_step_factory = PipelineStepFactory(PageTextRegionCroppingStep)
38class PageTextRegionCroppingStepConfig: 39 core_size: int 40 pad_size: int 41 num_samples_factor_relative_to_num_cropped_pages: float = 1.0 42 num_centroid_points_min: int = 10 43 num_deviate_points_min: int = 10 44 pad_value: int = 0 45 enable_downsample_labeling: bool = True 46 downsample_labeling_factor: int = 2
2def __init__(self, core_size, pad_size, num_samples_factor_relative_to_num_cropped_pages=attr_dict['num_samples_factor_relative_to_num_cropped_pages'].default, num_centroid_points_min=attr_dict['num_centroid_points_min'].default, num_deviate_points_min=attr_dict['num_deviate_points_min'].default, pad_value=attr_dict['pad_value'].default, enable_downsample_labeling=attr_dict['enable_downsample_labeling'].default, downsample_labeling_factor=attr_dict['downsample_labeling_factor'].default): 3 self.core_size = core_size 4 self.pad_size = pad_size 5 self.num_samples_factor_relative_to_num_cropped_pages = num_samples_factor_relative_to_num_cropped_pages 6 self.num_centroid_points_min = num_centroid_points_min 7 self.num_deviate_points_min = num_deviate_points_min 8 self.pad_value = pad_value 9 self.enable_downsample_labeling = enable_downsample_labeling 10 self.downsample_labeling_factor = downsample_labeling_factor
Method generated by attrs for class PageTextRegionCroppingStepConfig.
50class PageTextRegionCroppingStepInput: 51 page_cropping_step_output: PageCroppingStepOutput 52 page_text_region_step_output: PageTextRegionStepOutput 53 page_text_region_label_step_output: PageTextRegionLabelStepOutput
2def __init__(self, page_cropping_step_output, page_text_region_step_output, page_text_region_label_step_output): 3 self.page_cropping_step_output = page_cropping_step_output 4 self.page_text_region_step_output = page_text_region_step_output 5 self.page_text_region_label_step_output = page_text_region_label_step_output
Method generated by attrs for class PageTextRegionCroppingStepInput.
57class DownsampledLabel: 58 shape: Tuple[int, int] 59 page_char_mask: Mask 60 page_char_height_score_map: ScoreMap 61 page_char_gaussian_score_map: ScoreMap 62 page_char_regression_labels: Sequence[PageCharRegressionLabel] 63 page_char_bounding_box_mask: Mask 64 target_core_box: Box
2def __init__(self, shape, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, page_char_bounding_box_mask, target_core_box): 3 self.shape = shape 4 self.page_char_mask = page_char_mask 5 self.page_char_height_score_map = page_char_height_score_map 6 self.page_char_gaussian_score_map = page_char_gaussian_score_map 7 self.page_char_regression_labels = page_char_regression_labels 8 self.page_char_bounding_box_mask = page_char_bounding_box_mask 9 self.target_core_box = target_core_box
Method generated by attrs for class DownsampledLabel.
68class CroppedPageTextRegion: 69 page_image: Image 70 page_char_mask: Mask 71 page_char_height_score_map: ScoreMap 72 page_char_gaussian_score_map: ScoreMap 73 page_char_regression_labels: Sequence[PageCharRegressionLabel] 74 page_char_bounding_box_mask: Mask 75 target_core_box: Box 76 downsampled_label: Optional[DownsampledLabel]
2def __init__(self, page_image, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, page_char_bounding_box_mask, target_core_box, downsampled_label): 3 self.page_image = page_image 4 self.page_char_mask = page_char_mask 5 self.page_char_height_score_map = page_char_height_score_map 6 self.page_char_gaussian_score_map = page_char_gaussian_score_map 7 self.page_char_regression_labels = page_char_regression_labels 8 self.page_char_bounding_box_mask = page_char_bounding_box_mask 9 self.target_core_box = target_core_box 10 self.downsampled_label = downsampled_label
Method generated by attrs for class CroppedPageTextRegion.
80class PageTextRegionCroppingStepOutput: 81 cropped_page_text_regions: Sequence[CroppedPageTextRegion]
2def __init__(self, cropped_page_text_regions): 3 self.cropped_page_text_regions = cropped_page_text_regions
Method generated by attrs for class PageTextRegionCroppingStepOutput.
84class PageTextRegionCroppingStep( 85 PipelineStep[ 86 PageTextRegionCroppingStepConfig, 87 PageTextRegionCroppingStepInput, 88 PageTextRegionCroppingStepOutput, 89 ] 90): # yapf: disable 91 92 @classmethod 93 def build_strtree_for_page_char_regression_labels( 94 cls, 95 labels: Sequence[PageCharRegressionLabel], 96 ): 97 shapely_points: List[ShapelyPoint] = [] 98 99 for label in labels: 100 # Original resolution. 101 assert not label.is_downsampled 102 # As int. 103 xy_pair = (label.downsampled_label_point_x, label.downsampled_label_point_y) 104 shapely_points.append(ShapelyPoint(*xy_pair)) 105 106 strtree = STRtree(shapely_points) 107 return strtree 108 109 def sample_cropped_page_text_regions( 110 self, 111 page_image: Image, 112 shape_before_rotate: Tuple[int, int], 113 rotate_angle: int, 114 page_char_mask: Mask, 115 page_char_height_score_map: ScoreMap, 116 page_char_gaussian_score_map: ScoreMap, 117 page_char_bounding_box_mask: Mask, 118 centroid_strtree: STRtree, 119 centroid_page_char_regression_labels: Sequence[PageCharRegressionLabel], 120 deviate_strtree: STRtree, 121 deviate_page_char_regression_labels: Sequence[PageCharRegressionLabel], 122 rng: RandomGenerator, 123 ): 124 if rotate_angle != 0: 125 cropper_before_rotate = Cropper.create_from_random_proposal( 126 shape=shape_before_rotate, 127 core_size=self.config.core_size, 128 pad_size=self.config.pad_size, 129 pad_value=self.config.pad_value, 130 rng=rng, 131 ) 132 original_box_before_rotate = cropper_before_rotate.cropper_state.original_box 133 center_point_before_rotate = original_box_before_rotate.get_center_point() 134 135 rotated_result = rotate.distort( 136 {'angle': rotate_angle}, 137 shapable_or_shape=shape_before_rotate, 138 point=center_point_before_rotate, 139 ) 140 assert rotated_result.shape == page_image.shape 141 center_point = rotated_result.point 142 assert center_point 143 144 cropper = Cropper.create_from_center_point( 145 shape=page_image.shape, 146 core_size=self.config.core_size, 147 pad_size=self.config.pad_size, 148 pad_value=self.config.pad_value, 149 center_point=center_point, 150 ) 151 152 else: 153 cropper = Cropper.create_from_random_proposal( 154 shape=page_image.shape, 155 core_size=self.config.core_size, 156 pad_size=self.config.pad_size, 157 pad_value=self.config.pad_value, 158 rng=rng, 159 ) 160 161 # Remove labels out of the original core box. 162 original_core_shapely_polygon = cropper.original_core_box.to_shapely_polygon() 163 164 centroid_labels: List[PageCharRegressionLabel] = [] 165 for centroid_page_char_regression_label_idx in sorted( 166 centroid_strtree.query( 167 original_core_shapely_polygon, 168 predicate='intersects', 169 ) 170 ): 171 centroid_label = \ 172 centroid_page_char_regression_labels[centroid_page_char_regression_label_idx] 173 centroid_labels.append(centroid_label) 174 175 preserved_char_indices = set(centroid_label.char_idx for centroid_label in centroid_labels) 176 deviate_labels: List[PageCharRegressionLabel] = [] 177 for deviate_page_char_regression_label_idx in sorted( 178 deviate_strtree.query( 179 original_core_shapely_polygon, 180 predicate='intersects', 181 ) 182 ): 183 deviate_label = \ 184 deviate_page_char_regression_labels[deviate_page_char_regression_label_idx] 185 if deviate_label.char_idx not in preserved_char_indices: 186 # If the centroid is not preserved, ignore this deviate label as well. 187 continue 188 deviate_labels.append(deviate_label) 189 190 if len(centroid_labels) < self.config.num_centroid_points_min \ 191 or len(deviate_labels) < self.config.num_deviate_points_min: 192 return None 193 194 # Shift labels. 195 offset_y = cropper.target_box.up - cropper.original_box.up 196 offset_x = cropper.target_box.left - cropper.original_box.left 197 shifted_centroid_labels = [ 198 centroid_label.to_shifted_page_char_regression_label( 199 offset_y=offset_y, 200 offset_x=offset_x, 201 ) for centroid_label in centroid_labels 202 ] 203 shifted_deviate_labels = [ 204 deviate_label.to_shifted_page_char_regression_label( 205 offset_y=offset_y, 206 offset_x=offset_x, 207 ) for deviate_label in deviate_labels 208 ] 209 210 # Crop image and labeling maps. 211 page_image = cropper.crop_image(page_image) 212 page_char_mask = cropper.crop_mask( 213 page_char_mask, 214 core_only=True, 215 ) 216 page_char_height_score_map = cropper.crop_score_map( 217 page_char_height_score_map, 218 core_only=True, 219 ) 220 page_char_gaussian_score_map = cropper.crop_score_map( 221 page_char_gaussian_score_map, 222 core_only=True, 223 ) 224 page_char_bounding_box_mask = cropper.crop_mask( 225 page_char_bounding_box_mask, 226 core_only=True, 227 ) 228 229 downsampled_label: Optional[DownsampledLabel] = None 230 if self.config.enable_downsample_labeling: 231 downsample_labeling_factor = self.config.downsample_labeling_factor 232 233 assert cropper.crop_size % downsample_labeling_factor == 0 234 downsampled_size = cropper.crop_size // downsample_labeling_factor 235 downsampled_shape = (downsampled_size, downsampled_size) 236 237 assert self.config.pad_size % downsample_labeling_factor == 0 238 assert self.config.core_size % downsample_labeling_factor == 0 239 assert cropper.target_core_box.height \ 240 == cropper.target_core_box.width \ 241 == self.config.core_size 242 243 downsampled_pad_size = self.config.pad_size // downsample_labeling_factor 244 downsampled_core_size = self.config.core_size // downsample_labeling_factor 245 246 downsampled_target_core_begin = downsampled_pad_size 247 downsampled_target_core_end = downsampled_target_core_begin + downsampled_core_size - 1 248 downsampled_target_core_box = Box( 249 up=downsampled_target_core_begin, 250 down=downsampled_target_core_end, 251 left=downsampled_target_core_begin, 252 right=downsampled_target_core_end, 253 ) 254 255 downsampled_page_char_mask = page_char_mask.to_box_detached() 256 downsampled_page_char_mask = \ 257 downsampled_page_char_mask.to_resized_mask( 258 resized_height=downsampled_core_size, 259 resized_width=downsampled_core_size, 260 cv_resize_interpolation=cv.INTER_AREA, 261 ) 262 263 downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached() 264 downsampled_page_char_height_score_map = \ 265 downsampled_page_char_height_score_map.to_resized_score_map( 266 resized_height=downsampled_core_size, 267 resized_width=downsampled_core_size, 268 cv_resize_interpolation=cv.INTER_AREA, 269 ) 270 271 downsampled_page_char_gaussian_score_map = \ 272 page_char_gaussian_score_map.to_box_detached() 273 downsampled_page_char_gaussian_score_map = \ 274 downsampled_page_char_gaussian_score_map.to_resized_score_map( 275 resized_height=downsampled_core_size, 276 resized_width=downsampled_core_size, 277 cv_resize_interpolation=cv.INTER_AREA, 278 ) 279 280 downsampled_page_char_bounding_box_mask = \ 281 page_char_bounding_box_mask.to_box_detached() 282 downsampled_page_char_bounding_box_mask = \ 283 downsampled_page_char_bounding_box_mask.to_resized_mask( 284 resized_height=downsampled_core_size, 285 resized_width=downsampled_core_size, 286 cv_resize_interpolation=cv.INTER_AREA, 287 ) 288 289 downsampled_page_char_regression_labels = [ 290 label.to_downsampled_page_char_regression_label( 291 self.config.downsample_labeling_factor 292 ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels) 293 ] 294 295 downsampled_label = DownsampledLabel( 296 shape=downsampled_shape, 297 page_char_mask=downsampled_page_char_mask, 298 page_char_height_score_map=downsampled_page_char_height_score_map, 299 page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map, 300 page_char_regression_labels=downsampled_page_char_regression_labels, 301 page_char_bounding_box_mask=downsampled_page_char_bounding_box_mask, 302 target_core_box=downsampled_target_core_box, 303 ) 304 305 return CroppedPageTextRegion( 306 page_image=page_image, 307 page_char_mask=page_char_mask, 308 page_char_height_score_map=page_char_height_score_map, 309 page_char_gaussian_score_map=page_char_gaussian_score_map, 310 page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels, 311 page_char_bounding_box_mask=page_char_bounding_box_mask, 312 target_core_box=cropper.target_core_box, 313 downsampled_label=downsampled_label, 314 ) 315 316 def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator): 317 page_cropping_step_output = input.page_cropping_step_output 318 num_cropped_pages = len(page_cropping_step_output.cropped_pages) 319 320 page_text_region_step_output = input.page_text_region_step_output 321 page_image = page_text_region_step_output.page_image 322 shape_before_rotate = page_text_region_step_output.shape_before_rotate 323 rotate_angle = page_text_region_step_output.rotate_angle 324 325 page_text_region_label_step_output = input.page_text_region_label_step_output 326 page_char_mask = page_text_region_label_step_output.page_char_mask 327 page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map 328 page_char_gaussian_score_map = \ 329 page_text_region_label_step_output.page_char_gaussian_score_map 330 page_char_regression_labels = \ 331 page_text_region_label_step_output.page_char_regression_labels 332 page_char_bounding_box_mask = \ 333 page_text_region_label_step_output.page_char_bounding_box_mask 334 335 centroid_page_char_regression_labels = [ 336 label for label in page_char_regression_labels 337 if label.tag == PageCharRegressionLabelTag.CENTROID 338 ] 339 centroid_strtree = self.build_strtree_for_page_char_regression_labels( 340 centroid_page_char_regression_labels 341 ) 342 343 deviate_page_char_regression_labels = [ 344 label for label in page_char_regression_labels 345 if label.tag == PageCharRegressionLabelTag.DEVIATE 346 ] 347 deviate_strtree = self.build_strtree_for_page_char_regression_labels( 348 deviate_page_char_regression_labels 349 ) 350 351 num_samples = round( 352 self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages 353 ) 354 355 run_count_max = max(3, 2 * num_samples) 356 run_count = 0 357 358 cropped_page_text_regions: List[CroppedPageTextRegion] = [] 359 360 while len(cropped_page_text_regions) < num_samples and run_count < run_count_max: 361 cropped_page_text_region = self.sample_cropped_page_text_regions( 362 page_image=page_image, 363 shape_before_rotate=shape_before_rotate, 364 rotate_angle=rotate_angle, 365 page_char_mask=page_char_mask, 366 page_char_height_score_map=page_char_height_score_map, 367 page_char_gaussian_score_map=page_char_gaussian_score_map, 368 page_char_bounding_box_mask=page_char_bounding_box_mask, 369 centroid_strtree=centroid_strtree, 370 centroid_page_char_regression_labels=centroid_page_char_regression_labels, 371 deviate_strtree=deviate_strtree, 372 deviate_page_char_regression_labels=deviate_page_char_regression_labels, 373 rng=rng, 374 ) 375 if cropped_page_text_region: 376 cropped_page_text_regions.append(cropped_page_text_region) 377 run_count += 1 378 379 return PageTextRegionCroppingStepOutput( 380 cropped_page_text_regions=cropped_page_text_regions, 381 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
92 @classmethod 93 def build_strtree_for_page_char_regression_labels( 94 cls, 95 labels: Sequence[PageCharRegressionLabel], 96 ): 97 shapely_points: List[ShapelyPoint] = [] 98 99 for label in labels: 100 # Original resolution. 101 assert not label.is_downsampled 102 # As int. 103 xy_pair = (label.downsampled_label_point_x, label.downsampled_label_point_y) 104 shapely_points.append(ShapelyPoint(*xy_pair)) 105 106 strtree = STRtree(shapely_points) 107 return strtree
109 def sample_cropped_page_text_regions( 110 self, 111 page_image: Image, 112 shape_before_rotate: Tuple[int, int], 113 rotate_angle: int, 114 page_char_mask: Mask, 115 page_char_height_score_map: ScoreMap, 116 page_char_gaussian_score_map: ScoreMap, 117 page_char_bounding_box_mask: Mask, 118 centroid_strtree: STRtree, 119 centroid_page_char_regression_labels: Sequence[PageCharRegressionLabel], 120 deviate_strtree: STRtree, 121 deviate_page_char_regression_labels: Sequence[PageCharRegressionLabel], 122 rng: RandomGenerator, 123 ): 124 if rotate_angle != 0: 125 cropper_before_rotate = Cropper.create_from_random_proposal( 126 shape=shape_before_rotate, 127 core_size=self.config.core_size, 128 pad_size=self.config.pad_size, 129 pad_value=self.config.pad_value, 130 rng=rng, 131 ) 132 original_box_before_rotate = cropper_before_rotate.cropper_state.original_box 133 center_point_before_rotate = original_box_before_rotate.get_center_point() 134 135 rotated_result = rotate.distort( 136 {'angle': rotate_angle}, 137 shapable_or_shape=shape_before_rotate, 138 point=center_point_before_rotate, 139 ) 140 assert rotated_result.shape == page_image.shape 141 center_point = rotated_result.point 142 assert center_point 143 144 cropper = Cropper.create_from_center_point( 145 shape=page_image.shape, 146 core_size=self.config.core_size, 147 pad_size=self.config.pad_size, 148 pad_value=self.config.pad_value, 149 center_point=center_point, 150 ) 151 152 else: 153 cropper = Cropper.create_from_random_proposal( 154 shape=page_image.shape, 155 core_size=self.config.core_size, 156 pad_size=self.config.pad_size, 157 pad_value=self.config.pad_value, 158 rng=rng, 159 ) 160 161 # Remove labels out of the original core box. 162 original_core_shapely_polygon = cropper.original_core_box.to_shapely_polygon() 163 164 centroid_labels: List[PageCharRegressionLabel] = [] 165 for centroid_page_char_regression_label_idx in sorted( 166 centroid_strtree.query( 167 original_core_shapely_polygon, 168 predicate='intersects', 169 ) 170 ): 171 centroid_label = \ 172 centroid_page_char_regression_labels[centroid_page_char_regression_label_idx] 173 centroid_labels.append(centroid_label) 174 175 preserved_char_indices = set(centroid_label.char_idx for centroid_label in centroid_labels) 176 deviate_labels: List[PageCharRegressionLabel] = [] 177 for deviate_page_char_regression_label_idx in sorted( 178 deviate_strtree.query( 179 original_core_shapely_polygon, 180 predicate='intersects', 181 ) 182 ): 183 deviate_label = \ 184 deviate_page_char_regression_labels[deviate_page_char_regression_label_idx] 185 if deviate_label.char_idx not in preserved_char_indices: 186 # If the centroid is not preserved, ignore this deviate label as well. 187 continue 188 deviate_labels.append(deviate_label) 189 190 if len(centroid_labels) < self.config.num_centroid_points_min \ 191 or len(deviate_labels) < self.config.num_deviate_points_min: 192 return None 193 194 # Shift labels. 195 offset_y = cropper.target_box.up - cropper.original_box.up 196 offset_x = cropper.target_box.left - cropper.original_box.left 197 shifted_centroid_labels = [ 198 centroid_label.to_shifted_page_char_regression_label( 199 offset_y=offset_y, 200 offset_x=offset_x, 201 ) for centroid_label in centroid_labels 202 ] 203 shifted_deviate_labels = [ 204 deviate_label.to_shifted_page_char_regression_label( 205 offset_y=offset_y, 206 offset_x=offset_x, 207 ) for deviate_label in deviate_labels 208 ] 209 210 # Crop image and labeling maps. 211 page_image = cropper.crop_image(page_image) 212 page_char_mask = cropper.crop_mask( 213 page_char_mask, 214 core_only=True, 215 ) 216 page_char_height_score_map = cropper.crop_score_map( 217 page_char_height_score_map, 218 core_only=True, 219 ) 220 page_char_gaussian_score_map = cropper.crop_score_map( 221 page_char_gaussian_score_map, 222 core_only=True, 223 ) 224 page_char_bounding_box_mask = cropper.crop_mask( 225 page_char_bounding_box_mask, 226 core_only=True, 227 ) 228 229 downsampled_label: Optional[DownsampledLabel] = None 230 if self.config.enable_downsample_labeling: 231 downsample_labeling_factor = self.config.downsample_labeling_factor 232 233 assert cropper.crop_size % downsample_labeling_factor == 0 234 downsampled_size = cropper.crop_size // downsample_labeling_factor 235 downsampled_shape = (downsampled_size, downsampled_size) 236 237 assert self.config.pad_size % downsample_labeling_factor == 0 238 assert self.config.core_size % downsample_labeling_factor == 0 239 assert cropper.target_core_box.height \ 240 == cropper.target_core_box.width \ 241 == self.config.core_size 242 243 downsampled_pad_size = self.config.pad_size // downsample_labeling_factor 244 downsampled_core_size = self.config.core_size // downsample_labeling_factor 245 246 downsampled_target_core_begin = downsampled_pad_size 247 downsampled_target_core_end = downsampled_target_core_begin + downsampled_core_size - 1 248 downsampled_target_core_box = Box( 249 up=downsampled_target_core_begin, 250 down=downsampled_target_core_end, 251 left=downsampled_target_core_begin, 252 right=downsampled_target_core_end, 253 ) 254 255 downsampled_page_char_mask = page_char_mask.to_box_detached() 256 downsampled_page_char_mask = \ 257 downsampled_page_char_mask.to_resized_mask( 258 resized_height=downsampled_core_size, 259 resized_width=downsampled_core_size, 260 cv_resize_interpolation=cv.INTER_AREA, 261 ) 262 263 downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached() 264 downsampled_page_char_height_score_map = \ 265 downsampled_page_char_height_score_map.to_resized_score_map( 266 resized_height=downsampled_core_size, 267 resized_width=downsampled_core_size, 268 cv_resize_interpolation=cv.INTER_AREA, 269 ) 270 271 downsampled_page_char_gaussian_score_map = \ 272 page_char_gaussian_score_map.to_box_detached() 273 downsampled_page_char_gaussian_score_map = \ 274 downsampled_page_char_gaussian_score_map.to_resized_score_map( 275 resized_height=downsampled_core_size, 276 resized_width=downsampled_core_size, 277 cv_resize_interpolation=cv.INTER_AREA, 278 ) 279 280 downsampled_page_char_bounding_box_mask = \ 281 page_char_bounding_box_mask.to_box_detached() 282 downsampled_page_char_bounding_box_mask = \ 283 downsampled_page_char_bounding_box_mask.to_resized_mask( 284 resized_height=downsampled_core_size, 285 resized_width=downsampled_core_size, 286 cv_resize_interpolation=cv.INTER_AREA, 287 ) 288 289 downsampled_page_char_regression_labels = [ 290 label.to_downsampled_page_char_regression_label( 291 self.config.downsample_labeling_factor 292 ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels) 293 ] 294 295 downsampled_label = DownsampledLabel( 296 shape=downsampled_shape, 297 page_char_mask=downsampled_page_char_mask, 298 page_char_height_score_map=downsampled_page_char_height_score_map, 299 page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map, 300 page_char_regression_labels=downsampled_page_char_regression_labels, 301 page_char_bounding_box_mask=downsampled_page_char_bounding_box_mask, 302 target_core_box=downsampled_target_core_box, 303 ) 304 305 return CroppedPageTextRegion( 306 page_image=page_image, 307 page_char_mask=page_char_mask, 308 page_char_height_score_map=page_char_height_score_map, 309 page_char_gaussian_score_map=page_char_gaussian_score_map, 310 page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels, 311 page_char_bounding_box_mask=page_char_bounding_box_mask, 312 target_core_box=cropper.target_core_box, 313 downsampled_label=downsampled_label, 314 )
316 def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator): 317 page_cropping_step_output = input.page_cropping_step_output 318 num_cropped_pages = len(page_cropping_step_output.cropped_pages) 319 320 page_text_region_step_output = input.page_text_region_step_output 321 page_image = page_text_region_step_output.page_image 322 shape_before_rotate = page_text_region_step_output.shape_before_rotate 323 rotate_angle = page_text_region_step_output.rotate_angle 324 325 page_text_region_label_step_output = input.page_text_region_label_step_output 326 page_char_mask = page_text_region_label_step_output.page_char_mask 327 page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map 328 page_char_gaussian_score_map = \ 329 page_text_region_label_step_output.page_char_gaussian_score_map 330 page_char_regression_labels = \ 331 page_text_region_label_step_output.page_char_regression_labels 332 page_char_bounding_box_mask = \ 333 page_text_region_label_step_output.page_char_bounding_box_mask 334 335 centroid_page_char_regression_labels = [ 336 label for label in page_char_regression_labels 337 if label.tag == PageCharRegressionLabelTag.CENTROID 338 ] 339 centroid_strtree = self.build_strtree_for_page_char_regression_labels( 340 centroid_page_char_regression_labels 341 ) 342 343 deviate_page_char_regression_labels = [ 344 label for label in page_char_regression_labels 345 if label.tag == PageCharRegressionLabelTag.DEVIATE 346 ] 347 deviate_strtree = self.build_strtree_for_page_char_regression_labels( 348 deviate_page_char_regression_labels 349 ) 350 351 num_samples = round( 352 self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages 353 ) 354 355 run_count_max = max(3, 2 * num_samples) 356 run_count = 0 357 358 cropped_page_text_regions: List[CroppedPageTextRegion] = [] 359 360 while len(cropped_page_text_regions) < num_samples and run_count < run_count_max: 361 cropped_page_text_region = self.sample_cropped_page_text_regions( 362 page_image=page_image, 363 shape_before_rotate=shape_before_rotate, 364 rotate_angle=rotate_angle, 365 page_char_mask=page_char_mask, 366 page_char_height_score_map=page_char_height_score_map, 367 page_char_gaussian_score_map=page_char_gaussian_score_map, 368 page_char_bounding_box_mask=page_char_bounding_box_mask, 369 centroid_strtree=centroid_strtree, 370 centroid_page_char_regression_labels=centroid_page_char_regression_labels, 371 deviate_strtree=deviate_strtree, 372 deviate_page_char_regression_labels=deviate_page_char_regression_labels, 373 rng=rng, 374 ) 375 if cropped_page_text_region: 376 cropped_page_text_regions.append(cropped_page_text_region) 377 run_count += 1 378 379 return PageTextRegionCroppingStepOutput( 380 cropped_page_text_regions=cropped_page_text_regions, 381 )