vkit.mechanism.distortion_policy.random_distortion

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import (
 15    Mapping,
 16    Any,
 17    Iterable,
 18    Union,
 19    Tuple,
 20    Optional,
 21    Sequence,
 22    List,
 23)
 24from collections import defaultdict
 25import logging
 26
 27import attrs
 28from numpy.random import Generator as RandomGenerator
 29
 30from vkit.utility import (
 31    dyn_structure,
 32    rng_choice_with_size,
 33    normalize_to_probs,
 34    PathType,
 35)
 36from vkit.element import (
 37    Shapable,
 38    Image,
 39    Point,
 40    PointList,
 41    PointTuple,
 42    Box,
 43    Polygon,
 44    Mask,
 45    ScoreMap,
 46)
 47from ..distortion.interface import Distortion, DistortionResult
 48from .opt import LEVEL_MIN, LEVEL_MAX
 49from .type import DistortionPolicy, DistortionPolicyFactory
 50from .photometric import (
 51    color,
 52    blur,
 53    noise,
 54    effect,
 55    streak,
 56)
 57from .geometric import (
 58    affine,
 59    mls,
 60    camera,
 61)
 62
 63logger = logging.getLogger(__name__)
 64
 65
 66@attrs.define
 67class RandomDistortionDebug:
 68    distortion_names: List[str] = attrs.field(factory=list)
 69    distortion_levels: List[int] = attrs.field(factory=list)
 70    distortion_images: List[Image] = attrs.field(factory=list)
 71    distortion_configs: List[Any] = attrs.field(factory=list)
 72    distortion_states: List[Any] = attrs.field(factory=list)
 73
 74
 75@attrs.define
 76class RandomDistortionStageConfig:
 77    distortion_policies: Sequence[DistortionPolicy]
 78    distortion_policy_weights: Sequence[float]
 79    prob_enable: float
 80    num_distortions_min: int
 81    num_distortions_max: int
 82    inject_corner_points: bool = False
 83    conflict_control_keyword_groups: Sequence[Sequence[str]] = ()
 84    force_sample_level_in_full_range: bool = False
 85
 86
 87class RandomDistortionStage:
 88
 89    def __init__(self, config: RandomDistortionStageConfig):
 90        self.config = config
 91        self.distortion_policy_probs = normalize_to_probs(self.config.distortion_policy_weights)
 92
 93    def sample_distortion_policies(self, rng: RandomGenerator) -> Sequence[DistortionPolicy]:
 94        num_distortions = rng.integers(
 95            self.config.num_distortions_min,
 96            self.config.num_distortions_max + 1,
 97        )
 98        if num_distortions <= 0:
 99            return ()
