vkit.mechanism.distortion_policy.random_distortion
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import ( 15 Mapping, 16 Any, 17 Iterable, 18 Union, 19 Tuple, 20 Optional, 21 Sequence, 22 List, 23) 24from collections import defaultdict 25import logging 26 27import attrs 28from numpy.random import Generator as RandomGenerator 29 30from vkit.utility import ( 31 dyn_structure, 32 rng_choice_with_size, 33 normalize_to_probs, 34 PathType, 35) 36from vkit.element import ( 37 Shapable, 38 Image, 39 Point, 40 PointList, 41 PointTuple, 42 Box, 43 Polygon, 44 Mask, 45 ScoreMap, 46) 47from ..distortion.interface import Distortion, DistortionResult 48from .opt import LEVEL_MIN, LEVEL_MAX 49from .type import DistortionPolicy, DistortionPolicyFactory 50from .photometric import ( 51 color, 52 blur, 53 noise, 54 effect, 55 streak, 56) 57from .geometric import ( 58 affine, 59 mls, 60 camera, 61) 62 63logger = logging.getLogger(__name__) 64 65 66@attrs.define 67class RandomDistortionDebug: 68 distortion_names: List[str] = attrs.field(factory=list) 69 distortion_levels: List[int] = attrs.field(factory=list) 70 distortion_images: List[Image] = attrs.field(factory=list) 71 distortion_configs: List[Any] = attrs.field(factory=list) 72 distortion_states: List[Any] = attrs.field(factory=list) 73 74 75@attrs.define 76class RandomDistortionStageConfig: 77 distortion_policies: Sequence[DistortionPolicy] 78 distortion_policy_weights: Sequence[float] 79 prob_enable: float 80 num_distortions_min: int 81 num_distortions_max: int 82 inject_corner_points: bool = False 83 conflict_control_keyword_groups: Sequence[Sequence[str]] = () 84 force_sample_level_in_full_range: bool = False 85 86 87class RandomDistortionStage: 88 89 def __init__(self, config: RandomDistortionStageConfig): 90 self.config = config 91 self.distortion_policy_probs = normalize_to_probs(self.config.distortion_policy_weights) 92 93 def sample_distortion_policies(self, rng: RandomGenerator) -> Sequence[DistortionPolicy]: 94 num_distortions = rng.integers( 95 self.config.num_distortions_min, 96 self.config.num_distortions_max + 1, 97 ) 98 if num_distortions <= 0: 99 return () 100 101 num_retries = 5 102 while num_retries > 0: 103 distortion_policies = rng_choice_with_size( 104 rng, 105 self.config.distortion_policies, 106 size=num_distortions, 107 probs=self.distortion_policy_probs, 108 replace=False, 109 ) 110 111 # Conflict analysis. 112 conflict_idx_to_count = defaultdict(int) 113 for distortion_policy in distortion_policies: 114 for conflict_idx, keywords in \ 115 enumerate(self.config.conflict_control_keyword_groups): 116 match = False 117 for keyword in keywords: 118 if keyword in distortion_policy.name: 119 match = True 120 break 121 if match: 122 conflict_idx_to_count[conflict_idx] += 1 123 break 124 125 no_conflict = True 126 for count in conflict_idx_to_count.values(): 127 if count > 1: 128 no_conflict = False 129 logger.debug( 130 'distortion policies conflict detected ' 131 f'conflict_idx_to_count={conflict_idx_to_count}' 132 ) 133 break 134 135 if no_conflict: 136 return distortion_policies 137 else: 138 num_retries -= 1 139 140 logger.warning(f'Cannot sample distortion policies with num_distortion={num_distortions}.') 141 return () 142 143 def apply_distortions( 144 self, 145 distortion_result: DistortionResult, 146 level_min: int, 147 level_max: int, 148 rng: RandomGenerator, 149 debug: Optional[RandomDistortionDebug] = None, 150 ): 151 if rng.random() > self.config.prob_enable: 152 return distortion_result 153 154 if self.config.inject_corner_points: 155 height, width = distortion_result.shape 156 157 step = min(height // 4, width // 4) 158 assert step > 0 159 160 ys = list(range(0, height, step)) 161 if ys[-1] < height - 1: 162 ys.append(height - 1) 163 164 xs = list(range(0, width, step)) 165 if xs[0] == 0: 166 xs.pop(0) 167 if xs[-1] == width - 1: 168 xs.pop() 169 170 corner_points = PointList() 171 172 for x in (0, width - 1): 173 for y in ys: 174 corner_points.append(Point.create(y=y, x=x)) 175 for y in (0, height - 1): 176 for x in xs: 177 corner_points.append(Point.create(y=y, x=x)) 178 179 distortion_result.corner_points = corner_points.to_point_tuple() 180 181 if self.config.force_sample_level_in_full_range: 182 level_min = LEVEL_MIN 183 level_max = LEVEL_MAX 184 185 distortion_policies = self.sample_distortion_policies(rng) 186 187 for distortion_policy in distortion_policies: 188 level = rng.integers(level_min, level_max + 1) 189 190 distortion_result = distortion_policy.distort( 191 level=level, 192 shapable_or_shape=distortion_result.shape, 193 image=distortion_result.image, 194 mask=distortion_result.mask, 195 score_map=distortion_result.score_map, 196 point=distortion_result.point, 197 points=distortion_result.points, 198 corner_points=distortion_result.corner_points, 199 polygon=distortion_result.polygon, 200 polygons=distortion_result.polygons, 201 rng=rng, 202 enable_debug=bool(debug), 203 ) 204 205 if debug: 206 assert distortion_result.image 207 debug.distortion_images.append(distortion_result.image) 208 debug.distortion_names.append(distortion_policy.name) 209 debug.distortion_levels.append(level) 210 debug.distortion_configs.append(distortion_result.config) 211 debug.distortion_states.append(distortion_result.state) 212 213 distortion_result.config = None 214 distortion_result.state = None 215 216 return distortion_result 217 218 219class RandomDistortion: 220 221 def __init__( 222 self, 223 configs: Sequence[RandomDistortionStageConfig], 224 level_min: int, 225 level_max: int, 226 ): 227 self.stages = [RandomDistortionStage(config) for config in configs] 228 self.level_min = level_min 229 self.level_max = level_max 230 231 @classmethod 232 def get_distortion_result_all_points(cls, distortion_result: DistortionResult): 233 if distortion_result.corner_points: 234 yield from distortion_result.corner_points 235 236 if distortion_result.point: 237 yield distortion_result.point 238 239 if distortion_result.points: 240 yield from distortion_result.points 241 242 if distortion_result.polygon: 243 yield from distortion_result.polygon.points 244 245 if distortion_result.polygons: 246 for polygon in distortion_result.polygons: 247 yield from polygon.points 248 249 @classmethod 250 def get_distortion_result_element_bounding_box(cls, distortion_result: DistortionResult): 251 assert distortion_result.corner_points 252 253 all_points = cls.get_distortion_result_all_points(distortion_result) 254 point = next(all_points) 255 y_min = point.y 256 y_max = point.y 257 x_min = point.x 258 x_max = point.x 259 for point in all_points: 260 y_min = min(y_min, point.y) 261 y_max = max(y_max, point.y) 262 x_min = min(x_min, point.x) 263 x_max = max(x_max, point.x) 264 return Box(up=y_min, down=y_max, left=x_min, right=x_max) 265 266 @classmethod 267 def trim_distortion_result(cls, distortion_result: DistortionResult): 268 # Trim page if need. 269 if not distortion_result.corner_points: 270 return distortion_result 271 272 height, width = distortion_result.shape 273 box = cls.get_distortion_result_element_bounding_box(distortion_result) 274 275 pad_up = box.up 276 pad_down = height - 1 - box.down 277 # NOTE: accept the rounding error. 278 assert pad_up >= -1 and pad_down >= -1 279 280 pad_left = box.left 281 pad_right = width - 1 - box.right 282 assert pad_left >= -1 and pad_right >= -1 283 284 if pad_up <= 0 and pad_down <= 0 and pad_left <= 0 and pad_right <= 0: 285 return distortion_result 286 287 # Deal with rounding error. 288 up = max(0, box.up) 289 down = min(height - 1, box.down) 290 left = max(0, box.left) 291 right = min(width - 1, box.right) 292 293 pad_up = max(0, pad_up) 294 pad_down = max(0, pad_down) 295 pad_left = max(0, pad_left) 296 pad_right = max(0, pad_right) 297 298 if distortion_result.image: 299 distortion_result.image = distortion_result.image.to_cropped_image( 300 up=up, 301 down=down, 302 left=left, 303 right=right, 304 ) 305 306 if distortion_result.mask: 307 distortion_result.mask = distortion_result.mask.to_cropped_mask( 308 up=up, 309 down=down, 310 left=left, 311 right=right, 312 ) 313 314 if distortion_result.score_map: 315 distortion_result.score_map = distortion_result.score_map.to_cropped_score_map( 316 up=up, 317 down=down, 318 left=left, 319 right=right, 320 ) 321 322 if distortion_result.point: 323 distortion_result.point = distortion_result.point.to_shifted_point( 324 offset_y=-pad_up, 325 offset_x=-pad_left, 326 ) 327 328 if distortion_result.points: 329 distortion_result.points = distortion_result.points.to_shifted_points( 330 offset_y=-pad_up, 331 offset_x=-pad_left, 332 ) 333 334 if distortion_result.polygon: 335 distortion_result.polygon = distortion_result.polygon.to_shifted_polygon( 336 offset_y=-pad_up, 337 offset_x=-pad_left, 338 ) 339 340 if distortion_result.polygons: 341 distortion_result.polygons = [ 342 polygon.to_shifted_polygon( 343 offset_y=-pad_up, 344 offset_x=-pad_left, 345 ) for polygon in distortion_result.polygons 346 ] 347 348 return distortion_result 349 350 def distort( 351 self, 352 rng: RandomGenerator, 353 shapable_or_shape: Optional[Union[Shapable, Tuple[int, int]]] = None, 354 image: Optional[Image] = None, 355 mask: Optional[Mask] = None, 356 score_map: Optional[ScoreMap] = None, 357 point: Optional[Point] = None, 358 points: Optional[Union[PointList, PointTuple, Iterable[Point]]] = None, 359 polygon: Optional[Polygon] = None, 360 polygons: Optional[Iterable[Polygon]] = None, 361 debug: Optional[RandomDistortionDebug] = None, 362 ): 363 # Pack. 364 shape = Distortion.get_shape( 365 shapable_or_shape=shapable_or_shape, 366 image=image, 367 mask=mask, 368 score_map=score_map, 369 ) 370 distortion_result = DistortionResult(shape=shape) 371 distortion_result.image = image 372 distortion_result.mask = mask 373 distortion_result.score_map = score_map 374 distortion_result.point = point 375 distortion_result.points = PointTuple(points) if points else None 376 distortion_result.polygon = polygon 377 if polygons: 378 distortion_result.polygons = tuple(polygons) 379 380 # Apply distortions. 381 for stage in self.stages: 382 distortion_result = stage.apply_distortions( 383 distortion_result=distortion_result, 384 level_min=self.level_min, 385 level_max=self.level_max, 386 rng=rng, 387 debug=debug, 388 ) 389 390 distortion_result = self.trim_distortion_result(distortion_result) 391 392 return distortion_result 393 394 395@attrs.define 396class RandomDistortionFactoryConfig: 397 # Photometric. 398 prob_photometric: float = 1.0 399 num_photometric_min: int = 0 400 num_photometric_max: int = 2 401 photometric_conflict_control_keyword_groups: Sequence[Sequence[str]] = attrs.field( 402 factory=lambda: [ 403 [ 404 'blur', 405 'pixelation', 406 'jpeg', 407 ], 408 [ 409 'noise', 410 ], 411 ] 412 ) 413 # Geometric. 414 prob_geometric: float = 0.75 415 force_post_rotate: bool = False 416 # Shared. 417 level_min: int = LEVEL_MIN 418 level_max: int = LEVEL_MAX 419 disabled_policy_names: Sequence[str] = attrs.field(factory=list) 420 name_to_policy_config: Mapping[str, Any] = attrs.field(factory=dict) 421 name_to_policy_weight: Mapping[str, float] = attrs.field(factory=dict) 422 423 424_PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS = ( 425 ( 426 ( 427 color.mean_shift_policy_factory, 428 color.color_shift_policy_factory, 429 color.brightness_shift_policy_factory, 430 color.std_shift_policy_factory, 431 color.boundary_equalization_policy_factory, 432 color.histogram_equalization_policy_factory, 433 color.complement_policy_factory, 434 color.posterization_policy_factory, 435 color.color_balance_policy_factory, 436 color.channel_permutation_policy_factory, 437 ), 438 10.0, 439 ), 440 ( 441 ( 442 blur.gaussian_blur_policy_factory, 443 blur.defocus_blur_policy_factory, 444 blur.motion_blur_policy_factory, 445 blur.glass_blur_policy_factory, 446 blur.zoom_in_blur_policy_factory, 447 ), 448 1.0, 449 ), 450 ( 451 ( 452 noise.gaussion_noise_policy_factory, 453 noise.poisson_noise_policy_factory, 454 noise.impulse_noise_policy_factory, 455 noise.speckle_noise_policy_factory, 456 ), 457 3.0, 458 ), 459 ( 460 ( 461 effect.jpeg_quality_policy_factory, 462 effect.pixelation_policy_factory, 463 effect.fog_policy_factory, 464 ), 465 1.0, 466 ), 467 ( 468 ( 469 streak.line_streak_policy_factory, 470 streak.rectangle_streak_policy_factory, 471 streak.ellipse_streak_policy_factory, 472 ), 473 1.0, 474 ), 475) 476 477_GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS = ( 478 ( 479 ( 480 affine.shear_hori_policy_factory, 481 affine.shear_vert_policy_factory, 482 affine.rotate_policy_factory, 483 affine.skew_hori_policy_factory, 484 affine.skew_vert_policy_factory, 485 ), 486 1.0, 487 ), 488 ( 489 (mls.similarity_mls_policy_factory,), 490 1.0, 491 ), 492 ( 493 ( 494 camera.camera_plane_only_policy_factory, 495 camera.camera_cubic_curve_policy_factory, 496 camera.camera_plane_line_fold_policy_factory, 497 camera.camera_plane_line_curve_policy_factory, 498 ), 499 1.0, 500 ), 501) 502 503 504class RandomDistortionFactory: 505 506 @classmethod 507 def unpack_policy_factories_and_default_weights_sum_pairs( 508 cls, 509 policy_factories_and_default_weights_sum_pairs: Sequence[ 510 Tuple[ 511 Sequence[DistortionPolicyFactory], 512 float, 513 ] 514 ] 515 ): # yapf: disable 516 flatten_policy_factories: List[DistortionPolicyFactory] = [] 517 flatten_policy_default_weights: List[float] = [] 518 519 for policy_factories, default_weights_sum in policy_factories_and_default_weights_sum_pairs: 520 default_weight = default_weights_sum / len(policy_factories) 521 flatten_policy_factories.extend(policy_factories) 522 flatten_policy_default_weights.extend([default_weight] * len(policy_factories)) 523 524 assert len(flatten_policy_factories) == len(flatten_policy_default_weights) 525 return flatten_policy_factories, flatten_policy_default_weights 526 527 def __init__( 528 self, 529 photometric_policy_factories_and_default_weights_sum_pairs: Sequence[ 530 Tuple[ 531 Sequence[DistortionPolicyFactory], 532 float, 533 ] 534 ] = _PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS, 535 geometric_policy_factories_and_default_weights_sum_pairs: Sequence[ 536 Tuple[ 537 Sequence[DistortionPolicyFactory], 538 float, 539 ] 540 ] = _GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS, 541 ): # yapf: disable 542 ( 543 self.photometric_policy_factories, 544 self.photometric_policy_default_weights, 545 ) = self.unpack_policy_factories_and_default_weights_sum_pairs( 546 photometric_policy_factories_and_default_weights_sum_pairs 547 ) 548 549 ( 550 self.geometric_policy_factories, 551 self.geometric_policy_default_weights, 552 ) = self.unpack_policy_factories_and_default_weights_sum_pairs( 553 geometric_policy_factories_and_default_weights_sum_pairs 554 ) 555 556 @classmethod 557 def create_policies_and_policy_weights( 558 cls, 559 policy_factories: Sequence[DistortionPolicyFactory], 560 policy_default_weights: Sequence[float], 561 config: RandomDistortionFactoryConfig, 562 ): 563 disabled_policy_names = set(config.disabled_policy_names) 564 565 policies: List[DistortionPolicy] = [] 566 policy_weights: List[float] = [] 567 568 for policy_factory, policy_default_weight in zip(policy_factories, policy_default_weights): 569 if policy_factory.name in disabled_policy_names: 570 continue 571 572 policy_config = config.name_to_policy_config.get(policy_factory.name) 573 policies.append(policy_factory.create(policy_config)) 574 575 policy_weight = policy_default_weight 576 if policy_factory.name in config.name_to_policy_weight: 577 policy_weight = config.name_to_policy_weight[policy_factory.name] 578 policy_weights.append(policy_weight) 579 580 return policies, policy_weights 581 582 def create( 583 self, 584 config: Optional[ 585 Union[ 586 Mapping[str, Any], 587 PathType, 588 RandomDistortionFactoryConfig, 589 ] 590 ] = None, 591 ): # yapf: disable 592 config = dyn_structure( 593 config, 594 RandomDistortionFactoryConfig, 595 support_path_type=True, 596 support_none_type=True, 597 ) 598 599 stage_configs: List[RandomDistortionStageConfig] = [] 600 601 # Photometric. 602 ( 603 photometric_policies, 604 photometric_policy_weights, 605 ) = self.create_policies_and_policy_weights( 606 self.photometric_policy_factories, 607 self.photometric_policy_default_weights, 608 config, 609 ) 610 stage_configs.append( 611 RandomDistortionStageConfig( 612 distortion_policies=photometric_policies, 613 distortion_policy_weights=photometric_policy_weights, 614 prob_enable=config.prob_photometric, 615 num_distortions_min=config.num_photometric_min, 616 num_distortions_max=config.num_photometric_max, 617 conflict_control_keyword_groups=config.photometric_conflict_control_keyword_groups, 618 ) 619 ) 620 621 # Geometric. 622 ( 623 geometric_policies, 624 geometric_policy_weights, 625 ) = self.create_policies_and_policy_weights( 626 self.geometric_policy_factories, 627 self.geometric_policy_default_weights, 628 config, 629 ) 630 631 post_rotate_policy = None 632 if config.force_post_rotate: 633 rotate_policy_idx = -1 634 for geometric_policy_idx, geometric_policy in enumerate(geometric_policies): 635 if geometric_policy.name == 'rotate': 636 rotate_policy_idx = geometric_policy_idx 637 break 638 assert rotate_policy_idx >= 0 639 post_rotate_policy = geometric_policies.pop(rotate_policy_idx) 640 geometric_policy_weights.pop(rotate_policy_idx) 641 642 stage_configs.append( 643 RandomDistortionStageConfig( 644 distortion_policies=geometric_policies, 645 distortion_policy_weights=geometric_policy_weights, 646 prob_enable=config.prob_geometric, 647 num_distortions_min=1, 648 num_distortions_max=1, 649 inject_corner_points=config.force_post_rotate, 650 ) 651 ) 652 if post_rotate_policy: 653 stage_configs.append( 654 RandomDistortionStageConfig( 655 distortion_policies=[post_rotate_policy], 656 distortion_policy_weights=[1.0], 657 prob_enable=1.0, 658 num_distortions_min=1, 659 num_distortions_max=1, 660 force_sample_level_in_full_range=True, 661 ) 662 ) 663 664 return RandomDistortion( 665 configs=stage_configs, 666 level_min=config.level_min, 667 level_max=config.level_max, 668 ) 669 670 671random_distortion_factory = RandomDistortionFactory()
class
RandomDistortionDebug:
68class RandomDistortionDebug: 69 distortion_names: List[str] = attrs.field(factory=list) 70 distortion_levels: List[int] = attrs.field(factory=list) 71 distortion_images: List[Image] = attrs.field(factory=list) 72 distortion_configs: List[Any] = attrs.field(factory=list) 73 distortion_states: List[Any] = attrs.field(factory=list)
RandomDistortionDebug( distortion_names: List[str] = NOTHING, distortion_levels: List[int] = NOTHING, distortion_images: List[vkit.element.image.Image] = NOTHING, distortion_configs: List[Any] = NOTHING, distortion_states: List[Any] = NOTHING)
2def __init__(self, distortion_names=NOTHING, distortion_levels=NOTHING, distortion_images=NOTHING, distortion_configs=NOTHING, distortion_states=NOTHING): 3 if distortion_names is not NOTHING: 4 self.distortion_names = distortion_names 5 else: 6 self.distortion_names = __attr_factory_distortion_names() 7 if distortion_levels is not NOTHING: 8 self.distortion_levels = distortion_levels 9 else: 10 self.distortion_levels = __attr_factory_distortion_levels() 11 if distortion_images is not NOTHING: 12 self.distortion_images = distortion_images 13 else: 14 self.distortion_images = __attr_factory_distortion_images() 15 if distortion_configs is not NOTHING: 16 self.distortion_configs = distortion_configs 17 else: 18 self.distortion_configs = __attr_factory_distortion_configs() 19 if distortion_states is not NOTHING: 20 self.distortion_states = distortion_states 21 else: 22 self.distortion_states = __attr_factory_distortion_states()
Method generated by attrs for class RandomDistortionDebug.
class
RandomDistortionStageConfig:
77class RandomDistortionStageConfig: 78 distortion_policies: Sequence[DistortionPolicy] 79 distortion_policy_weights: Sequence[float] 80 prob_enable: float 81 num_distortions_min: int 82 num_distortions_max: int 83 inject_corner_points: bool = False 84 conflict_control_keyword_groups: Sequence[Sequence[str]] = () 85 force_sample_level_in_full_range: bool = False
RandomDistortionStageConfig( distortion_policies: Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicy], distortion_policy_weights: Sequence[float], prob_enable: float, num_distortions_min: int, num_distortions_max: int, inject_corner_points: bool = False, conflict_control_keyword_groups: Sequence[Sequence[str]] = (), force_sample_level_in_full_range: bool = False)
2def __init__(self, distortion_policies, distortion_policy_weights, prob_enable, num_distortions_min, num_distortions_max, inject_corner_points=attr_dict['inject_corner_points'].default, conflict_control_keyword_groups=attr_dict['conflict_control_keyword_groups'].default, force_sample_level_in_full_range=attr_dict['force_sample_level_in_full_range'].default): 3 self.distortion_policies = distortion_policies 4 self.distortion_policy_weights = distortion_policy_weights 5 self.prob_enable = prob_enable 6 self.num_distortions_min = num_distortions_min 7 self.num_distortions_max = num_distortions_max 8 self.inject_corner_points = inject_corner_points 9 self.conflict_control_keyword_groups = conflict_control_keyword_groups 10 self.force_sample_level_in_full_range = force_sample_level_in_full_range
Method generated by attrs for class RandomDistortionStageConfig.
class
RandomDistortionStage:
88class RandomDistortionStage: 89 90 def __init__(self, config: RandomDistortionStageConfig): 91 self.config = config 92 self.distortion_policy_probs = normalize_to_probs(self.config.distortion_policy_weights) 93 94 def sample_distortion_policies(self, rng: RandomGenerator) -> Sequence[DistortionPolicy]: 95 num_distortions = rng.integers( 96 self.config.num_distortions_min, 97 self.config.num_distortions_max + 1, 98 ) 99 if num_distortions <= 0: 100 return () 101 102 num_retries = 5 103 while num_retries > 0: 104 distortion_policies = rng_choice_with_size( 105 rng, 106 self.config.distortion_policies, 107 size=num_distortions, 108 probs=self.distortion_policy_probs, 109 replace=False, 110 ) 111 112 # Conflict analysis. 113 conflict_idx_to_count = defaultdict(int) 114 for distortion_policy in distortion_policies: 115 for conflict_idx, keywords in \ 116 enumerate(self.config.conflict_control_keyword_groups): 117 match = False 118 for keyword in keywords: 119 if keyword in distortion_policy.name: 120 match = True 121 break 122 if match: 123 conflict_idx_to_count[conflict_idx] += 1 124 break 125 126 no_conflict = True 127 for count in conflict_idx_to_count.values(): 128 if count > 1: 129 no_conflict = False 130 logger.debug( 131 'distortion policies conflict detected ' 132 f'conflict_idx_to_count={conflict_idx_to_count}' 133 ) 134 break 135 136 if no_conflict: 137 return distortion_policies 138 else: 139 num_retries -= 1 140 141 logger.warning(f'Cannot sample distortion policies with num_distortion={num_distortions}.') 142 return () 143 144 def apply_distortions( 145 self, 146 distortion_result: DistortionResult, 147 level_min: int, 148 level_max: int, 149 rng: RandomGenerator, 150 debug: Optional[RandomDistortionDebug] = None, 151 ): 152 if rng.random() > self.config.prob_enable: 153 return distortion_result 154 155 if self.config.inject_corner_points: 156 height, width = distortion_result.shape 157 158 step = min(height // 4, width // 4) 159 assert step > 0 160 161 ys = list(range(0, height, step)) 162 if ys[-1] < height - 1: 163 ys.append(height - 1) 164 165 xs = list(range(0, width, step)) 166 if xs[0] == 0: 167 xs.pop(0) 168 if xs[-1] == width - 1: 169 xs.pop() 170 171 corner_points = PointList() 172 173 for x in (0, width - 1): 174 for y in ys: 175 corner_points.append(Point.create(y=y, x=x)) 176 for y in (0, height - 1): 177 for x in xs: 178 corner_points.append(Point.create(y=y, x=x)) 179 180 distortion_result.corner_points = corner_points.to_point_tuple() 181 182 if self.config.force_sample_level_in_full_range: 183 level_min = LEVEL_MIN 184 level_max = LEVEL_MAX 185 186 distortion_policies = self.sample_distortion_policies(rng) 187 188 for distortion_policy in distortion_policies: 189 level = rng.integers(level_min, level_max + 1) 190 191 distortion_result = distortion_policy.distort( 192 level=level, 193 shapable_or_shape=distortion_result.shape, 194 image=distortion_result.image, 195 mask=distortion_result.mask, 196 score_map=distortion_result.score_map, 197 point=distortion_result.point, 198 points=distortion_result.points, 199 corner_points=distortion_result.corner_points, 200 polygon=distortion_result.polygon, 201 polygons=distortion_result.polygons, 202 rng=rng, 203 enable_debug=bool(debug), 204 ) 205 206 if debug: 207 assert distortion_result.image 208 debug.distortion_images.append(distortion_result.image) 209 debug.distortion_names.append(distortion_policy.name) 210 debug.distortion_levels.append(level) 211 debug.distortion_configs.append(distortion_result.config) 212 debug.distortion_states.append(distortion_result.state) 213 214 distortion_result.config = None 215 distortion_result.state = None 216 217 return distortion_result
RandomDistortionStage( config: vkit.mechanism.distortion_policy.random_distortion.RandomDistortionStageConfig)
def
sample_distortion_policies( self, rng: numpy.random._generator.Generator) -> Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicy]:
94 def sample_distortion_policies(self, rng: RandomGenerator) -> Sequence[DistortionPolicy]: 95 num_distortions = rng.integers( 96 self.config.num_distortions_min, 97 self.config.num_distortions_max + 1, 98 ) 99 if num_distortions <= 0: 100 return () 101 102 num_retries = 5 103 while num_retries > 0: 104 distortion_policies = rng_choice_with_size( 105 rng, 106 self.config.distortion_policies, 107 size=num_distortions, 108 probs=self.distortion_policy_probs, 109 replace=False, 110 ) 111 112 # Conflict analysis. 113 conflict_idx_to_count = defaultdict(int) 114 for distortion_policy in distortion_policies: 115 for conflict_idx, keywords in \ 116 enumerate(self.config.conflict_control_keyword_groups): 117 match = False 118 for keyword in keywords: 119 if keyword in distortion_policy.name: 120 match = True 121 break 122 if match: 123 conflict_idx_to_count[conflict_idx] += 1 124 break 125 126 no_conflict = True 127 for count in conflict_idx_to_count.values(): 128 if count > 1: 129 no_conflict = False 130 logger.debug( 131 'distortion policies conflict detected ' 132 f'conflict_idx_to_count={conflict_idx_to_count}' 133 ) 134 break 135 136 if no_conflict: 137 return distortion_policies 138 else: 139 num_retries -= 1 140 141 logger.warning(f'Cannot sample distortion policies with num_distortion={num_distortions}.') 142 return ()
def
apply_distortions( self, distortion_result: vkit.mechanism.distortion.interface.DistortionResult, level_min: int, level_max: int, rng: numpy.random._generator.Generator, debug: Union[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionDebug, NoneType] = None):
144 def apply_distortions( 145 self, 146 distortion_result: DistortionResult, 147 level_min: int, 148 level_max: int, 149 rng: RandomGenerator, 150 debug: Optional[RandomDistortionDebug] = None, 151 ): 152 if rng.random() > self.config.prob_enable: 153 return distortion_result 154 155 if self.config.inject_corner_points: 156 height, width = distortion_result.shape 157 158 step = min(height // 4, width // 4) 159 assert step > 0 160 161 ys = list(range(0, height, step)) 162 if ys[-1] < height - 1: 163 ys.append(height - 1) 164 165 xs = list(range(0, width, step)) 166 if xs[0] == 0: 167 xs.pop(0) 168 if xs[-1] == width - 1: 169 xs.pop() 170 171 corner_points = PointList() 172 173 for x in (0, width - 1): 174 for y in ys: 175 corner_points.append(Point.create(y=y, x=x)) 176 for y in (0, height - 1): 177 for x in xs: 178 corner_points.append(Point.create(y=y, x=x)) 179 180 distortion_result.corner_points = corner_points.to_point_tuple() 181 182 if self.config.force_sample_level_in_full_range: 183 level_min = LEVEL_MIN 184 level_max = LEVEL_MAX 185 186 distortion_policies = self.sample_distortion_policies(rng) 187 188 for distortion_policy in distortion_policies: 189 level = rng.integers(level_min, level_max + 1) 190 191 distortion_result = distortion_policy.distort( 192 level=level, 193 shapable_or_shape=distortion_result.shape, 194 image=distortion_result.image, 195 mask=distortion_result.mask, 196 score_map=distortion_result.score_map, 197 point=distortion_result.point, 198 points=distortion_result.points, 199 corner_points=distortion_result.corner_points, 200 polygon=distortion_result.polygon, 201 polygons=distortion_result.polygons, 202 rng=rng, 203 enable_debug=bool(debug), 204 ) 205 206 if debug: 207 assert distortion_result.image 208 debug.distortion_images.append(distortion_result.image) 209 debug.distortion_names.append(distortion_policy.name) 210 debug.distortion_levels.append(level) 211 debug.distortion_configs.append(distortion_result.config) 212 debug.distortion_states.append(distortion_result.state) 213 214 distortion_result.config = None 215 distortion_result.state = None 216 217 return distortion_result
class
RandomDistortion:
220class RandomDistortion: 221 222 def __init__( 223 self, 224 configs: Sequence[RandomDistortionStageConfig], 225 level_min: int, 226 level_max: int, 227 ): 228 self.stages = [RandomDistortionStage(config) for config in configs] 229 self.level_min = level_min 230 self.level_max = level_max 231 232 @classmethod 233 def get_distortion_result_all_points(cls, distortion_result: DistortionResult): 234 if distortion_result.corner_points: 235 yield from distortion_result.corner_points 236 237 if distortion_result.point: 238 yield distortion_result.point 239 240 if distortion_result.points: 241 yield from distortion_result.points 242 243 if distortion_result.polygon: 244 yield from distortion_result.polygon.points 245 246 if distortion_result.polygons: 247 for polygon in distortion_result.polygons: 248 yield from polygon.points 249 250 @classmethod 251 def get_distortion_result_element_bounding_box(cls, distortion_result: DistortionResult): 252 assert distortion_result.corner_points 253 254 all_points = cls.get_distortion_result_all_points(distortion_result) 255 point = next(all_points) 256 y_min = point.y 257 y_max = point.y 258 x_min = point.x 259 x_max = point.x 260 for point in all_points: 261 y_min = min(y_min, point.y) 262 y_max = max(y_max, point.y) 263 x_min = min(x_min, point.x) 264 x_max = max(x_max, point.x) 265 return Box(up=y_min, down=y_max, left=x_min, right=x_max) 266 267 @classmethod 268 def trim_distortion_result(cls, distortion_result: DistortionResult): 269 # Trim page if need. 270 if not distortion_result.corner_points: 271 return distortion_result 272 273 height, width = distortion_result.shape 274 box = cls.get_distortion_result_element_bounding_box(distortion_result) 275 276 pad_up = box.up 277 pad_down = height - 1 - box.down 278 # NOTE: accept the rounding error. 279 assert pad_up >= -1 and pad_down >= -1 280 281 pad_left = box.left 282 pad_right = width - 1 - box.right 283 assert pad_left >= -1 and pad_right >= -1 284 285 if pad_up <= 0 and pad_down <= 0 and pad_left <= 0 and pad_right <= 0: 286 return distortion_result 287 288 # Deal with rounding error. 289 up = max(0, box.up) 290 down = min(height - 1, box.down) 291 left = max(0, box.left) 292 right = min(width - 1, box.right) 293 294 pad_up = max(0, pad_up) 295 pad_down = max(0, pad_down) 296 pad_left = max(0, pad_left) 297 pad_right = max(0, pad_right) 298 299 if distortion_result.image: 300 distortion_result.image = distortion_result.image.to_cropped_image( 301 up=up, 302 down=down, 303 left=left, 304 right=right, 305 ) 306 307 if distortion_result.mask: 308 distortion_result.mask = distortion_result.mask.to_cropped_mask( 309 up=up, 310 down=down, 311 left=left, 312 right=right, 313 ) 314 315 if distortion_result.score_map: 316 distortion_result.score_map = distortion_result.score_map.to_cropped_score_map( 317 up=up, 318 down=down, 319 left=left, 320 right=right, 321 ) 322 323 if distortion_result.point: 324 distortion_result.point = distortion_result.point.to_shifted_point( 325 offset_y=-pad_up, 326 offset_x=-pad_left, 327 ) 328 329 if distortion_result.points: 330 distortion_result.points = distortion_result.points.to_shifted_points( 331 offset_y=-pad_up, 332 offset_x=-pad_left, 333 ) 334 335 if distortion_result.polygon: 336 distortion_result.polygon = distortion_result.polygon.to_shifted_polygon( 337 offset_y=-pad_up, 338 offset_x=-pad_left, 339 ) 340 341 if distortion_result.polygons: 342 distortion_result.polygons = [ 343 polygon.to_shifted_polygon( 344 offset_y=-pad_up, 345 offset_x=-pad_left, 346 ) for polygon in distortion_result.polygons 347 ] 348 349 return distortion_result 350 351 def distort( 352 self, 353 rng: RandomGenerator, 354 shapable_or_shape: Optional[Union[Shapable, Tuple[int, int]]] = None, 355 image: Optional[Image] = None, 356 mask: Optional[Mask] = None, 357 score_map: Optional[ScoreMap] = None, 358 point: Optional[Point] = None, 359 points: Optional[Union[PointList, PointTuple, Iterable[Point]]] = None, 360 polygon: Optional[Polygon] = None, 361 polygons: Optional[Iterable[Polygon]] = None, 362 debug: Optional[RandomDistortionDebug] = None, 363 ): 364 # Pack. 365 shape = Distortion.get_shape( 366 shapable_or_shape=shapable_or_shape, 367 image=image, 368 mask=mask, 369 score_map=score_map, 370 ) 371 distortion_result = DistortionResult(shape=shape) 372 distortion_result.image = image 373 distortion_result.mask = mask 374 distortion_result.score_map = score_map 375 distortion_result.point = point 376 distortion_result.points = PointTuple(points) if points else None 377 distortion_result.polygon = polygon 378 if polygons: 379 distortion_result.polygons = tuple(polygons) 380 381 # Apply distortions. 382 for stage in self.stages: 383 distortion_result = stage.apply_distortions( 384 distortion_result=distortion_result, 385 level_min=self.level_min, 386 level_max=self.level_max, 387 rng=rng, 388 debug=debug, 389 ) 390 391 distortion_result = self.trim_distortion_result(distortion_result) 392 393 return distortion_result
RandomDistortion( configs: Sequence[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionStageConfig], level_min: int, level_max: int)
@classmethod
def
get_distortion_result_all_points( cls, distortion_result: vkit.mechanism.distortion.interface.DistortionResult):
232 @classmethod 233 def get_distortion_result_all_points(cls, distortion_result: DistortionResult): 234 if distortion_result.corner_points: 235 yield from distortion_result.corner_points 236 237 if distortion_result.point: 238 yield distortion_result.point 239 240 if distortion_result.points: 241 yield from distortion_result.points 242 243 if distortion_result.polygon: 244 yield from distortion_result.polygon.points 245 246 if distortion_result.polygons: 247 for polygon in distortion_result.polygons: 248 yield from polygon.points
@classmethod
def
get_distortion_result_element_bounding_box( cls, distortion_result: vkit.mechanism.distortion.interface.DistortionResult):
250 @classmethod 251 def get_distortion_result_element_bounding_box(cls, distortion_result: DistortionResult): 252 assert distortion_result.corner_points 253 254 all_points = cls.get_distortion_result_all_points(distortion_result) 255 point = next(all_points) 256 y_min = point.y 257 y_max = point.y 258 x_min = point.x 259 x_max = point.x 260 for point in all_points: 261 y_min = min(y_min, point.y) 262 y_max = max(y_max, point.y) 263 x_min = min(x_min, point.x) 264 x_max = max(x_max, point.x) 265 return Box(up=y_min, down=y_max, left=x_min, right=x_max)
@classmethod
def
trim_distortion_result( cls, distortion_result: vkit.mechanism.distortion.interface.DistortionResult):
267 @classmethod 268 def trim_distortion_result(cls, distortion_result: DistortionResult): 269 # Trim page if need. 270 if not distortion_result.corner_points: 271 return distortion_result 272 273 height, width = distortion_result.shape 274 box = cls.get_distortion_result_element_bounding_box(distortion_result) 275 276 pad_up = box.up 277 pad_down = height - 1 - box.down 278 # NOTE: accept the rounding error. 279 assert pad_up >= -1 and pad_down >= -1 280 281 pad_left = box.left 282 pad_right = width - 1 - box.right 283 assert pad_left >= -1 and pad_right >= -1 284 285 if pad_up <= 0 and pad_down <= 0 and pad_left <= 0 and pad_right <= 0: 286 return distortion_result 287 288 # Deal with rounding error. 289 up = max(0, box.up) 290 down = min(height - 1, box.down) 291 left = max(0, box.left) 292 right = min(width - 1, box.right) 293 294 pad_up = max(0, pad_up) 295 pad_down = max(0, pad_down) 296 pad_left = max(0, pad_left) 297 pad_right = max(0, pad_right) 298 299 if distortion_result.image: 300 distortion_result.image = distortion_result.image.to_cropped_image( 301 up=up, 302 down=down, 303 left=left, 304 right=right, 305 ) 306 307 if distortion_result.mask: 308 distortion_result.mask = distortion_result.mask.to_cropped_mask( 309 up=up, 310 down=down, 311 left=left, 312 right=right, 313 ) 314 315 if distortion_result.score_map: 316 distortion_result.score_map = distortion_result.score_map.to_cropped_score_map( 317 up=up, 318 down=down, 319 left=left, 320 right=right, 321 ) 322 323 if distortion_result.point: 324 distortion_result.point = distortion_result.point.to_shifted_point( 325 offset_y=-pad_up, 326 offset_x=-pad_left, 327 ) 328 329 if distortion_result.points: 330 distortion_result.points = distortion_result.points.to_shifted_points( 331 offset_y=-pad_up, 332 offset_x=-pad_left, 333 ) 334 335 if distortion_result.polygon: 336 distortion_result.polygon = distortion_result.polygon.to_shifted_polygon( 337 offset_y=-pad_up, 338 offset_x=-pad_left, 339 ) 340 341 if distortion_result.polygons: 342 distortion_result.polygons = [ 343 polygon.to_shifted_polygon( 344 offset_y=-pad_up, 345 offset_x=-pad_left, 346 ) for polygon in distortion_result.polygons 347 ] 348 349 return distortion_result
def
distort( self, rng: numpy.random._generator.Generator, shapable_or_shape: Union[vkit.element.type.Shapable, Tuple[int, int], NoneType] = None, image: Union[vkit.element.image.Image, NoneType] = None, mask: Union[vkit.element.mask.Mask, NoneType] = None, score_map: Union[vkit.element.score_map.ScoreMap, NoneType] = None, point: Union[vkit.element.point.Point, NoneType] = None, points: Union[vkit.element.point.PointList, vkit.element.point.PointTuple, Iterable[vkit.element.point.Point], NoneType] = None, polygon: Union[vkit.element.polygon.Polygon, NoneType] = None, polygons: Union[Iterable[vkit.element.polygon.Polygon], NoneType] = None, debug: Union[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionDebug, NoneType] = None):
351 def distort( 352 self, 353 rng: RandomGenerator, 354 shapable_or_shape: Optional[Union[Shapable, Tuple[int, int]]] = None, 355 image: Optional[Image] = None, 356 mask: Optional[Mask] = None, 357 score_map: Optional[ScoreMap] = None, 358 point: Optional[Point] = None, 359 points: Optional[Union[PointList, PointTuple, Iterable[Point]]] = None, 360 polygon: Optional[Polygon] = None, 361 polygons: Optional[Iterable[Polygon]] = None, 362 debug: Optional[RandomDistortionDebug] = None, 363 ): 364 # Pack. 365 shape = Distortion.get_shape( 366 shapable_or_shape=shapable_or_shape, 367 image=image, 368 mask=mask, 369 score_map=score_map, 370 ) 371 distortion_result = DistortionResult(shape=shape) 372 distortion_result.image = image 373 distortion_result.mask = mask 374 distortion_result.score_map = score_map 375 distortion_result.point = point 376 distortion_result.points = PointTuple(points) if points else None 377 distortion_result.polygon = polygon 378 if polygons: 379 distortion_result.polygons = tuple(polygons) 380 381 # Apply distortions. 382 for stage in self.stages: 383 distortion_result = stage.apply_distortions( 384 distortion_result=distortion_result, 385 level_min=self.level_min, 386 level_max=self.level_max, 387 rng=rng, 388 debug=debug, 389 ) 390 391 distortion_result = self.trim_distortion_result(distortion_result) 392 393 return distortion_result
class
RandomDistortionFactoryConfig:
397class RandomDistortionFactoryConfig: 398 # Photometric. 399 prob_photometric: float = 1.0 400 num_photometric_min: int = 0 401 num_photometric_max: int = 2 402 photometric_conflict_control_keyword_groups: Sequence[Sequence[str]] = attrs.field( 403 factory=lambda: [ 404 [ 405 'blur', 406 'pixelation', 407 'jpeg', 408 ], 409 [ 410 'noise', 411 ], 412 ] 413 ) 414 # Geometric. 415 prob_geometric: float = 0.75 416 force_post_rotate: bool = False 417 # Shared. 418 level_min: int = LEVEL_MIN 419 level_max: int = LEVEL_MAX 420 disabled_policy_names: Sequence[str] = attrs.field(factory=list) 421 name_to_policy_config: Mapping[str, Any] = attrs.field(factory=dict) 422 name_to_policy_weight: Mapping[str, float] = attrs.field(factory=dict)
RandomDistortionFactoryConfig( prob_photometric: float = 1.0, num_photometric_min: int = 0, num_photometric_max: int = 2, photometric_conflict_control_keyword_groups: Sequence[Sequence[str]] = NOTHING, prob_geometric: float = 0.75, force_post_rotate: bool = False, level_min: int = 1, level_max: int = 10, disabled_policy_names: Sequence[str] = NOTHING, name_to_policy_config: Mapping[str, Any] = NOTHING, name_to_policy_weight: Mapping[str, float] = NOTHING)
2def __init__(self, prob_photometric=attr_dict['prob_photometric'].default, num_photometric_min=attr_dict['num_photometric_min'].default, num_photometric_max=attr_dict['num_photometric_max'].default, photometric_conflict_control_keyword_groups=NOTHING, prob_geometric=attr_dict['prob_geometric'].default, force_post_rotate=attr_dict['force_post_rotate'].default, level_min=attr_dict['level_min'].default, level_max=attr_dict['level_max'].default, disabled_policy_names=NOTHING, name_to_policy_config=NOTHING, name_to_policy_weight=NOTHING): 3 self.prob_photometric = prob_photometric 4 self.num_photometric_min = num_photometric_min 5 self.num_photometric_max = num_photometric_max 6 if photometric_conflict_control_keyword_groups is not NOTHING: 7 self.photometric_conflict_control_keyword_groups = photometric_conflict_control_keyword_groups 8 else: 9 self.photometric_conflict_control_keyword_groups = __attr_factory_photometric_conflict_control_keyword_groups() 10 self.prob_geometric = prob_geometric 11 self.force_post_rotate = force_post_rotate 12 self.level_min = level_min 13 self.level_max = level_max 14 if disabled_policy_names is not NOTHING: 15 self.disabled_policy_names = disabled_policy_names 16 else: 17 self.disabled_policy_names = __attr_factory_disabled_policy_names() 18 if name_to_policy_config is not NOTHING: 19 self.name_to_policy_config = name_to_policy_config 20 else: 21 self.name_to_policy_config = __attr_factory_name_to_policy_config() 22 if name_to_policy_weight is not NOTHING: 23 self.name_to_policy_weight = name_to_policy_weight 24 else: 25 self.name_to_policy_weight = __attr_factory_name_to_policy_weight()
Method generated by attrs for class RandomDistortionFactoryConfig.
class
RandomDistortionFactory:
505class RandomDistortionFactory: 506 507 @classmethod 508 def unpack_policy_factories_and_default_weights_sum_pairs( 509 cls, 510 policy_factories_and_default_weights_sum_pairs: Sequence[ 511 Tuple[ 512 Sequence[DistortionPolicyFactory], 513 float, 514 ] 515 ] 516 ): # yapf: disable 517 flatten_policy_factories: List[DistortionPolicyFactory] = [] 518 flatten_policy_default_weights: List[float] = [] 519 520 for policy_factories, default_weights_sum in policy_factories_and_default_weights_sum_pairs: 521 default_weight = default_weights_sum / len(policy_factories) 522 flatten_policy_factories.extend(policy_factories) 523 flatten_policy_default_weights.extend([default_weight] * len(policy_factories)) 524 525 assert len(flatten_policy_factories) == len(flatten_policy_default_weights) 526 return flatten_policy_factories, flatten_policy_default_weights 527 528 def __init__( 529 self, 530 photometric_policy_factories_and_default_weights_sum_pairs: Sequence[ 531 Tuple[ 532 Sequence[DistortionPolicyFactory], 533 float, 534 ] 535 ] = _PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS, 536 geometric_policy_factories_and_default_weights_sum_pairs: Sequence[ 537 Tuple[ 538 Sequence[DistortionPolicyFactory], 539 float, 540 ] 541 ] = _GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS, 542 ): # yapf: disable 543 ( 544 self.photometric_policy_factories, 545 self.photometric_policy_default_weights, 546 ) = self.unpack_policy_factories_and_default_weights_sum_pairs( 547 photometric_policy_factories_and_default_weights_sum_pairs 548 ) 549 550 ( 551 self.geometric_policy_factories, 552 self.geometric_policy_default_weights, 553 ) = self.unpack_policy_factories_and_default_weights_sum_pairs( 554 geometric_policy_factories_and_default_weights_sum_pairs 555 ) 556 557 @classmethod 558 def create_policies_and_policy_weights( 559 cls, 560 policy_factories: Sequence[DistortionPolicyFactory], 561 policy_default_weights: Sequence[float], 562 config: RandomDistortionFactoryConfig, 563 ): 564 disabled_policy_names = set(config.disabled_policy_names) 565 566 policies: List[DistortionPolicy] = [] 567 policy_weights: List[float] = [] 568 569 for policy_factory, policy_default_weight in zip(policy_factories, policy_default_weights): 570 if policy_factory.name in disabled_policy_names: 571 continue 572 573 policy_config = config.name_to_policy_config.get(policy_factory.name) 574 policies.append(policy_factory.create(policy_config)) 575 576 policy_weight = policy_default_weight 577 if policy_factory.name in config.name_to_policy_weight: 578 policy_weight = config.name_to_policy_weight[policy_factory.name] 579 policy_weights.append(policy_weight) 580 581 return policies, policy_weights 582 583 def create( 584 self, 585 config: Optional[ 586 Union[ 587 Mapping[str, Any], 588 PathType, 589 RandomDistortionFactoryConfig, 590 ] 591 ] = None, 592 ): # yapf: disable 593 config = dyn_structure( 594 config, 595 RandomDistortionFactoryConfig, 596 support_path_type=True, 597 support_none_type=True, 598 ) 599 600 stage_configs: List[RandomDistortionStageConfig] = [] 601 602 # Photometric. 603 ( 604 photometric_policies, 605 photometric_policy_weights, 606 ) = self.create_policies_and_policy_weights( 607 self.photometric_policy_factories, 608 self.photometric_policy_default_weights, 609 config, 610 ) 611 stage_configs.append( 612 RandomDistortionStageConfig( 613 distortion_policies=photometric_policies, 614 distortion_policy_weights=photometric_policy_weights, 615 prob_enable=config.prob_photometric, 616 num_distortions_min=config.num_photometric_min, 617 num_distortions_max=config.num_photometric_max, 618 conflict_control_keyword_groups=config.photometric_conflict_control_keyword_groups, 619 ) 620 ) 621 622 # Geometric. 623 ( 624 geometric_policies, 625 geometric_policy_weights, 626 ) = self.create_policies_and_policy_weights( 627 self.geometric_policy_factories, 628 self.geometric_policy_default_weights, 629 config, 630 ) 631 632 post_rotate_policy = None 633 if config.force_post_rotate: 634 rotate_policy_idx = -1 635 for geometric_policy_idx, geometric_policy in enumerate(geometric_policies): 636 if geometric_policy.name == 'rotate': 637 rotate_policy_idx = geometric_policy_idx 638 break 639 assert rotate_policy_idx >= 0 640 post_rotate_policy = geometric_policies.pop(rotate_policy_idx) 641 geometric_policy_weights.pop(rotate_policy_idx) 642 643 stage_configs.append( 644 RandomDistortionStageConfig( 645 distortion_policies=geometric_policies, 646 distortion_policy_weights=geometric_policy_weights, 647 prob_enable=config.prob_geometric, 648 num_distortions_min=1, 649 num_distortions_max=1, 650 inject_corner_points=config.force_post_rotate, 651 ) 652 ) 653 if post_rotate_policy: 654 stage_configs.append( 655 RandomDistortionStageConfig( 656 distortion_policies=[post_rotate_policy], 657 distortion_policy_weights=[1.0], 658 prob_enable=1.0, 659 num_distortions_min=1, 660 num_distortions_max=1, 661 force_sample_level_in_full_range=True, 662 ) 663 ) 664 665 return RandomDistortion( 666 configs=stage_configs, 667 level_min=config.level_min, 668 level_max=config.level_max, 669 )
RandomDistortionFactory( photometric_policy_factories_and_default_weights_sum_pairs: Sequence[Tuple[Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], float]] = (((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36e50>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36df0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36eb0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd362b0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36220>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36f70>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3bdc0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b340>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b100>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b280>), 10.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36790>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4f9d0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b9a0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b400>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b220>), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36be0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b3a0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4f4c0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4fca0>), 3.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4fe80>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd62c10>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd62130>), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd62f70>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4fb50>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd735e0>), 1.0)), geometric_policy_factories_and_default_weights_sum_pairs: Sequence[Tuple[Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], float]] = (((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd73280>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd070d0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd732e0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd737f0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95cd0>), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95850>,), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd73b80>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95d90>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95d30>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95610>), 1.0)))
528 def __init__( 529 self, 530 photometric_policy_factories_and_default_weights_sum_pairs: Sequence[ 531 Tuple[ 532 Sequence[DistortionPolicyFactory], 533 float, 534 ] 535 ] = _PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS, 536 geometric_policy_factories_and_default_weights_sum_pairs: Sequence[ 537 Tuple[ 538 Sequence[DistortionPolicyFactory], 539 float, 540 ] 541 ] = _GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS, 542 ): # yapf: disable 543 ( 544 self.photometric_policy_factories, 545 self.photometric_policy_default_weights, 546 ) = self.unpack_policy_factories_and_default_weights_sum_pairs( 547 photometric_policy_factories_and_default_weights_sum_pairs 548 ) 549 550 ( 551 self.geometric_policy_factories, 552 self.geometric_policy_default_weights, 553 ) = self.unpack_policy_factories_and_default_weights_sum_pairs( 554 geometric_policy_factories_and_default_weights_sum_pairs 555 )
@classmethod
def
unpack_policy_factories_and_default_weights_sum_pairs( cls, policy_factories_and_default_weights_sum_pairs: Sequence[Tuple[Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], float]]):
507 @classmethod 508 def unpack_policy_factories_and_default_weights_sum_pairs( 509 cls, 510 policy_factories_and_default_weights_sum_pairs: Sequence[ 511 Tuple[ 512 Sequence[DistortionPolicyFactory], 513 float, 514 ] 515 ] 516 ): # yapf: disable 517 flatten_policy_factories: List[DistortionPolicyFactory] = [] 518 flatten_policy_default_weights: List[float] = [] 519 520 for policy_factories, default_weights_sum in policy_factories_and_default_weights_sum_pairs: 521 default_weight = default_weights_sum / len(policy_factories) 522 flatten_policy_factories.extend(policy_factories) 523 flatten_policy_default_weights.extend([default_weight] * len(policy_factories)) 524 525 assert len(flatten_policy_factories) == len(flatten_policy_default_weights) 526 return flatten_policy_factories, flatten_policy_default_weights
@classmethod
def
create_policies_and_policy_weights( cls, policy_factories: Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], policy_default_weights: Sequence[float], config: vkit.mechanism.distortion_policy.random_distortion.RandomDistortionFactoryConfig):
557 @classmethod 558 def create_policies_and_policy_weights( 559 cls, 560 policy_factories: Sequence[DistortionPolicyFactory], 561 policy_default_weights: Sequence[float], 562 config: RandomDistortionFactoryConfig, 563 ): 564 disabled_policy_names = set(config.disabled_policy_names) 565 566 policies: List[DistortionPolicy] = [] 567 policy_weights: List[float] = [] 568 569 for policy_factory, policy_default_weight in zip(policy_factories, policy_default_weights): 570 if policy_factory.name in disabled_policy_names: 571 continue 572 573 policy_config = config.name_to_policy_config.get(policy_factory.name) 574 policies.append(policy_factory.create(policy_config)) 575 576 policy_weight = policy_default_weight 577 if policy_factory.name in config.name_to_policy_weight: 578 policy_weight = config.name_to_policy_weight[policy_factory.name] 579 policy_weights.append(policy_weight) 580 581 return policies, policy_weights
def
create( self, config: Union[Mapping[str, Any], str, os.PathLike, vkit.mechanism.distortion_policy.random_distortion.RandomDistortionFactoryConfig, NoneType] = None):
583 def create( 584 self, 585 config: Optional[ 586 Union[ 587 Mapping[str, Any], 588 PathType, 589 RandomDistortionFactoryConfig, 590 ] 591 ] = None, 592 ): # yapf: disable 593 config = dyn_structure( 594 config, 595 RandomDistortionFactoryConfig, 596 support_path_type=True, 597 support_none_type=True, 598 ) 599 600 stage_configs: List[RandomDistortionStageConfig] = [] 601 602 # Photometric. 603 ( 604 photometric_policies, 605 photometric_policy_weights, 606 ) = self.create_policies_and_policy_weights( 607 self.photometric_policy_factories, 608 self.photometric_policy_default_weights, 609 config, 610 ) 611 stage_configs.append( 612 RandomDistortionStageConfig( 613 distortion_policies=photometric_policies, 614 distortion_policy_weights=photometric_policy_weights, 615 prob_enable=config.prob_photometric, 616 num_distortions_min=config.num_photometric_min, 617 num_distortions_max=config.num_photometric_max, 618 conflict_control_keyword_groups=config.photometric_conflict_control_keyword_groups, 619 ) 620 ) 621 622 # Geometric. 623 ( 624 geometric_policies, 625 geometric_policy_weights, 626 ) = self.create_policies_and_policy_weights( 627 self.geometric_policy_factories, 628 self.geometric_policy_default_weights, 629 config, 630 ) 631 632 post_rotate_policy = None 633 if config.force_post_rotate: 634 rotate_policy_idx = -1 635 for geometric_policy_idx, geometric_policy in enumerate(geometric_policies): 636 if geometric_policy.name == 'rotate': 637 rotate_policy_idx = geometric_policy_idx 638 break 639 assert rotate_policy_idx >= 0 640 post_rotate_policy = geometric_policies.pop(rotate_policy_idx) 641 geometric_policy_weights.pop(rotate_policy_idx) 642 643 stage_configs.append( 644 RandomDistortionStageConfig( 645 distortion_policies=geometric_policies, 646 distortion_policy_weights=geometric_policy_weights, 647 prob_enable=config.prob_geometric, 648 num_distortions_min=1, 649 num_distortions_max=1, 650 inject_corner_points=config.force_post_rotate, 651 ) 652 ) 653 if post_rotate_policy: 654 stage_configs.append( 655 RandomDistortionStageConfig( 656 distortion_policies=[post_rotate_policy], 657 distortion_policy_weights=[1.0], 658 prob_enable=1.0, 659 num_distortions_min=1, 660 num_distortions_max=1, 661 force_sample_level_in_full_range=True, 662 ) 663 ) 664 665 return RandomDistortion( 666 configs=stage_configs, 667 level_min=config.level_min, 668 level_max=config.level_max, 669 )