vkit.pipeline.text_detection.page_text_region_cropping
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Sequence, Mapping, Tuple, DefaultDict, List, Optional 15from collections import defaultdict 16import itertools 17import warnings 18 19import attrs 20from numpy.random import Generator as RandomGenerator 21from shapely.errors import ShapelyDeprecationWarning 22from shapely.strtree import STRtree 23from shapely.geometry import Point as ShapelyPoint 24import cv2 as cv 25 26from vkit.element import Box, Mask, ScoreMap, Image 27from vkit.mechanism.distortion import rotate 28from vkit.mechanism.cropper import Cropper 29from ..interface import PipelineStep, PipelineStepFactory 30from .page_cropping import PageCroppingStepOutput 31from .page_text_region import PageTextRegionStepOutput 32from .page_text_region_label import ( 33 PageCharRegressionLabelTag, 34 PageCharRegressionLabel, 35 PageTextRegionLabelStepOutput, 36) 37 38# Shapely version has been explicitly locked under 2.0, hence ignore this warning. 39warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning) 40 41 42@attrs.define 43class PageTextRegionCroppingStepConfig: 44 core_size: int 45 pad_size: int 46 num_samples_factor_relative_to_num_cropped_pages: float = 1.0 47 num_centroid_points_min: int = 10 48 num_deviate_points_min: int = 10 49 pad_value: int = 0 50 enable_downsample_labeling: bool = True 51 downsample_labeling_factor: int = 2 52 53 54@attrs.define 55class PageTextRegionCroppingStepInput: 56 page_cropping_step_output: PageCroppingStepOutput 57 page_text_region_step_output: PageTextRegionStepOutput 58 page_text_region_label_step_output: PageTextRegionLabelStepOutput 59 60 61@attrs.define 62class DownsampledLabel: 63 shape: Tuple[int, int] 64 page_char_mask: Mask 65 page_char_height_score_map: ScoreMap 66 page_char_gaussian_score_map: ScoreMap 67 page_char_regression_labels: Sequence[PageCharRegressionLabel] 68 core_box: Box 69 70 71@attrs.define 72class CroppedPageTextRegion: 73 page_image: Image 74 page_char_mask: Mask 75 page_char_height_score_map: ScoreMap 76 page_char_gaussian_score_map: ScoreMap 77 page_char_regression_labels: Sequence[PageCharRegressionLabel] 78 core_box: Box 79 downsampled_label: Optional[DownsampledLabel] 80 81 82@attrs.define 83class PageTextRegionCroppingStepOutput: 84 cropped_page_text_regions: Sequence[CroppedPageTextRegion] 85 86 87class PageTextRegionCroppingStep( 88 PipelineStep[ 89 PageTextRegionCroppingStepConfig, 90 PageTextRegionCroppingStepInput, 91 PageTextRegionCroppingStepOutput, 92 ] 93): # yapf: disable 94 95 @classmethod 96 def build_strtree_for_page_char_regression_labels( 97 cls, 98 labels: Sequence[PageCharRegressionLabel], 99 ): 100 shapely_points: List[ShapelyPoint] = [] 101 102 xy_pair_to_labels: DefaultDict[ 103 Tuple[int, int], 104 List[PageCharRegressionLabel], 105 ] = defaultdict(list) # yapf: disable 106 107 for label in labels: 108 assert isinstance(label.label_point_x, int) 109 assert isinstance(label.label_point_y, int) 110 xy_pair = (label.label_point_x, label.label_point_y) 111 shapely_points.append(ShapelyPoint(*xy_pair)) 112 xy_pair_to_labels[xy_pair].append(label) 113 114 strtree = STRtree(shapely_points) 115 return strtree, xy_pair_to_labels 116 117 def sample_cropped_page_text_regions( 118 self, 119 page_image: Image, 120 shape_before_rotate: Tuple[int, int], 121 rotate_angle: int, 122 page_char_mask: Mask, 123 page_char_height_score_map: ScoreMap, 124 page_char_gaussian_score_map: ScoreMap, 125 centroid_strtree: STRtree, 126 centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]], 127 deviate_strtree: STRtree, 128 deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]], 129 rng: RandomGenerator, 130 ): 131 if rotate_angle != 0: 132 cropper_before_rotate = Cropper.create( 133 shape=shape_before_rotate, 134 core_size=self.config.core_size, 135 pad_size=self.config.pad_size, 136 pad_value=self.config.pad_value, 137 rng=rng, 138 ) 139 origin_box_before_rotate = cropper_before_rotate.cropper_state.origin_box 140 center_point_before_rotate = origin_box_before_rotate.get_center_point() 141 142 rotated_result = rotate.distort( 143 {'angle': rotate_angle}, 144 shapable_or_shape=shape_before_rotate, 145 point=center_point_before_rotate, 146 ) 147 assert rotated_result.shape == page_image.shape 148 center_point = rotated_result.point 149 assert center_point 150 151 cropper = Cropper.create_from_center_point( 152 shape=page_image.shape, 153 core_size=self.config.core_size, 154 pad_size=self.config.pad_size, 155 pad_value=self.config.pad_value, 156 center_point=center_point, 157 ) 158 159 else: 160 cropper = Cropper.create( 161 shape=page_image.shape, 162 core_size=self.config.core_size, 163 pad_size=self.config.pad_size, 164 pad_value=self.config.pad_value, 165 rng=rng, 166 ) 167 168 # Pick labels. 169 origin_core_shapely_polygon = cropper.origin_core_box.to_shapely_polygon() 170 171 centroid_labels: List[PageCharRegressionLabel] = [] 172 for shapely_point in centroid_strtree.query(origin_core_shapely_polygon): 173 if not origin_core_shapely_polygon.intersects(shapely_point): 174 continue 175 assert isinstance(shapely_point, ShapelyPoint) 176 centroid_xy_pair = (int(shapely_point.x), int(shapely_point.y)) 177 centroid_labels.extend(centroid_xy_pair_to_labels[centroid_xy_pair]) 178 179 deviate_labels: List[PageCharRegressionLabel] = [] 180 for shapely_point in deviate_strtree.query(origin_core_shapely_polygon): 181 if not origin_core_shapely_polygon.intersects(shapely_point): 182 continue 183 assert isinstance(shapely_point, ShapelyPoint) 184 deviate_xy_pair = (int(shapely_point.x), int(shapely_point.y)) 185 deviate_labels.extend(deviate_xy_pair_to_labels[deviate_xy_pair]) 186 187 if len(centroid_labels) < self.config.num_centroid_points_min \ 188 or len(deviate_labels) < self.config.num_deviate_points_min: 189 return None 190 191 # Shift labels. 192 offset_y = cropper.target_box.up - cropper.origin_box.up 193 offset_x = cropper.target_box.left - cropper.origin_box.left 194 shifted_centroid_labels = [ 195 centroid_label.to_shifted_page_char_regression_label( 196 offset_y=offset_y, 197 offset_x=offset_x, 198 ) for centroid_label in centroid_labels 199 ] 200 shifted_deviate_labels = [ 201 deviate_label.to_shifted_page_char_regression_label( 202 offset_y=offset_y, 203 offset_x=offset_x, 204 ) for deviate_label in deviate_labels 205 ] 206 207 # Crop image and score map. 208 page_image = cropper.crop_image(page_image) 209 page_char_mask = cropper.crop_mask( 210 page_char_mask, 211 core_only=True, 212 ) 213 page_char_height_score_map = cropper.crop_score_map( 214 page_char_height_score_map, 215 core_only=True, 216 ) 217 page_char_gaussian_score_map = cropper.crop_score_map( 218 page_char_gaussian_score_map, 219 core_only=True, 220 ) 221 222 downsampled_label: Optional[DownsampledLabel] = None 223 if self.config.enable_downsample_labeling: 224 downsample_labeling_factor = self.config.downsample_labeling_factor 225 226 assert cropper.crop_size % downsample_labeling_factor == 0 227 downsampled_size = cropper.crop_size // downsample_labeling_factor 228 downsampled_shape = (downsampled_size, downsampled_size) 229 230 assert self.config.pad_size % downsample_labeling_factor == 0 231 assert self.config.core_size % downsample_labeling_factor == 0 232 assert cropper.core_box.height == cropper.core_box.width == self.config.core_size 233 234 downsampled_pad_size = self.config.pad_size // downsample_labeling_factor 235 downsampled_core_size = self.config.core_size // downsample_labeling_factor 236 237 downsampled_core_begin = downsampled_pad_size 238 downsampled_core_end = downsampled_core_begin + downsampled_core_size - 1 239 downsampled_core_box = Box( 240 up=downsampled_core_begin, 241 down=downsampled_core_end, 242 left=downsampled_core_begin, 243 right=downsampled_core_end, 244 ) 245 246 downsampled_page_char_mask = page_char_mask.to_box_detached() 247 downsampled_page_char_mask = \ 248 downsampled_page_char_mask.to_resized_mask( 249 resized_height=downsampled_core_size, 250 resized_width=downsampled_core_size, 251 cv_resize_interpolation=cv.INTER_AREA, 252 ) 253 254 downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached() 255 downsampled_page_char_height_score_map = \ 256 downsampled_page_char_height_score_map.to_resized_score_map( 257 resized_height=downsampled_core_size, 258 resized_width=downsampled_core_size, 259 cv_resize_interpolation=cv.INTER_AREA, 260 ) 261 262 downsampled_page_char_gaussian_score_map = \ 263 page_char_gaussian_score_map.to_box_detached() 264 downsampled_page_char_gaussian_score_map = \ 265 downsampled_page_char_gaussian_score_map.to_resized_score_map( 266 resized_height=downsampled_core_size, 267 resized_width=downsampled_core_size, 268 cv_resize_interpolation=cv.INTER_AREA, 269 ) 270 271 downsampled_page_char_regression_labels = [ 272 label.to_downsampled_page_char_regression_label( 273 self.config.downsample_labeling_factor 274 ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels) 275 ] 276 277 downsampled_label = DownsampledLabel( 278 shape=downsampled_shape, 279 page_char_mask=downsampled_page_char_mask, 280 page_char_height_score_map=downsampled_page_char_height_score_map, 281 page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map, 282 page_char_regression_labels=downsampled_page_char_regression_labels, 283 core_box=downsampled_core_box, 284 ) 285 286 return CroppedPageTextRegion( 287 page_image=page_image, 288 page_char_mask=page_char_mask, 289 page_char_height_score_map=page_char_height_score_map, 290 page_char_gaussian_score_map=page_char_gaussian_score_map, 291 page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels, 292 core_box=cropper.core_box, 293 downsampled_label=downsampled_label, 294 ) 295 296 def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator): 297 page_cropping_step_output = input.page_cropping_step_output 298 num_cropped_pages = len(page_cropping_step_output.cropped_pages) 299 300 page_text_region_step_output = input.page_text_region_step_output 301 page_image = page_text_region_step_output.page_image 302 shape_before_rotate = page_text_region_step_output.shape_before_rotate 303 rotate_angle = page_text_region_step_output.rotate_angle 304 305 page_text_region_label_step_output = input.page_text_region_label_step_output 306 page_char_mask = page_text_region_label_step_output.page_char_mask 307 page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map 308 page_char_gaussian_score_map = \ 309 page_text_region_label_step_output.page_char_gaussian_score_map 310 page_char_regression_labels = \ 311 page_text_region_label_step_output.page_char_regression_labels 312 313 ( 314 centroid_strtree, 315 centroid_xy_pair_to_labels, 316 ) = self.build_strtree_for_page_char_regression_labels([ 317 label for label in page_char_regression_labels 318 if label.tag == PageCharRegressionLabelTag.CENTROID 319 ]) 320 ( 321 deviate_strtree, 322 deviate_xy_pair_to_labels, 323 ) = self.build_strtree_for_page_char_regression_labels([ 324 label for label in page_char_regression_labels 325 if label.tag == PageCharRegressionLabelTag.DEVIATE 326 ]) 327 328 num_samples = round( 329 self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages 330 ) 331 332 run_count_max = max(3, 2 * num_samples) 333 run_count = 0 334 335 cropped_page_text_regions: List[CroppedPageTextRegion] = [] 336 337 while len(cropped_page_text_regions) < num_samples and run_count < run_count_max: 338 cropped_page_text_region = self.sample_cropped_page_text_regions( 339 page_image=page_image, 340 shape_before_rotate=shape_before_rotate, 341 rotate_angle=rotate_angle, 342 page_char_mask=page_char_mask, 343 page_char_height_score_map=page_char_height_score_map, 344 page_char_gaussian_score_map=page_char_gaussian_score_map, 345 centroid_strtree=centroid_strtree, 346 centroid_xy_pair_to_labels=centroid_xy_pair_to_labels, 347 deviate_strtree=deviate_strtree, 348 deviate_xy_pair_to_labels=deviate_xy_pair_to_labels, 349 rng=rng, 350 ) 351 if cropped_page_text_region: 352 cropped_page_text_regions.append(cropped_page_text_region) 353 run_count += 1 354 355 return PageTextRegionCroppingStepOutput( 356 cropped_page_text_regions=cropped_page_text_regions, 357 ) 358 359 360page_text_region_cropping_step_factory = PipelineStepFactory(PageTextRegionCroppingStep)
44class PageTextRegionCroppingStepConfig: 45 core_size: int 46 pad_size: int 47 num_samples_factor_relative_to_num_cropped_pages: float = 1.0 48 num_centroid_points_min: int = 10 49 num_deviate_points_min: int = 10 50 pad_value: int = 0 51 enable_downsample_labeling: bool = True 52 downsample_labeling_factor: int = 2
2def __init__(self, core_size, pad_size, num_samples_factor_relative_to_num_cropped_pages=attr_dict['num_samples_factor_relative_to_num_cropped_pages'].default, num_centroid_points_min=attr_dict['num_centroid_points_min'].default, num_deviate_points_min=attr_dict['num_deviate_points_min'].default, pad_value=attr_dict['pad_value'].default, enable_downsample_labeling=attr_dict['enable_downsample_labeling'].default, downsample_labeling_factor=attr_dict['downsample_labeling_factor'].default): 3 self.core_size = core_size 4 self.pad_size = pad_size 5 self.num_samples_factor_relative_to_num_cropped_pages = num_samples_factor_relative_to_num_cropped_pages 6 self.num_centroid_points_min = num_centroid_points_min 7 self.num_deviate_points_min = num_deviate_points_min 8 self.pad_value = pad_value 9 self.enable_downsample_labeling = enable_downsample_labeling 10 self.downsample_labeling_factor = downsample_labeling_factor
Method generated by attrs for class PageTextRegionCroppingStepConfig.
56class PageTextRegionCroppingStepInput: 57 page_cropping_step_output: PageCroppingStepOutput 58 page_text_region_step_output: PageTextRegionStepOutput 59 page_text_region_label_step_output: PageTextRegionLabelStepOutput
2def __init__(self, page_cropping_step_output, page_text_region_step_output, page_text_region_label_step_output): 3 self.page_cropping_step_output = page_cropping_step_output 4 self.page_text_region_step_output = page_text_region_step_output 5 self.page_text_region_label_step_output = page_text_region_label_step_output
Method generated by attrs for class PageTextRegionCroppingStepInput.
63class DownsampledLabel: 64 shape: Tuple[int, int] 65 page_char_mask: Mask 66 page_char_height_score_map: ScoreMap 67 page_char_gaussian_score_map: ScoreMap 68 page_char_regression_labels: Sequence[PageCharRegressionLabel] 69 core_box: Box
2def __init__(self, shape, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, core_box): 3 self.shape = shape 4 self.page_char_mask = page_char_mask 5 self.page_char_height_score_map = page_char_height_score_map 6 self.page_char_gaussian_score_map = page_char_gaussian_score_map 7 self.page_char_regression_labels = page_char_regression_labels 8 self.core_box = core_box
Method generated by attrs for class DownsampledLabel.
73class CroppedPageTextRegion: 74 page_image: Image 75 page_char_mask: Mask 76 page_char_height_score_map: ScoreMap 77 page_char_gaussian_score_map: ScoreMap 78 page_char_regression_labels: Sequence[PageCharRegressionLabel] 79 core_box: Box 80 downsampled_label: Optional[DownsampledLabel]
2def __init__(self, page_image, page_char_mask, page_char_height_score_map, page_char_gaussian_score_map, page_char_regression_labels, core_box, downsampled_label): 3 self.page_image = page_image 4 self.page_char_mask = page_char_mask 5 self.page_char_height_score_map = page_char_height_score_map 6 self.page_char_gaussian_score_map = page_char_gaussian_score_map 7 self.page_char_regression_labels = page_char_regression_labels 8 self.core_box = core_box 9 self.downsampled_label = downsampled_label
Method generated by attrs for class CroppedPageTextRegion.
84class PageTextRegionCroppingStepOutput: 85 cropped_page_text_regions: Sequence[CroppedPageTextRegion]
2def __init__(self, cropped_page_text_regions): 3 self.cropped_page_text_regions = cropped_page_text_regions
Method generated by attrs for class PageTextRegionCroppingStepOutput.
88class PageTextRegionCroppingStep( 89 PipelineStep[ 90 PageTextRegionCroppingStepConfig, 91 PageTextRegionCroppingStepInput, 92 PageTextRegionCroppingStepOutput, 93 ] 94): # yapf: disable 95 96 @classmethod 97 def build_strtree_for_page_char_regression_labels( 98 cls, 99 labels: Sequence[PageCharRegressionLabel], 100 ): 101 shapely_points: List[ShapelyPoint] = [] 102 103 xy_pair_to_labels: DefaultDict[ 104 Tuple[int, int], 105 List[PageCharRegressionLabel], 106 ] = defaultdict(list) # yapf: disable 107 108 for label in labels: 109 assert isinstance(label.label_point_x, int) 110 assert isinstance(label.label_point_y, int) 111 xy_pair = (label.label_point_x, label.label_point_y) 112 shapely_points.append(ShapelyPoint(*xy_pair)) 113 xy_pair_to_labels[xy_pair].append(label) 114 115 strtree = STRtree(shapely_points) 116 return strtree, xy_pair_to_labels 117 118 def sample_cropped_page_text_regions( 119 self, 120 page_image: Image, 121 shape_before_rotate: Tuple[int, int], 122 rotate_angle: int, 123 page_char_mask: Mask, 124 page_char_height_score_map: ScoreMap, 125 page_char_gaussian_score_map: ScoreMap, 126 centroid_strtree: STRtree, 127 centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]], 128 deviate_strtree: STRtree, 129 deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]], 130 rng: RandomGenerator, 131 ): 132 if rotate_angle != 0: 133 cropper_before_rotate = Cropper.create( 134 shape=shape_before_rotate, 135 core_size=self.config.core_size, 136 pad_size=self.config.pad_size, 137 pad_value=self.config.pad_value, 138 rng=rng, 139 ) 140 origin_box_before_rotate = cropper_before_rotate.cropper_state.origin_box 141 center_point_before_rotate = origin_box_before_rotate.get_center_point() 142 143 rotated_result = rotate.distort( 144 {'angle': rotate_angle}, 145 shapable_or_shape=shape_before_rotate, 146 point=center_point_before_rotate, 147 ) 148 assert rotated_result.shape == page_image.shape 149 center_point = rotated_result.point 150 assert center_point 151 152 cropper = Cropper.create_from_center_point( 153 shape=page_image.shape, 154 core_size=self.config.core_size, 155 pad_size=self.config.pad_size, 156 pad_value=self.config.pad_value, 157 center_point=center_point, 158 ) 159 160 else: 161 cropper = Cropper.create( 162 shape=page_image.shape, 163 core_size=self.config.core_size, 164 pad_size=self.config.pad_size, 165 pad_value=self.config.pad_value, 166 rng=rng, 167 ) 168 169 # Pick labels. 170 origin_core_shapely_polygon = cropper.origin_core_box.to_shapely_polygon() 171 172 centroid_labels: List[PageCharRegressionLabel] = [] 173 for shapely_point in centroid_strtree.query(origin_core_shapely_polygon): 174 if not origin_core_shapely_polygon.intersects(shapely_point): 175 continue 176 assert isinstance(shapely_point, ShapelyPoint) 177 centroid_xy_pair = (int(shapely_point.x), int(shapely_point.y)) 178 centroid_labels.extend(centroid_xy_pair_to_labels[centroid_xy_pair]) 179 180 deviate_labels: List[PageCharRegressionLabel] = [] 181 for shapely_point in deviate_strtree.query(origin_core_shapely_polygon): 182 if not origin_core_shapely_polygon.intersects(shapely_point): 183 continue 184 assert isinstance(shapely_point, ShapelyPoint) 185 deviate_xy_pair = (int(shapely_point.x), int(shapely_point.y)) 186 deviate_labels.extend(deviate_xy_pair_to_labels[deviate_xy_pair]) 187 188 if len(centroid_labels) < self.config.num_centroid_points_min \ 189 or len(deviate_labels) < self.config.num_deviate_points_min: 190 return None 191 192 # Shift labels. 193 offset_y = cropper.target_box.up - cropper.origin_box.up 194 offset_x = cropper.target_box.left - cropper.origin_box.left 195 shifted_centroid_labels = [ 196 centroid_label.to_shifted_page_char_regression_label( 197 offset_y=offset_y, 198 offset_x=offset_x, 199 ) for centroid_label in centroid_labels 200 ] 201 shifted_deviate_labels = [ 202 deviate_label.to_shifted_page_char_regression_label( 203 offset_y=offset_y, 204 offset_x=offset_x, 205 ) for deviate_label in deviate_labels 206 ] 207 208 # Crop image and score map. 209 page_image = cropper.crop_image(page_image) 210 page_char_mask = cropper.crop_mask( 211 page_char_mask, 212 core_only=True, 213 ) 214 page_char_height_score_map = cropper.crop_score_map( 215 page_char_height_score_map, 216 core_only=True, 217 ) 218 page_char_gaussian_score_map = cropper.crop_score_map( 219 page_char_gaussian_score_map, 220 core_only=True, 221 ) 222 223 downsampled_label: Optional[DownsampledLabel] = None 224 if self.config.enable_downsample_labeling: 225 downsample_labeling_factor = self.config.downsample_labeling_factor 226 227 assert cropper.crop_size % downsample_labeling_factor == 0 228 downsampled_size = cropper.crop_size // downsample_labeling_factor 229 downsampled_shape = (downsampled_size, downsampled_size) 230 231 assert self.config.pad_size % downsample_labeling_factor == 0 232 assert self.config.core_size % downsample_labeling_factor == 0 233 assert cropper.core_box.height == cropper.core_box.width == self.config.core_size 234 235 downsampled_pad_size = self.config.pad_size // downsample_labeling_factor 236 downsampled_core_size = self.config.core_size // downsample_labeling_factor 237 238 downsampled_core_begin = downsampled_pad_size 239 downsampled_core_end = downsampled_core_begin + downsampled_core_size - 1 240 downsampled_core_box = Box( 241 up=downsampled_core_begin, 242 down=downsampled_core_end, 243 left=downsampled_core_begin, 244 right=downsampled_core_end, 245 ) 246 247 downsampled_page_char_mask = page_char_mask.to_box_detached() 248 downsampled_page_char_mask = \ 249 downsampled_page_char_mask.to_resized_mask( 250 resized_height=downsampled_core_size, 251 resized_width=downsampled_core_size, 252 cv_resize_interpolation=cv.INTER_AREA, 253 ) 254 255 downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached() 256 downsampled_page_char_height_score_map = \ 257 downsampled_page_char_height_score_map.to_resized_score_map( 258 resized_height=downsampled_core_size, 259 resized_width=downsampled_core_size, 260 cv_resize_interpolation=cv.INTER_AREA, 261 ) 262 263 downsampled_page_char_gaussian_score_map = \ 264 page_char_gaussian_score_map.to_box_detached() 265 downsampled_page_char_gaussian_score_map = \ 266 downsampled_page_char_gaussian_score_map.to_resized_score_map( 267 resized_height=downsampled_core_size, 268 resized_width=downsampled_core_size, 269 cv_resize_interpolation=cv.INTER_AREA, 270 ) 271 272 downsampled_page_char_regression_labels = [ 273 label.to_downsampled_page_char_regression_label( 274 self.config.downsample_labeling_factor 275 ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels) 276 ] 277 278 downsampled_label = DownsampledLabel( 279 shape=downsampled_shape, 280 page_char_mask=downsampled_page_char_mask, 281 page_char_height_score_map=downsampled_page_char_height_score_map, 282 page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map, 283 page_char_regression_labels=downsampled_page_char_regression_labels, 284 core_box=downsampled_core_box, 285 ) 286 287 return CroppedPageTextRegion( 288 page_image=page_image, 289 page_char_mask=page_char_mask, 290 page_char_height_score_map=page_char_height_score_map, 291 page_char_gaussian_score_map=page_char_gaussian_score_map, 292 page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels, 293 core_box=cropper.core_box, 294 downsampled_label=downsampled_label, 295 ) 296 297 def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator): 298 page_cropping_step_output = input.page_cropping_step_output 299 num_cropped_pages = len(page_cropping_step_output.cropped_pages) 300 301 page_text_region_step_output = input.page_text_region_step_output 302 page_image = page_text_region_step_output.page_image 303 shape_before_rotate = page_text_region_step_output.shape_before_rotate 304 rotate_angle = page_text_region_step_output.rotate_angle 305 306 page_text_region_label_step_output = input.page_text_region_label_step_output 307 page_char_mask = page_text_region_label_step_output.page_char_mask 308 page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map 309 page_char_gaussian_score_map = \ 310 page_text_region_label_step_output.page_char_gaussian_score_map 311 page_char_regression_labels = \ 312 page_text_region_label_step_output.page_char_regression_labels 313 314 ( 315 centroid_strtree, 316 centroid_xy_pair_to_labels, 317 ) = self.build_strtree_for_page_char_regression_labels([ 318 label for label in page_char_regression_labels 319 if label.tag == PageCharRegressionLabelTag.CENTROID 320 ]) 321 ( 322 deviate_strtree, 323 deviate_xy_pair_to_labels, 324 ) = self.build_strtree_for_page_char_regression_labels([ 325 label for label in page_char_regression_labels 326 if label.tag == PageCharRegressionLabelTag.DEVIATE 327 ]) 328 329 num_samples = round( 330 self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages 331 ) 332 333 run_count_max = max(3, 2 * num_samples) 334 run_count = 0 335 336 cropped_page_text_regions: List[CroppedPageTextRegion] = [] 337 338 while len(cropped_page_text_regions) < num_samples and run_count < run_count_max: 339 cropped_page_text_region = self.sample_cropped_page_text_regions( 340 page_image=page_image, 341 shape_before_rotate=shape_before_rotate, 342 rotate_angle=rotate_angle, 343 page_char_mask=page_char_mask, 344 page_char_height_score_map=page_char_height_score_map, 345 page_char_gaussian_score_map=page_char_gaussian_score_map, 346 centroid_strtree=centroid_strtree, 347 centroid_xy_pair_to_labels=centroid_xy_pair_to_labels, 348 deviate_strtree=deviate_strtree, 349 deviate_xy_pair_to_labels=deviate_xy_pair_to_labels, 350 rng=rng, 351 ) 352 if cropped_page_text_region: 353 cropped_page_text_regions.append(cropped_page_text_region) 354 run_count += 1 355 356 return PageTextRegionCroppingStepOutput( 357 cropped_page_text_regions=cropped_page_text_regions, 358 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
96 @classmethod 97 def build_strtree_for_page_char_regression_labels( 98 cls, 99 labels: Sequence[PageCharRegressionLabel], 100 ): 101 shapely_points: List[ShapelyPoint] = [] 102 103 xy_pair_to_labels: DefaultDict[ 104 Tuple[int, int], 105 List[PageCharRegressionLabel], 106 ] = defaultdict(list) # yapf: disable 107 108 for label in labels: 109 assert isinstance(label.label_point_x, int) 110 assert isinstance(label.label_point_y, int) 111 xy_pair = (label.label_point_x, label.label_point_y) 112 shapely_points.append(ShapelyPoint(*xy_pair)) 113 xy_pair_to_labels[xy_pair].append(label) 114 115 strtree = STRtree(shapely_points) 116 return strtree, xy_pair_to_labels
118 def sample_cropped_page_text_regions( 119 self, 120 page_image: Image, 121 shape_before_rotate: Tuple[int, int], 122 rotate_angle: int, 123 page_char_mask: Mask, 124 page_char_height_score_map: ScoreMap, 125 page_char_gaussian_score_map: ScoreMap, 126 centroid_strtree: STRtree, 127 centroid_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]], 128 deviate_strtree: STRtree, 129 deviate_xy_pair_to_labels: Mapping[Tuple[int, int], Sequence[PageCharRegressionLabel]], 130 rng: RandomGenerator, 131 ): 132 if rotate_angle != 0: 133 cropper_before_rotate = Cropper.create( 134 shape=shape_before_rotate, 135 core_size=self.config.core_size, 136 pad_size=self.config.pad_size, 137 pad_value=self.config.pad_value, 138 rng=rng, 139 ) 140 origin_box_before_rotate = cropper_before_rotate.cropper_state.origin_box 141 center_point_before_rotate = origin_box_before_rotate.get_center_point() 142 143 rotated_result = rotate.distort( 144 {'angle': rotate_angle}, 145 shapable_or_shape=shape_before_rotate, 146 point=center_point_before_rotate, 147 ) 148 assert rotated_result.shape == page_image.shape 149 center_point = rotated_result.point 150 assert center_point 151 152 cropper = Cropper.create_from_center_point( 153 shape=page_image.shape, 154 core_size=self.config.core_size, 155 pad_size=self.config.pad_size, 156 pad_value=self.config.pad_value, 157 center_point=center_point, 158 ) 159 160 else: 161 cropper = Cropper.create( 162 shape=page_image.shape, 163 core_size=self.config.core_size, 164 pad_size=self.config.pad_size, 165 pad_value=self.config.pad_value, 166 rng=rng, 167 ) 168 169 # Pick labels. 170 origin_core_shapely_polygon = cropper.origin_core_box.to_shapely_polygon() 171 172 centroid_labels: List[PageCharRegressionLabel] = [] 173 for shapely_point in centroid_strtree.query(origin_core_shapely_polygon): 174 if not origin_core_shapely_polygon.intersects(shapely_point): 175 continue 176 assert isinstance(shapely_point, ShapelyPoint) 177 centroid_xy_pair = (int(shapely_point.x), int(shapely_point.y)) 178 centroid_labels.extend(centroid_xy_pair_to_labels[centroid_xy_pair]) 179 180 deviate_labels: List[PageCharRegressionLabel] = [] 181 for shapely_point in deviate_strtree.query(origin_core_shapely_polygon): 182 if not origin_core_shapely_polygon.intersects(shapely_point): 183 continue 184 assert isinstance(shapely_point, ShapelyPoint) 185 deviate_xy_pair = (int(shapely_point.x), int(shapely_point.y)) 186 deviate_labels.extend(deviate_xy_pair_to_labels[deviate_xy_pair]) 187 188 if len(centroid_labels) < self.config.num_centroid_points_min \ 189 or len(deviate_labels) < self.config.num_deviate_points_min: 190 return None 191 192 # Shift labels. 193 offset_y = cropper.target_box.up - cropper.origin_box.up 194 offset_x = cropper.target_box.left - cropper.origin_box.left 195 shifted_centroid_labels = [ 196 centroid_label.to_shifted_page_char_regression_label( 197 offset_y=offset_y, 198 offset_x=offset_x, 199 ) for centroid_label in centroid_labels 200 ] 201 shifted_deviate_labels = [ 202 deviate_label.to_shifted_page_char_regression_label( 203 offset_y=offset_y, 204 offset_x=offset_x, 205 ) for deviate_label in deviate_labels 206 ] 207 208 # Crop image and score map. 209 page_image = cropper.crop_image(page_image) 210 page_char_mask = cropper.crop_mask( 211 page_char_mask, 212 core_only=True, 213 ) 214 page_char_height_score_map = cropper.crop_score_map( 215 page_char_height_score_map, 216 core_only=True, 217 ) 218 page_char_gaussian_score_map = cropper.crop_score_map( 219 page_char_gaussian_score_map, 220 core_only=True, 221 ) 222 223 downsampled_label: Optional[DownsampledLabel] = None 224 if self.config.enable_downsample_labeling: 225 downsample_labeling_factor = self.config.downsample_labeling_factor 226 227 assert cropper.crop_size % downsample_labeling_factor == 0 228 downsampled_size = cropper.crop_size // downsample_labeling_factor 229 downsampled_shape = (downsampled_size, downsampled_size) 230 231 assert self.config.pad_size % downsample_labeling_factor == 0 232 assert self.config.core_size % downsample_labeling_factor == 0 233 assert cropper.core_box.height == cropper.core_box.width == self.config.core_size 234 235 downsampled_pad_size = self.config.pad_size // downsample_labeling_factor 236 downsampled_core_size = self.config.core_size // downsample_labeling_factor 237 238 downsampled_core_begin = downsampled_pad_size 239 downsampled_core_end = downsampled_core_begin + downsampled_core_size - 1 240 downsampled_core_box = Box( 241 up=downsampled_core_begin, 242 down=downsampled_core_end, 243 left=downsampled_core_begin, 244 right=downsampled_core_end, 245 ) 246 247 downsampled_page_char_mask = page_char_mask.to_box_detached() 248 downsampled_page_char_mask = \ 249 downsampled_page_char_mask.to_resized_mask( 250 resized_height=downsampled_core_size, 251 resized_width=downsampled_core_size, 252 cv_resize_interpolation=cv.INTER_AREA, 253 ) 254 255 downsampled_page_char_height_score_map = page_char_height_score_map.to_box_detached() 256 downsampled_page_char_height_score_map = \ 257 downsampled_page_char_height_score_map.to_resized_score_map( 258 resized_height=downsampled_core_size, 259 resized_width=downsampled_core_size, 260 cv_resize_interpolation=cv.INTER_AREA, 261 ) 262 263 downsampled_page_char_gaussian_score_map = \ 264 page_char_gaussian_score_map.to_box_detached() 265 downsampled_page_char_gaussian_score_map = \ 266 downsampled_page_char_gaussian_score_map.to_resized_score_map( 267 resized_height=downsampled_core_size, 268 resized_width=downsampled_core_size, 269 cv_resize_interpolation=cv.INTER_AREA, 270 ) 271 272 downsampled_page_char_regression_labels = [ 273 label.to_downsampled_page_char_regression_label( 274 self.config.downsample_labeling_factor 275 ) for label in itertools.chain(shifted_centroid_labels, shifted_deviate_labels) 276 ] 277 278 downsampled_label = DownsampledLabel( 279 shape=downsampled_shape, 280 page_char_mask=downsampled_page_char_mask, 281 page_char_height_score_map=downsampled_page_char_height_score_map, 282 page_char_gaussian_score_map=downsampled_page_char_gaussian_score_map, 283 page_char_regression_labels=downsampled_page_char_regression_labels, 284 core_box=downsampled_core_box, 285 ) 286 287 return CroppedPageTextRegion( 288 page_image=page_image, 289 page_char_mask=page_char_mask, 290 page_char_height_score_map=page_char_height_score_map, 291 page_char_gaussian_score_map=page_char_gaussian_score_map, 292 page_char_regression_labels=shifted_centroid_labels + shifted_deviate_labels, 293 core_box=cropper.core_box, 294 downsampled_label=downsampled_label, 295 )
297 def run(self, input: PageTextRegionCroppingStepInput, rng: RandomGenerator): 298 page_cropping_step_output = input.page_cropping_step_output 299 num_cropped_pages = len(page_cropping_step_output.cropped_pages) 300 301 page_text_region_step_output = input.page_text_region_step_output 302 page_image = page_text_region_step_output.page_image 303 shape_before_rotate = page_text_region_step_output.shape_before_rotate 304 rotate_angle = page_text_region_step_output.rotate_angle 305 306 page_text_region_label_step_output = input.page_text_region_label_step_output 307 page_char_mask = page_text_region_label_step_output.page_char_mask 308 page_char_height_score_map = page_text_region_label_step_output.page_char_height_score_map 309 page_char_gaussian_score_map = \ 310 page_text_region_label_step_output.page_char_gaussian_score_map 311 page_char_regression_labels = \ 312 page_text_region_label_step_output.page_char_regression_labels 313 314 ( 315 centroid_strtree, 316 centroid_xy_pair_to_labels, 317 ) = self.build_strtree_for_page_char_regression_labels([ 318 label for label in page_char_regression_labels 319 if label.tag == PageCharRegressionLabelTag.CENTROID 320 ]) 321 ( 322 deviate_strtree, 323 deviate_xy_pair_to_labels, 324 ) = self.build_strtree_for_page_char_regression_labels([ 325 label for label in page_char_regression_labels 326 if label.tag == PageCharRegressionLabelTag.DEVIATE 327 ]) 328 329 num_samples = round( 330 self.config.num_samples_factor_relative_to_num_cropped_pages * num_cropped_pages 331 ) 332 333 run_count_max = max(3, 2 * num_samples) 334 run_count = 0 335 336 cropped_page_text_regions: List[CroppedPageTextRegion] = [] 337 338 while len(cropped_page_text_regions) < num_samples and run_count < run_count_max: 339 cropped_page_text_region = self.sample_cropped_page_text_regions( 340 page_image=page_image, 341 shape_before_rotate=shape_before_rotate, 342 rotate_angle=rotate_angle, 343 page_char_mask=page_char_mask, 344 page_char_height_score_map=page_char_height_score_map, 345 page_char_gaussian_score_map=page_char_gaussian_score_map, 346 centroid_strtree=centroid_strtree, 347 centroid_xy_pair_to_labels=centroid_xy_pair_to_labels, 348 deviate_strtree=deviate_strtree, 349 deviate_xy_pair_to_labels=deviate_xy_pair_to_labels, 350 rng=rng, 351 ) 352 if cropped_page_text_region: 353 cropped_page_text_regions.append(cropped_page_text_region) 354 run_count += 1 355 356 return PageTextRegionCroppingStepOutput( 357 cropped_page_text_regions=cropped_page_text_regions, 358 )