100
101        num_retries = 5
102        while num_retries > 0:
103            distortion_policies = rng_choice_with_size(
104                rng,
105                self.config.distortion_policies,
106                size=num_distortions,
107                probs=self.distortion_policy_probs,
108                replace=False,
109            )
110
111            # Conflict analysis.
112            conflict_idx_to_count = defaultdict(int)
113            for distortion_policy in distortion_policies:
114                for conflict_idx, keywords in \
115                        enumerate(self.config.conflict_control_keyword_groups):
116                    match = False
117                    for keyword in keywords:
118                        if keyword in distortion_policy.name:
119                            match = True
120                            break
121                    if match:
122                        conflict_idx_to_count[conflict_idx] += 1
123                        break
124
125            no_conflict = True
126            for count in conflict_idx_to_count.values():
127                if count > 1:
128                    no_conflict = False
129                    logger.debug(
130                        'distortion policies conflict detected '
131                        f'conflict_idx_to_count={conflict_idx_to_count}'
132                    )
133                    break
134
135            if no_conflict:
136                return distortion_policies
137            else:
138                num_retries -= 1
139
140        logger.warning(f'Cannot sample distortion policies with num_distortion={num_distortions}.')
141        return ()
142
143    def apply_distortions(
144        self,
145        distortion_result: DistortionResult,
146        level_min: int,
147        level_max: int,
148        rng: RandomGenerator,
149        debug: Optional[RandomDistortionDebug] = None,
150    ):
151        if rng.random() > self.config.prob_enable:
152            return distortion_result
153
154        if self.config.inject_corner_points:
155            height, width = distortion_result.shape
156
157            step = min(height // 4, width // 4)
158            assert step > 0
159
160            ys = list(range(0, height, step))
161            if ys[-1] < height - 1:
162                ys.append(height - 1)
163
164            xs = list(range(0, width, step))
165            if xs[0] == 0:
166                xs.pop(0)
167            if xs[-1] == width - 1:
168                xs.pop()
169
170            corner_points = PointList()
171
172            for x in (0, width - 1):
173                for y in ys:
174                    corner_points.append(Point.create(y=y, x=x))
175            for y in (0, height - 1):
176                for x in xs:
177                    corner_points.append(Point.create(y=y, x=x))
178
179            distortion_result.corner_points = corner_points.to_point_tuple()
180
181        if self.config.force_sample_level_in_full_range:
182            level_min = LEVEL_MIN
183            level_max = LEVEL_MAX
184
185        distortion_policies = self.sample_distortion_policies(rng)
186
187        for distortion_policy in distortion_policies:
188            level = rng.integers(level_min, level_max + 1)
189
190            distortion_result = distortion_policy.distort(
191                level=level,
192                shapable_or_shape=distortion_result.shape,
193                image=distortion_result.image,
194                mask=distortion_result.mask,
195                score_map=distortion_result.score_map,
196                point=distortion_result.point,
197                points=distortion_result.points,
198                corner_points=distortion_result.corner_points,
199                polygon=distortion_result.polygon,
200                polygons=distortion_result.polygons,
201                rng=rng,
202                enable_debug=bool(debug),
203            )
204
205            if debug:
206                assert distortion_result.image
207                debug.distortion_images.append(distortion_result.image)
208                debug.distortion_names.append(distortion_policy.name)
209                debug.distortion_levels.append(level)
210                debug.distortion_configs.append(distortion_result.config)
211                debug.distortion_states.append(distortion_result.state)
212
213            distortion_result.config = None
214            distortion_result.state = None
215
216        return distortion_result
217
218
219class RandomDistortion:
220
221    def __init__(
222        self,
223        configs: Sequence[RandomDistortionStageConfig],
224        level_min: int,
225        level_max: int,
226    ):
227        self.stages = [RandomDistortionStage(config) for config in configs]
228        self.level_min = level_min
229        self.level_max = level_max
230
231    @classmethod
232    def get_distortion_result_all_points(cls, distortion_result: DistortionResult):
233        if distortion_result.corner_points:
234            yield from distortion_result.corner_points
235
236        if distortion_result.point:
237            yield distortion_result.point
238
239        if distortion_result.points:
240            yield from distortion_result.points
241
242        if distortion_result.polygon:
243            yield from distortion_result.polygon.points
244
245        if distortion_result.polygons:
246            for polygon in distortion_result.polygons:
247                yield from polygon.points
248
249    @classmethod
250    def get_distortion_result_element_bounding_box(cls, distortion_result: DistortionResult):
251        assert distortion_result.corner_points
252
253        all_points = cls.get_distortion_result_all_points(distortion_result)
254        point = next(all_points)
255        y_min = point.y
256        y_max = point.y
257        x_min = point.x
258        x_max = point.x
259        for point in all_points:
260            y_min = min(y_min, point.y)
261            y_max = max(y_max, point.y)
262            x_min = min(x_min, point.x)
263            x_max = max(x_max, point.x)
264        return Box(up=y_min, down=y_max, left=x_min, right=x_max)
265
266    @classmethod
267    def trim_distortion_result(cls, distortion_result: DistortionResult):
268        # Trim page if need.
269        if not distortion_result.corner_points:
270            return distortion_result
271
272        height, width = distortion_result.shape
273        box = cls.get_distortion_result_element_bounding_box(distortion_result)
274
275        pad_up = box.up
276        pad_down = height - 1 - box.down
277        # NOTE: accept the rounding error.
278        assert pad_up >= -1 and pad_down >= -1
279
280        pad_left = box.left
281        pad_right = width - 1 - box.right
282        assert pad_left >= -1 and pad_right >= -1
283
284        if pad_up <= 0 and pad_down <= 0 and pad_left <= 0 and pad_right <= 0:
285            return distortion_result
286
287        # Deal with rounding error.
288        up = max(0, box.up)
289        down = min(height - 1, box.down)
290        left = max(0, box.left)
291        right = min(width - 1, box.right)
292
293        pad_up = max(0, pad_up)
294        pad_down = max(0, pad_down)
295        pad_left = max(0, pad_left)
296        pad_right = max(0, pad_right)
297
298        if distortion_result.image:
299            distortion_result.image = distortion_result.image.to_cropped_image(
300                up=up,
301                down=down,
302                left=left,
303                right=right,
304            )
305
306        if distortion_result.mask:
307            distortion_result.mask = distortion_result.mask.to_cropped_mask(
308                up=up,
309                down=down,
310                left=left,
311                right=right,
312            )
313
314        if distortion_result.score_map:
315            distortion_result.score_map = distortion_result.score_map.to_cropped_score_map(
316                up=up,
317                down=down,
318                left=left,
319                right=right,
320            )
321
322        if distortion_result.point:
323            distortion_result.point = distortion_result.point.to_shifted_point(
324                offset_y=-pad_up,
325                offset_x=-pad_left,
326            )
327
328        if distortion_result.points:
329            distortion_result.points = distortion_result.points.to_shifted_points(
330                offset_y=-pad_up,
331                offset_x=-pad_left,
332            )
333
334        if distortion_result.polygon:
335            distortion_result.polygon = distortion_result.polygon.to_shifted_polygon(
336                offset_y=-pad_up,
337                offset_x=-pad_left,
338            )
339
340        if distortion_result.polygons:
341            distortion_result.polygons = [
342                polygon.to_shifted_polygon(
343                    offset_y=-pad_up,
344                    offset_x=-pad_left,
345                ) for polygon in distortion_result.polygons
346            ]
347
348        return distortion_result
349
350    def distort(
351        self,
352        rng: RandomGenerator,
353        shapable_or_shape: Optional[Union[Shapable, Tuple[int, int]]] = None,
354        image: Optional[Image] = None,
355        mask: Optional[Mask] = None,
356        score_map: Optional[ScoreMap] = None,
357        point: Optional[Point] = None,
358        points: Optional[Union[PointList, PointTuple, Iterable[Point]]] = None,
359        polygon: Optional[Polygon] = None,
360        polygons: Optional[Iterable[Polygon]] = None,
361        debug: Optional[RandomDistortionDebug] = None,
362    ):
363        # Pack.
364        shape = Distortion.get_shape(
365            shapable_or_shape=shapable_or_shape,
366            image=image,
367            mask=mask,
368            score_map=score_map,
369        )
370        distortion_result = DistortionResult(shape=shape)
371        distortion_result.image = image
372        distortion_result.mask = mask
373        distortion_result.score_map = score_map
374        distortion_result.point = point
375        distortion_result.points = PointTuple(points) if points else None
376        distortion_result.polygon = polygon
377        if polygons:
378            distortion_result.polygons = tuple(polygons)
379
380        # Apply distortions.
381        for stage in self.stages:
382            distortion_result = stage.apply_distortions(
383                distortion_result=distortion_result,
384                level_min=self.level_min,
385                level_max=self.level_max,
386                rng=rng,
387                debug=debug,
388            )
389
390        distortion_result = self.trim_distortion_result(distortion_result)
391
392        return distortion_result
393
394
395@attrs.define
396class RandomDistortionFactoryConfig:
397    # Photometric.
398    prob_photometric: float = 1.0
399    num_photometric_min: int = 0
400    num_photometric_max: int = 2
401    photometric_conflict_control_keyword_groups: Sequence[Sequence[str]] = attrs.field(
402        factory=lambda: [
403            [
404                'blur',
405                'pixelation',
406                'jpeg',
407            ],
408            [
409                'noise',
410            ],
411        ]
412    )
413    # Geometric.
414    prob_geometric: float = 0.75
415    force_post_rotate: bool = False
416    # Shared.
417    level_min: int = LEVEL_MIN
418    level_max: int = LEVEL_MAX
419    disabled_policy_names: Sequence[str] = attrs.field(factory=list)
420    name_to_policy_config: Mapping[str, Any] = attrs.field(factory=dict)
421    name_to_policy_weight: Mapping[str, float] = attrs.field(factory=dict)
422
423
424_PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS = (
425    (
426        (
427            color.mean_shift_policy_factory,
428            color.color_shift_policy_factory,
429            color.brightness_shift_policy_factory,
430            color.std_shift_policy_factory,
431            color.boundary_equalization_policy_factory,
432            color.histogram_equalization_policy_factory,
433            color.complement_policy_factory,
434            color.posterization_policy_factory,
435            color.color_balance_policy_factory,
436            color.channel_permutation_policy_factory,
437        ),
438        10.0,
439    ),
440    (
441        (
442            blur.gaussian_blur_policy_factory,
443            blur.defocus_blur_policy_factory,
444            blur.motion_blur_policy_factory,
445            blur.glass_blur_policy_factory,
446            blur.zoom_in_blur_policy_factory,
447        ),
448        1.0,
449    ),
450    (
451        (
452            noise.gaussion_noise_policy_factory,
453            noise.poisson_noise_policy_factory,
454            noise.impulse_noise_policy_factory,
455            noise.speckle_noise_policy_factory,
456        ),
457        3.0,
458    ),
459    (
460        (
461            effect.jpeg_quality_policy_factory,
462            effect.pixelation_policy_factory,
463            effect.fog_policy_factory,
464        ),
465        1.0,
466    ),
467    (
468        (
469            streak.line_streak_policy_factory,
470            streak.rectangle_streak_policy_factory,
471            streak.ellipse_streak_policy_factory,
472        ),
473        1.0,
474    ),
475)
476
477_GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS = (
478    (
479        (
480            affine.shear_hori_policy_factory,
481            affine.shear_vert_policy_factory,
482            affine.rotate_policy_factory,
483            affine.skew_hori_policy_factory,
484            affine.skew_vert_policy_factory,
485        ),
486        1.0,
487    ),
488    (
489        (mls.similarity_mls_policy_factory,),
490        1.0,
491    ),
492    (
493        (
494            camera.camera_plane_only_policy_factory,
495            camera.camera_cubic_curve_policy_factory,
496            camera.camera_plane_line_fold_policy_factory,
497            camera.camera_plane_line_curve_policy_factory,
498        ),
499        1.0,
500    ),
501)
502
503
504class RandomDistortionFactory:
505
506    @classmethod
507    def unpack_policy_factories_and_default_weights_sum_pairs(
508        cls,
509        policy_factories_and_default_weights_sum_pairs: Sequence[
510            Tuple[
511                Sequence[DistortionPolicyFactory],
512                float,
513            ]
514        ]
515    ):  # yapf: disable
516        flatten_policy_factories: List[DistortionPolicyFactory] = []
517        flatten_policy_default_weights: List[float] = []
518
519        for policy_factories, default_weights_sum in policy_factories_and_default_weights_sum_pairs:
520            default_weight = default_weights_sum / len(policy_factories)
521            flatten_policy_factories.extend(policy_factories)
522            flatten_policy_default_weights.extend([default_weight] * len(policy_factories))
523
524        assert len(flatten_policy_factories) == len(flatten_policy_default_weights)
525        return flatten_policy_factories, flatten_policy_default_weights
526
527    def __init__(
528        self,
529        photometric_policy_factories_and_default_weights_sum_pairs: Sequence[
530            Tuple[
531                Sequence[DistortionPolicyFactory],
532                float,
533            ]
534        ] = _PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS,
535        geometric_policy_factories_and_default_weights_sum_pairs: Sequence[
536            Tuple[
537                Sequence[DistortionPolicyFactory],
538                float,
539            ]
540        ] = _GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS,
541    ):  # yapf: disable
542        (
543            self.photometric_policy_factories,
544            self.photometric_policy_default_weights,
545        ) = self.unpack_policy_factories_and_default_weights_sum_pairs(
546            photometric_policy_factories_and_default_weights_sum_pairs
547        )
548
549        (
550            self.geometric_policy_factories,
551            self.geometric_policy_default_weights,
552        ) = self.unpack_policy_factories_and_default_weights_sum_pairs(
553            geometric_policy_factories_and_default_weights_sum_pairs
554        )
555
556    @classmethod
557    def create_policies_and_policy_weights(
558        cls,
559        policy_factories: Sequence[DistortionPolicyFactory],
560        policy_default_weights: Sequence[float],
561        config: RandomDistortionFactoryConfig,
562    ):
563        disabled_policy_names = set(config.disabled_policy_names)
564
565        policies: List[DistortionPolicy] = []
566        policy_weights: List[float] = []
567
568        for policy_factory, policy_default_weight in zip(policy_factories, policy_default_weights):
569            if policy_factory.name in disabled_policy_names:
570                continue
571
572            policy_config = config.name_to_policy_config.get(policy_factory.name)
573            policies.append(policy_factory.create(policy_config))
574
575            policy_weight = policy_default_weight
576            if policy_factory.name in config.name_to_policy_weight:
577                policy_weight = config.name_to_policy_weight[policy_factory.name]
578            policy_weights.append(policy_weight)
579
580        return policies, policy_weights
581
582    def create(
583        self,
584        config: Optional[
585            Union[
586                Mapping[str, Any],
587                PathType,
588                RandomDistortionFactoryConfig,
589            ]
590        ] = None,
591    ):  # yapf: disable
592        config = dyn_structure(
593            config,
594            RandomDistortionFactoryConfig,
595            support_path_type=True,
596            support_none_type=True,
597        )
598
599        stage_configs: List[RandomDistortionStageConfig] = []
600
601        # Photometric.
602        (
603            photometric_policies,
604            photometric_policy_weights,
605        ) = self.create_policies_and_policy_weights(
606            self.photometric_policy_factories,
607            self.photometric_policy_default_weights,
608            config,
609        )
610        stage_configs.append(
611            RandomDistortionStageConfig(
612                distortion_policies=photometric_policies,
613                distortion_policy_weights=photometric_policy_weights,
614                prob_enable=config.prob_photometric,
615                num_distortions_min=config.num_photometric_min,
616                num_distortions_max=config.num_photometric_max,
617                conflict_control_keyword_groups=config.photometric_conflict_control_keyword_groups,
618            )
619        )
620
621        # Geometric.
622        (
623            geometric_policies,
624            geometric_policy_weights,
625        ) = self.create_policies_and_policy_weights(
626            self.geometric_policy_factories,
627            self.geometric_policy_default_weights,
628            config,
629        )
630
631        post_rotate_policy = None
632        if config.force_post_rotate:
633            rotate_policy_idx = -1
634            for geometric_policy_idx, geometric_policy in enumerate(geometric_policies):
635                if geometric_policy.name == 'rotate':
636                    rotate_policy_idx = geometric_policy_idx
637                    break
638            assert rotate_policy_idx >= 0
639            post_rotate_policy = geometric_policies.pop(rotate_policy_idx)
640            geometric_policy_weights.pop(rotate_policy_idx)
641
642        stage_configs.append(
643            RandomDistortionStageConfig(
644                distortion_policies=geometric_policies,
645                distortion_policy_weights=geometric_policy_weights,
646                prob_enable=config.prob_geometric,
647                num_distortions_min=1,
648                num_distortions_max=1,
649                inject_corner_points=config.force_post_rotate,
650            )
651        )
652        if post_rotate_policy:
653            stage_configs.append(
654                RandomDistortionStageConfig(
655                    distortion_policies=[post_rotate_policy],
656                    distortion_policy_weights=[1.0],
657                    prob_enable=1.0,
658                    num_distortions_min=1,
659                    num_distortions_max=1,
660                    force_sample_level_in_full_range=True,
661                )
662            )
663
664        return RandomDistortion(
665            configs=stage_configs,
666            level_min=config.level_min,
667            level_max=config.level_max,
668        )
669
670
671random_distortion_factory = RandomDistortionFactory()
class RandomDistortionDebug:
68class RandomDistortionDebug:
69    distortion_names: List[str] = attrs.field(factory=list)
70    distortion_levels: List[int] = attrs.field(factory=list)
71    distortion_images: List[Image] = attrs.field(factory=list)
72    distortion_configs: List[Any] = attrs.field(factory=list)
73    distortion_states: List[Any] = attrs.field(factory=list)
RandomDistortionDebug( distortion_names: List[str] = NOTHING, distortion_levels: List[int] = NOTHING, distortion_images: List[vkit.element.image.Image] = NOTHING, distortion_configs: List[Any] = NOTHING, distortion_states: List[Any] = NOTHING)
 2def __init__(self, distortion_names=NOTHING, distortion_levels=NOTHING, distortion_images=NOTHING, distortion_configs=NOTHING, distortion_states=NOTHING):
 3    if distortion_names is not NOTHING:
 4        self.distortion_names = distortion_names
 5    else:
 6        self.distortion_names = __attr_factory_distortion_names()
 7    if distortion_levels is not NOTHING:
 8        self.distortion_levels = distortion_levels
 9    else:
10        self.distortion_levels = __attr_factory_distortion_levels()
11    if distortion_images is not NOTHING:
12        self.distortion_images = distortion_images
13    else:
14        self.distortion_images = __attr_factory_distortion_images()
15    if distortion_configs is not NOTHING:
16        self.distortion_configs = distortion_configs
17    else:
18        self.distortion_configs = __attr_factory_distortion_configs()
19    if distortion_states is not NOTHING:
20        self.distortion_states = distortion_states
21    else:
22        self.distortion_states = __attr_factory_distortion_states()

Method generated by attrs for class RandomDistortionDebug.

class RandomDistortionStageConfig:
77class RandomDistortionStageConfig:
78    distortion_policies: Sequence[DistortionPolicy]
79    distortion_policy_weights: Sequence[float]
80    prob_enable: float
81    num_distortions_min: int
82    num_distortions_max: int
83    inject_corner_points: bool = False
84    conflict_control_keyword_groups: Sequence[Sequence[str]] = ()
85    force_sample_level_in_full_range: bool = False
RandomDistortionStageConfig( distortion_policies: Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicy], distortion_policy_weights: Sequence[float], prob_enable: float, num_distortions_min: int, num_distortions_max: int, inject_corner_points: bool = False, conflict_control_keyword_groups: Sequence[Sequence[str]] = (), force_sample_level_in_full_range: bool = False)
 2def __init__(self, distortion_policies, distortion_policy_weights, prob_enable, num_distortions_min, num_distortions_max, inject_corner_points=attr_dict['inject_corner_points'].default, conflict_control_keyword_groups=attr_dict['conflict_control_keyword_groups'].default, force_sample_level_in_full_range=attr_dict['force_sample_level_in_full_range'].default):
 3    self.distortion_policies = distortion_policies
 4    self.distortion_policy_weights = distortion_policy_weights
 5    self.prob_enable = prob_enable
 6    self.num_distortions_min = num_distortions_min
 7    self.num_distortions_max = num_distortions_max
 8    self.inject_corner_points = inject_corner_points
 9    self.conflict_control_keyword_groups = conflict_control_keyword_groups
10    self.force_sample_level_in_full_range = force_sample_level_in_full_range

Method generated by attrs for class RandomDistortionStageConfig.

class RandomDistortionStage:
 88class RandomDistortionStage:
 89
 90    def __init__(self, config: RandomDistortionStageConfig):
 91        self.config = config
 92        self.distortion_policy_probs = normalize_to_probs(self.config.distortion_policy_weights)
 93
 94    def sample_distortion_policies(self, rng: RandomGenerator) -> Sequence[DistortionPolicy]:
 95        num_distortions = rng.integers(
 96            self.config.num_distortions_min,
 97            self.config.num_distortions_max + 1,
 98        )
 99        if num_distortions <= 0:
100            return ()
101
102        num_retries = 5
103        while num_retries > 0:
104            distortion_policies = rng_choice_with_size(
105                rng,
106                self.config.distortion_policies,
107                size=num_distortions,
108                probs=self.distortion_policy_probs,
109                replace=False,
110            )
111
112            # Conflict analysis.
113            conflict_idx_to_count = defaultdict(int)
114            for distortion_policy in distortion_policies:
115                for conflict_idx, keywords in \
116                        enumerate(self.config.conflict_control_keyword_groups):
117                    match = False
118                    for keyword in keywords:
119                        if keyword in distortion_policy.name:
120                            match = True
121                            break
122                    if match:
123                        conflict_idx_to_count[conflict_idx] += 1
124                        break
125
126            no_conflict = True
127            for count in conflict_idx_to_count.values():
128                if count > 1:
129                    no_conflict = False
130                    logger.debug(
131                        'distortion policies conflict detected '
132                        f'conflict_idx_to_count={conflict_idx_to_count}'
133                    )
134                    break
135
136            if no_conflict:
137                return distortion_policies
138            else:
139                num_retries -= 1
140
141        logger.warning(f'Cannot sample distortion policies with num_distortion={num_distortions}.')
142        return ()
143
144    def apply_distortions(
145        self,
146        distortion_result: DistortionResult,
147        level_min: int,
148        level_max: int,
149        rng: RandomGenerator,
150        debug: Optional[RandomDistortionDebug] = None,
151    ):
152        if rng.random() > self.config.prob_enable:
153            return distortion_result
154
155        if self.config.inject_corner_points:
156            height, width = distortion_result.shape
157
158            step = min(height // 4, width // 4)
159            assert step > 0
160
161            ys = list(range(0, height, step))
162            if ys[-1] < height - 1:
163                ys.append(height - 1)
164
165            xs = list(range(0, width, step))
166            if xs[0] == 0:
167                xs.pop(0)
168            if xs[-1] == width - 1:
169                xs.pop()
170
171            corner_points = PointList()
172
173            for x in (0, width - 1):
174                for y in ys:
175                    corner_points.append(Point.create(y=y, x=x))
176            for y in (0, height - 1):
177                for x in xs:
178                    corner_points.append(Point.create(y=y, x=x))
179
180            distortion_result.corner_points = corner_points.to_point_tuple()
181
182        if self.config.force_sample_level_in_full_range:
183            level_min = LEVEL_MIN
184            level_max = LEVEL_MAX
185
186        distortion_policies = self.sample_distortion_policies(rng)
187
188        for distortion_policy in distortion_policies:
189            level = rng.integers(level_min, level_max + 1)
190
191            distortion_result = distortion_policy.distort(
192                level=level,
193                shapable_or_shape=distortion_result.shape,
194                image=distortion_result.image,
195                mask=distortion_result.mask,
196                score_map=distortion_result.score_map,
197                point=distortion_result.point,
198                points=distortion_result.points,
199                corner_points=distortion_result.corner_points,
200                polygon=distortion_result.polygon,
201                polygons=distortion_result.polygons,
202                rng=rng,
203                enable_debug=bool(debug),
204            )
205
206            if debug:
207                assert distortion_result.image
208                debug.distortion_images.append(distortion_result.image)
209                debug.distortion_names.append(distortion_policy.name)
210                debug.distortion_levels.append(level)
211                debug.distortion_configs.append(distortion_result.config)
212                debug.distortion_states.append(distortion_result.state)
213
214            distortion_result.config = None
215            distortion_result.state = None
216
217        return distortion_result
90    def __init__(self, config: RandomDistortionStageConfig):
91        self.config = config
92        self.distortion_policy_probs = normalize_to_probs(self.config.distortion_policy_weights)
def sample_distortion_policies( self, rng: numpy.random._generator.Generator) -> Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicy]:
 94    def sample_distortion_policies(self, rng: RandomGenerator) -> Sequence[DistortionPolicy]:
 95        num_distortions = rng.integers(
 96            self.config.num_distortions_min,
 97            self.config.num_distortions_max + 1,
 98        )
 99        if num_distortions <= 0:
100            return ()
101
102        num_retries = 5
103        while num_retries > 0:
104            distortion_policies = rng_choice_with_size(
105                rng,
106                self.config.distortion_policies,
107                size=num_distortions,
108                probs=self.distortion_policy_probs,
109                replace=False,
110            )
111
112            # Conflict analysis.
113            conflict_idx_to_count = defaultdict(int)
114            for distortion_policy in distortion_policies:
115                for conflict_idx, keywords in \
116                        enumerate(self.config.conflict_control_keyword_groups):
117                    match = False
118                    for keyword in keywords:
119                        if keyword in distortion_policy.name:
120                            match = True
121                            break
122                    if match:
123                        conflict_idx_to_count[conflict_idx] += 1
124                        break
125
126            no_conflict = True
127            for count in conflict_idx_to_count.values():
128                if count > 1:
129                    no_conflict = False
130                    logger.debug(
131                        'distortion policies conflict detected '
132                        f'conflict_idx_to_count={conflict_idx_to_count}'
133                    )
134                    break
135
136            if no_conflict:
137                return distortion_policies
138            else:
139                num_retries -= 1
140
141        logger.warning(f'Cannot sample distortion policies with num_distortion={num_distortions}.')
142        return ()
def apply_distortions( self, distortion_result: vkit.mechanism.distortion.interface.DistortionResult, level_min: int, level_max: int, rng: numpy.random._generator.Generator, debug: Union[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionDebug, NoneType] = None):
144    def apply_distortions(
145        self,
146        distortion_result: DistortionResult,
147        level_min: int,
148        level_max: int,
149        rng: RandomGenerator,
150        debug: Optional[RandomDistortionDebug] = None,
151    ):
152        if rng.random() > self.config.prob_enable:
153            return distortion_result
154
155        if self.config.inject_corner_points:
156            height, width = distortion_result.shape
157
158            step = min(height // 4, width // 4)
159            assert step > 0
160
161            ys = list(range(0, height, step))
162            if ys[-1] < height - 1:
163                ys.append(height - 1)
164
165            xs = list(range(0, width, step))
166            if xs[0] == 0:
167                xs.pop(0)
168            if xs[-1] == width - 1:
169                xs.pop()
170
171            corner_points = PointList()
172
173            for x in (0, width - 1):
174                for y in ys:
175                    corner_points.append(Point.create(y=y, x=x))
176            for y in (0, height - 1):
177                for x in xs:
178                    corner_points.append(Point.create(y=y, x=x))
179
180            distortion_result.corner_points = corner_points.to_point_tuple()
181
182        if self.config.force_sample_level_in_full_range:
183            level_min = LEVEL_MIN
184            level_max = LEVEL_MAX
185
186        distortion_policies = self.sample_distortion_policies(rng)
187
188        for distortion_policy in distortion_policies:
189            level = rng.integers(level_min, level_max + 1)
190
191            distortion_result = distortion_policy.distort(
192                level=level,
193                shapable_or_shape=distortion_result.shape,
194                image=distortion_result.image,
195                mask=distortion_result.mask,
196                score_map=distortion_result.score_map,
197                point=distortion_result.point,
198                points=distortion_result.points,
199                corner_points=distortion_result.corner_points,
200                polygon=distortion_result.polygon,
201                polygons=distortion_result.polygons,
202                rng=rng,
203                enable_debug=bool(debug),
204            )
205
206            if debug:
207                assert distortion_result.image
208                debug.distortion_images.append(distortion_result.image)
209                debug.distortion_names.append(distortion_policy.name)
210                debug.distortion_levels.append(level)
211                debug.distortion_configs.append(distortion_result.config)
212                debug.distortion_states.append(distortion_result.state)
213
214            distortion_result.config = None
215            distortion_result.state = None
216
217        return distortion_result
class RandomDistortion:
220class RandomDistortion:
221
222    def __init__(
223        self,
224        configs: Sequence[RandomDistortionStageConfig],
225        level_min: int,
226        level_max: int,
227    ):
228        self.stages = [RandomDistortionStage(config) for config in configs]
229        self.level_min = level_min
230        self.level_max = level_max
231
232    @classmethod
233    def get_distortion_result_all_points(cls, distortion_result: DistortionResult):
234        if distortion_result.corner_points:
235            yield from distortion_result.corner_points
236
237        if distortion_result.point:
238            yield distortion_result.point
239
240        if distortion_result.points:
241            yield from distortion_result.points
242
243        if distortion_result.polygon:
244            yield from distortion_result.polygon.points
245
246        if distortion_result.polygons:
247            for polygon in distortion_result.polygons:
248                yield from polygon.points
249
250    @classmethod
251    def get_distortion_result_element_bounding_box(cls, distortion_result: DistortionResult):
252        assert distortion_result.corner_points
253
254        all_points = cls.get_distortion_result_all_points(distortion_result)
255        point = next(all_points)
256        y_min = point.y
257        y_max = point.y
258        x_min = point.x
259        x_max = point.x
260        for point in all_points:
261            y_min = min(y_min, point.y)
262            y_max = max(y_max, point.y)
263            x_min = min(x_min, point.x)
264            x_max = max(x_max, point.x)
265        return Box(up=y_min, down=y_max, left=x_min, right=x_max)
266
267    @classmethod
268    def trim_distortion_result(cls, distortion_result: DistortionResult):
269        # Trim page if need.
270        if not distortion_result.corner_points:
271            return distortion_result
272
273        height, width = distortion_result.shape
274        box = cls.get_distortion_result_element_bounding_box(distortion_result)
275
276        pad_up = box.up
277        pad_down = height - 1 - box.down
278        # NOTE: accept the rounding error.
279        assert pad_up >= -1 and pad_down >= -1
280
281        pad_left = box.left
282        pad_right = width - 1 - box.right
283        assert pad_left >= -1 and pad_right >= -1
284
285        if pad_up <= 0 and pad_down <= 0 and pad_left <= 0 and pad_right <= 0:
286            return distortion_result
287
288        # Deal with rounding error.
289        up = max(0, box.up)
290        down = min(height - 1, box.down)
291        left = max(0, box.left)
292        right = min(width - 1, box.right)
293
294        pad_up = max(0, pad_up)
295        pad_down = max(0, pad_down)
296        pad_left = max(0, pad_left)
297        pad_right = max(0, pad_right)
298
299        if distortion_result.image:
300            distortion_result.image = distortion_result.image.to_cropped_image(
301                up=up,
302                down=down,
303                left=left,
304                right=right,
305            )
306
307        if distortion_result.mask:
308            distortion_result.mask = distortion_result.mask.to_cropped_mask(
309                up=up,
310                down=down,
311                left=left,
312                right=right,
313            )
314
315        if distortion_result.score_map:
316            distortion_result.score_map = distortion_result.score_map.to_cropped_score_map(
317                up=up,
318                down=down,
319                left=left,
320                right=right,
321            )
322
323        if distortion_result.point:
324            distortion_result.point = distortion_result.point.to_shifted_point(
325                offset_y=-pad_up,
326                offset_x=-pad_left,
327            )
328
329        if distortion_result.points:
330            distortion_result.points = distortion_result.points.to_shifted_points(
331                offset_y=-pad_up,
332                offset_x=-pad_left,
333            )
334
335        if distortion_result.polygon:
336            distortion_result.polygon = distortion_result.polygon.to_shifted_polygon(
337                offset_y=-pad_up,
338                offset_x=-pad_left,
339            )
340
341        if distortion_result.polygons:
342            distortion_result.polygons = [
343                polygon.to_shifted_polygon(
344                    offset_y=-pad_up,
345                    offset_x=-pad_left,
346                ) for polygon in distortion_result.polygons
347            ]
348
349        return distortion_result
350
351    def distort(
352        self,
353        rng: RandomGenerator,
354        shapable_or_shape: Optional[Union[Shapable, Tuple[int, int]]] = None,
355        image: Optional[Image] = None,
356        mask: Optional[Mask] = None,
357        score_map: Optional[ScoreMap] = None,
358        point: Optional[Point] = None,
359        points: Optional[Union[PointList, PointTuple, Iterable[Point]]] = None,
360        polygon: Optional[Polygon] = None,
361        polygons: Optional[Iterable[Polygon]] = None,
362        debug: Optional[RandomDistortionDebug] = None,
363    ):
364        # Pack.
365        shape = Distortion.get_shape(
366            shapable_or_shape=shapable_or_shape,
367            image=image,
368            mask=mask,
369            score_map=score_map,
370        )
371        distortion_result = DistortionResult(shape=shape)
372        distortion_result.image = image
373        distortion_result.mask = mask
374        distortion_result.score_map = score_map
375        distortion_result.point = point
376        distortion_result.points = PointTuple(points) if points else None
377        distortion_result.polygon = polygon
378        if polygons:
379            distortion_result.polygons = tuple(polygons)
380
381        # Apply distortions.
382        for stage in self.stages:
383            distortion_result = stage.apply_distortions(
384                distortion_result=distortion_result,
385                level_min=self.level_min,
386                level_max=self.level_max,
387                rng=rng,
388                debug=debug,
389            )
390
391        distortion_result = self.trim_distortion_result(distortion_result)
392
393        return distortion_result
RandomDistortion( configs: Sequence[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionStageConfig], level_min: int, level_max: int)
222    def __init__(
223        self,
224        configs: Sequence[RandomDistortionStageConfig],
225        level_min: int,
226        level_max: int,
227    ):
228        self.stages = [RandomDistortionStage(config) for config in configs]
229        self.level_min = level_min
230        self.level_max = level_max
@classmethod
def get_distortion_result_all_points( cls, distortion_result: vkit.mechanism.distortion.interface.DistortionResult):
232    @classmethod
233    def get_distortion_result_all_points(cls, distortion_result: DistortionResult):
234        if distortion_result.corner_points:
235            yield from distortion_result.corner_points
236
237        if distortion_result.point:
238            yield distortion_result.point
239
240        if distortion_result.points:
241            yield from distortion_result.points
242
243        if distortion_result.polygon:
244            yield from distortion_result.polygon.points
245
246        if distortion_result.polygons:
247            for polygon in distortion_result.polygons:
248                yield from polygon.points
@classmethod
def get_distortion_result_element_bounding_box( cls, distortion_result: vkit.mechanism.distortion.interface.DistortionResult):
250    @classmethod
251    def get_distortion_result_element_bounding_box(cls, distortion_result: DistortionResult):
252        assert distortion_result.corner_points
253
254        all_points = cls.get_distortion_result_all_points(distortion_result)
255        point = next(all_points)
256        y_min = point.y
257        y_max = point.y
258        x_min = point.x
259        x_max = point.x
260        for point in all_points:
261            y_min = min(y_min, point.y)
262            y_max = max(y_max, point.y)
263            x_min = min(x_min, point.x)
264            x_max = max(x_max, point.x)
265        return Box(up=y_min, down=y_max, left=x_min, right=x_max)
@classmethod
def trim_distortion_result( cls, distortion_result: vkit.mechanism.distortion.interface.DistortionResult):
267    @classmethod
268    def trim_distortion_result(cls, distortion_result: DistortionResult):
269        # Trim page if need.
270        if not distortion_result.corner_points:
271            return distortion_result
272
273        height, width = distortion_result.shape
274        box = cls.get_distortion_result_element_bounding_box(distortion_result)
275
276        pad_up = box.up
277        pad_down = height - 1 - box.down
278        # NOTE: accept the rounding error.
279        assert pad_up >= -1 and pad_down >= -1
280
281        pad_left = box.left
282        pad_right = width - 1 - box.right
283        assert pad_left >= -1 and pad_right >= -1
284
285        if pad_up <= 0 and pad_down <= 0 and pad_left <= 0 and pad_right <= 0:
286            return distortion_result
287
288        # Deal with rounding error.
289        up = max(0, box.up)
290        down = min(height - 1, box.down)
291        left = max(0, box.left)
292        right = min(width - 1, box.right)
293
294        pad_up = max(0, pad_up)
295        pad_down = max(0, pad_down)
296        pad_left = max(0, pad_left)
297        pad_right = max(0, pad_right)
298
299        if distortion_result.image:
300            distortion_result.image = distortion_result.image.to_cropped_image(
301                up=up,
302                down=down,
303                left=left,
304                right=right,
305            )
306
307        if distortion_result.mask:
308            distortion_result.mask = distortion_result.mask.to_cropped_mask(
309                up=up,
310                down=down,
311                left=left,
312                right=right,
313            )
314
315        if distortion_result.score_map:
316            distortion_result.score_map = distortion_result.score_map.to_cropped_score_map(
317                up=up,
318                down=down,
319                left=left,
320                right=right,
321            )
322
323        if distortion_result.point:
324            distortion_result.point = distortion_result.point.to_shifted_point(
325                offset_y=-pad_up,
326                offset_x=-pad_left,
327            )
328
329        if distortion_result.points:
330            distortion_result.points = distortion_result.points.to_shifted_points(
331                offset_y=-pad_up,
332                offset_x=-pad_left,
333            )
334
335        if distortion_result.polygon:
336            distortion_result.polygon = distortion_result.polygon.to_shifted_polygon(
337                offset_y=-pad_up,
338                offset_x=-pad_left,
339            )
340
341        if distortion_result.polygons:
342            distortion_result.polygons = [
343                polygon.to_shifted_polygon(
344                    offset_y=-pad_up,
345                    offset_x=-pad_left,
346                ) for polygon in distortion_result.polygons
347            ]
348
349        return distortion_result
def distort( self, rng: numpy.random._generator.Generator, shapable_or_shape: Union[vkit.element.type.Shapable, Tuple[int, int], NoneType] = None, image: Union[vkit.element.image.Image, NoneType] = None, mask: Union[vkit.element.mask.Mask, NoneType] = None, score_map: Union[vkit.element.score_map.ScoreMap, NoneType] = None, point: Union[vkit.element.point.Point, NoneType] = None, points: Union[vkit.element.point.PointList, vkit.element.point.PointTuple, Iterable[vkit.element.point.Point], NoneType] = None, polygon: Union[vkit.element.polygon.Polygon, NoneType] = None, polygons: Union[Iterable[vkit.element.polygon.Polygon], NoneType] = None, debug: Union[vkit.mechanism.distortion_policy.random_distortion.RandomDistortionDebug, NoneType] = None):
351    def distort(
352        self,
353        rng: RandomGenerator,
354        shapable_or_shape: Optional[Union[Shapable, Tuple[int, int]]] = None,
355        image: Optional[Image] = None,
356        mask: Optional[Mask] = None,
357        score_map: Optional[ScoreMap] = None,
358        point: Optional[Point] = None,
359        points: Optional[Union[PointList, PointTuple, Iterable[Point]]] = None,
360        polygon: Optional[Polygon] = None,
361        polygons: Optional[Iterable[Polygon]] = None,
362        debug: Optional[RandomDistortionDebug] = None,
363    ):
364        # Pack.
365        shape = Distortion.get_shape(
366            shapable_or_shape=shapable_or_shape,
367            image=image,
368            mask=mask,
369            score_map=score_map,
370        )
371        distortion_result = DistortionResult(shape=shape)
372        distortion_result.image = image
373        distortion_result.mask = mask
374        distortion_result.score_map = score_map
375        distortion_result.point = point
376        distortion_result.points = PointTuple(points) if points else None
377        distortion_result.polygon = polygon
378        if polygons:
379            distortion_result.polygons = tuple(polygons)
380
381        # Apply distortions.
382        for stage in self.stages:
383            distortion_result = stage.apply_distortions(
384                distortion_result=distortion_result,
385                level_min=self.level_min,
386                level_max=self.level_max,
387                rng=rng,
388                debug=debug,
389            )
390
391        distortion_result = self.trim_distortion_result(distortion_result)
392
393        return distortion_result
class RandomDistortionFactoryConfig:
397class RandomDistortionFactoryConfig:
398    # Photometric.
399    prob_photometric: float = 1.0
400    num_photometric_min: int = 0
401    num_photometric_max: int = 2
402    photometric_conflict_control_keyword_groups: Sequence[Sequence[str]] = attrs.field(
403        factory=lambda: [
404            [
405                'blur',
406                'pixelation',
407                'jpeg',
408            ],
409            [
410                'noise',
411            ],
412        ]
413    )
414    # Geometric.
415    prob_geometric: float = 0.75
416    force_post_rotate: bool = False
417    # Shared.
418    level_min: int = LEVEL_MIN
419    level_max: int = LEVEL_MAX
420    disabled_policy_names: Sequence[str] = attrs.field(factory=list)
421    name_to_policy_config: Mapping[str, Any] = attrs.field(factory=dict)
422    name_to_policy_weight: Mapping[str, float] = attrs.field(factory=dict)
RandomDistortionFactoryConfig( prob_photometric: float = 1.0, num_photometric_min: int = 0, num_photometric_max: int = 2, photometric_conflict_control_keyword_groups: Sequence[Sequence[str]] = NOTHING, prob_geometric: float = 0.75, force_post_rotate: bool = False, level_min: int = 1, level_max: int = 10, disabled_policy_names: Sequence[str] = NOTHING, name_to_policy_config: Mapping[str, Any] = NOTHING, name_to_policy_weight: Mapping[str, float] = NOTHING)
 2def __init__(self, prob_photometric=attr_dict['prob_photometric'].default, num_photometric_min=attr_dict['num_photometric_min'].default, num_photometric_max=attr_dict['num_photometric_max'].default, photometric_conflict_control_keyword_groups=NOTHING, prob_geometric=attr_dict['prob_geometric'].default, force_post_rotate=attr_dict['force_post_rotate'].default, level_min=attr_dict['level_min'].default, level_max=attr_dict['level_max'].default, disabled_policy_names=NOTHING, name_to_policy_config=NOTHING, name_to_policy_weight=NOTHING):
 3    self.prob_photometric = prob_photometric
 4    self.num_photometric_min = num_photometric_min
 5    self.num_photometric_max = num_photometric_max
 6    if photometric_conflict_control_keyword_groups is not NOTHING:
 7        self.photometric_conflict_control_keyword_groups = photometric_conflict_control_keyword_groups
 8    else:
 9        self.photometric_conflict_control_keyword_groups = __attr_factory_photometric_conflict_control_keyword_groups()
10    self.prob_geometric = prob_geometric
11    self.force_post_rotate = force_post_rotate
12    self.level_min = level_min
13    self.level_max = level_max
14    if disabled_policy_names is not NOTHING:
15        self.disabled_policy_names = disabled_policy_names
16    else:
17        self.disabled_policy_names = __attr_factory_disabled_policy_names()
18    if name_to_policy_config is not NOTHING:
19        self.name_to_policy_config = name_to_policy_config
20    else:
21        self.name_to_policy_config = __attr_factory_name_to_policy_config()
22    if name_to_policy_weight is not NOTHING:
23        self.name_to_policy_weight = name_to_policy_weight
24    else:
25        self.name_to_policy_weight = __attr_factory_name_to_policy_weight()

Method generated by attrs for class RandomDistortionFactoryConfig.

class RandomDistortionFactory:
505class RandomDistortionFactory:
506
507    @classmethod
508    def unpack_policy_factories_and_default_weights_sum_pairs(
509        cls,
510        policy_factories_and_default_weights_sum_pairs: Sequence[
511            Tuple[
512                Sequence[DistortionPolicyFactory],
513                float,
514            ]
515        ]
516    ):  # yapf: disable
517        flatten_policy_factories: List[DistortionPolicyFactory] = []
518        flatten_policy_default_weights: List[float] = []
519
520        for policy_factories, default_weights_sum in policy_factories_and_default_weights_sum_pairs:
521            default_weight = default_weights_sum / len(policy_factories)
522            flatten_policy_factories.extend(policy_factories)
523            flatten_policy_default_weights.extend([default_weight] * len(policy_factories))
524
525        assert len(flatten_policy_factories) == len(flatten_policy_default_weights)
526        return flatten_policy_factories, flatten_policy_default_weights
527
528    def __init__(
529        self,
530        photometric_policy_factories_and_default_weights_sum_pairs: Sequence[
531            Tuple[
532                Sequence[DistortionPolicyFactory],
533                float,
534            ]
535        ] = _PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS,
536        geometric_policy_factories_and_default_weights_sum_pairs: Sequence[
537            Tuple[
538                Sequence[DistortionPolicyFactory],
539                float,
540            ]
541        ] = _GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS,
542    ):  # yapf: disable
543        (
544            self.photometric_policy_factories,
545            self.photometric_policy_default_weights,
546        ) = self.unpack_policy_factories_and_default_weights_sum_pairs(
547            photometric_policy_factories_and_default_weights_sum_pairs
548        )
549
550        (
551            self.geometric_policy_factories,
552            self.geometric_policy_default_weights,
553        ) = self.unpack_policy_factories_and_default_weights_sum_pairs(
554            geometric_policy_factories_and_default_weights_sum_pairs
555        )
556
557    @classmethod
558    def create_policies_and_policy_weights(
559        cls,
560        policy_factories: Sequence[DistortionPolicyFactory],
561        policy_default_weights: Sequence[float],
562        config: RandomDistortionFactoryConfig,
563    ):
564        disabled_policy_names = set(config.disabled_policy_names)
565
566        policies: List[DistortionPolicy] = []
567        policy_weights: List[float] = []
568
569        for policy_factory, policy_default_weight in zip(policy_factories, policy_default_weights):
570            if policy_factory.name in disabled_policy_names:
571                continue
572
573            policy_config = config.name_to_policy_config.get(policy_factory.name)
574            policies.append(policy_factory.create(policy_config))
575
576            policy_weight = policy_default_weight
577            if policy_factory.name in config.name_to_policy_weight:
578                policy_weight = config.name_to_policy_weight[policy_factory.name]
579            policy_weights.append(policy_weight)
580
581        return policies, policy_weights
582
583    def create(
584        self,
585        config: Optional[
586            Union[
587                Mapping[str, Any],
588                PathType,
589                RandomDistortionFactoryConfig,
590            ]
591        ] = None,
592    ):  # yapf: disable
593        config = dyn_structure(
594            config,
595            RandomDistortionFactoryConfig,
596            support_path_type=True,
597            support_none_type=True,
598        )
599
600        stage_configs: List[RandomDistortionStageConfig] = []
601
602        # Photometric.
603        (
604            photometric_policies,
605            photometric_policy_weights,
606        ) = self.create_policies_and_policy_weights(
607            self.photometric_policy_factories,
608            self.photometric_policy_default_weights,
609            config,
610        )
611        stage_configs.append(
612            RandomDistortionStageConfig(
613                distortion_policies=photometric_policies,
614                distortion_policy_weights=photometric_policy_weights,
615                prob_enable=config.prob_photometric,
616                num_distortions_min=config.num_photometric_min,
617                num_distortions_max=config.num_photometric_max,
618                conflict_control_keyword_groups=config.photometric_conflict_control_keyword_groups,
619            )
620        )
621
622        # Geometric.
623        (
624            geometric_policies,
625            geometric_policy_weights,
626        ) = self.create_policies_and_policy_weights(
627            self.geometric_policy_factories,
628            self.geometric_policy_default_weights,
629            config,
630        )
631
632        post_rotate_policy = None
633        if config.force_post_rotate:
634            rotate_policy_idx = -1
635            for geometric_policy_idx, geometric_policy in enumerate(geometric_policies):
636                if geometric_policy.name == 'rotate':
637                    rotate_policy_idx = geometric_policy_idx
638                    break
639            assert rotate_policy_idx >= 0
640            post_rotate_policy = geometric_policies.pop(rotate_policy_idx)
641            geometric_policy_weights.pop(rotate_policy_idx)
642
643        stage_configs.append(
644            RandomDistortionStageConfig(
645                distortion_policies=geometric_policies,
646                distortion_policy_weights=geometric_policy_weights,
647                prob_enable=config.prob_geometric,
648                num_distortions_min=1,
649                num_distortions_max=1,
650                inject_corner_points=config.force_post_rotate,
651            )
652        )
653        if post_rotate_policy:
654            stage_configs.append(
655                RandomDistortionStageConfig(
656                    distortion_policies=[post_rotate_policy],
657                    distortion_policy_weights=[1.0],
658                    prob_enable=1.0,
659                    num_distortions_min=1,
660                    num_distortions_max=1,
661                    force_sample_level_in_full_range=True,
662                )
663            )
664
665        return RandomDistortion(
666            configs=stage_configs,
667            level_min=config.level_min,
668            level_max=config.level_max,
669        )
RandomDistortionFactory( photometric_policy_factories_and_default_weights_sum_pairs: Sequence[Tuple[Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], float]] = (((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36e50>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36df0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36eb0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd362b0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36220>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36f70>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3bdc0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b340>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b100>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b280>), 10.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36790>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4f9d0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b9a0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b400>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b220>), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd36be0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd3b3a0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4f4c0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4fca0>), 3.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4fe80>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd62c10>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd62130>), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd62f70>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd4fb50>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd735e0>), 1.0)), geometric_policy_factories_and_default_weights_sum_pairs: Sequence[Tuple[Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], float]] = (((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd73280>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd070d0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd732e0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd737f0>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95cd0>), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95850>,), 1.0), ((<vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd73b80>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95d90>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95d30>, <vkit.mechanism.distortion_policy.type.DistortionPolicyFactory object at 0x7f083bd95610>), 1.0)))
528    def __init__(
529        self,
530        photometric_policy_factories_and_default_weights_sum_pairs: Sequence[
531            Tuple[
532                Sequence[DistortionPolicyFactory],
533                float,
534            ]
535        ] = _PHOTOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS,
536        geometric_policy_factories_and_default_weights_sum_pairs: Sequence[
537            Tuple[
538                Sequence[DistortionPolicyFactory],
539                float,
540            ]
541        ] = _GEOMETRIC_POLICY_FACTORIES_AND_DEFAULT_WEIGHTS_SUM_PAIRS,
542    ):  # yapf: disable
543        (
544            self.photometric_policy_factories,
545            self.photometric_policy_default_weights,
546        ) = self.unpack_policy_factories_and_default_weights_sum_pairs(
547            photometric_policy_factories_and_default_weights_sum_pairs
548        )
549
550        (
551            self.geometric_policy_factories,
552            self.geometric_policy_default_weights,
553        ) = self.unpack_policy_factories_and_default_weights_sum_pairs(
554            geometric_policy_factories_and_default_weights_sum_pairs
555        )
@classmethod
def unpack_policy_factories_and_default_weights_sum_pairs( cls, policy_factories_and_default_weights_sum_pairs: Sequence[Tuple[Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], float]]):
507    @classmethod
508    def unpack_policy_factories_and_default_weights_sum_pairs(
509        cls,
510        policy_factories_and_default_weights_sum_pairs: Sequence[
511            Tuple[
512                Sequence[DistortionPolicyFactory],
513                float,
514            ]
515        ]
516    ):  # yapf: disable
517        flatten_policy_factories: List[DistortionPolicyFactory] = []
518        flatten_policy_default_weights: List[float] = []
519
520        for policy_factories, default_weights_sum in policy_factories_and_default_weights_sum_pairs:
521            default_weight = default_weights_sum / len(policy_factories)
522            flatten_policy_factories.extend(policy_factories)
523            flatten_policy_default_weights.extend([default_weight] * len(policy_factories))
524
525        assert len(flatten_policy_factories) == len(flatten_policy_default_weights)
526        return flatten_policy_factories, flatten_policy_default_weights
@classmethod
def create_policies_and_policy_weights( cls, policy_factories: Sequence[vkit.mechanism.distortion_policy.type.DistortionPolicyFactory], policy_default_weights: Sequence[float], config: vkit.mechanism.distortion_policy.random_distortion.RandomDistortionFactoryConfig):
557    @classmethod
558    def create_policies_and_policy_weights(
559        cls,
560        policy_factories: Sequence[DistortionPolicyFactory],
561        policy_default_weights: Sequence[float],
562        config: RandomDistortionFactoryConfig,
563    ):
564        disabled_policy_names = set(config.disabled_policy_names)
565
566        policies: List[DistortionPolicy] = []
567        policy_weights: List[float] = []
568
569        for policy_factory, policy_default_weight in zip(policy_factories, policy_default_weights):
570            if policy_factory.name in disabled_policy_names:
571                continue
572
573            policy_config = config.name_to_policy_config.get(policy_factory.name)
574            policies.append(policy_factory.create(policy_config))
575
576            policy_weight = policy_default_weight
577            if policy_factory.name in config.name_to_policy_weight:
578                policy_weight = config.name_to_policy_weight[policy_factory.name]
579            policy_weights.append(policy_weight)
580
581        return policies, policy_weights
def create( self, config: Union[Mapping[str, Any], str, os.PathLike, vkit.mechanism.distortion_policy.random_distortion.RandomDistortionFactoryConfig, NoneType] = None):
583    def create(
584        self,
585        config: Optional[
586            Union[
587                Mapping[str, Any],
588                PathType,
589                RandomDistortionFactoryConfig,
590            ]
591        ] = None,
592    ):  # yapf: disable
593        config = dyn_structure(
594            config,
595            RandomDistortionFactoryConfig,
596            support_path_type=True,
597            support_none_type=True,
598        )
599
600        stage_configs: List[RandomDistortionStageConfig] = []
601
602        # Photometric.
603        (
604            photometric_policies,
605            photometric_policy_weights,
606        ) = self.create_policies_and_policy_weights(
607            self.photometric_policy_factories,
608            self.photometric_policy_default_weights,
609            config,
610        )
611        stage_configs.append(
612            RandomDistortionStageConfig(
613                distortion_policies=photometric_policies,
614                distortion_policy_weights=photometric_policy_weights,
615                prob_enable=config.prob_photometric,
616                num_distortions_min=config.num_photometric_min,
617                num_distortions_max=config.num_photometric_max,
618                conflict_control_keyword_groups=config.photometric_conflict_control_keyword_groups,
619            )
620        )
621
622        # Geometric.
623        (
624            geometric_policies,
625            geometric_policy_weights,
626        ) = self.create_policies_and_policy_weights(
627            self.geometric_policy_factories,
628            self.geometric_policy_default_weights,
629            config,
630        )
631
632        post_rotate_policy = None
633        if config.force_post_rotate:
634            rotate_policy_idx = -1
635            for geometric_policy_idx, geometric_policy in enumerate(geometric_policies):
636                if geometric_policy.name == 'rotate':
637                    rotate_policy_idx = geometric_policy_idx
638                    break
639            assert rotate_policy_idx >= 0
640            post_rotate_policy = geometric_policies.pop(rotate_policy_idx)
641            geometric_policy_weights.pop(rotate_policy_idx)
642
643        stage_configs.append(
644            RandomDistortionStageConfig(
645                distortion_policies=geometric_policies,
646                distortion_policy_weights=geometric_policy_weights,
647                prob_enable=config.prob_geometric,
648                num_distortions_min=1,
649                num_distortions_max=1,
650                inject_corner_points=config.force_post_rotate,
651            )
652        )
653        if post_rotate_policy:
654            stage_configs.append(
655                RandomDistortionStageConfig(
656                    distortion_policies=[post_rotate_policy],
657                    distortion_policy_weights=[1.0],
658                    prob_enable=1.0,
659                    num_distortions_min=1,
660                    num_distortions_max=1,
661                    force_sample_level_in_full_range=True,
662                )
663            )
664
665        return RandomDistortion(
666            configs=stage_configs,
667            level_min=config.level_min,
668            level_max=config.level_max,
669        )