vkit.mechanism.distortion.geometric.mls

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Tuple, Optional
 15
 16import numpy as np
 17from numpy.random import Generator as RandomGenerator
 18import attrs
 19
 20from vkit.element import Point, PointTuple
 21from ..interface import DistortionConfig
 22from .grid_rendering.interface import (
 23    PointProjector,
 24    DistortionStateImageGridBased,
 25    DistortionImageGridBased,
 26)
 27from .grid_rendering.grid_creator import create_src_image_grid
 28
 29
 30@attrs.define
 31class SimilarityMlsConfig(DistortionConfig):
 32    src_handle_points: PointTuple
 33    dst_handle_points: PointTuple
 34    grid_size: int
 35    resize_as_src: bool = False
 36
 37
 38class SimilarityMlsPointProjector(PointProjector):
 39
 40    def __init__(self, src_handle_points: PointTuple, dst_handle_points: PointTuple):
 41        self.src_handle_points = src_handle_points
 42        self.dst_handle_points = dst_handle_points
 43
 44        self.src_xy_pair_to_dst_point = {
 45            (src_point.smooth_x, src_point.smooth_y): dst_point
 46            for src_point, dst_point in zip(src_handle_points, dst_handle_points)
 47        }
 48
 49        self.src_handle_np_points = src_handle_points.to_smooth_np_array()
 50        self.dst_handle_np_points = dst_handle_points.to_smooth_np_array()
 51
 52    def project_point(self, src_point: Point):
 53        '''
 54        Calculate the corresponding dst point given the src point.
 55        Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf
 56        '''
 57        src_xy_pair = (src_point.smooth_x, src_point.smooth_y)
 58
 59        if src_xy_pair in self.src_xy_pair_to_dst_point:
 60            # Identity.
 61            return self.src_xy_pair_to_dst_point[src_xy_pair]
 62
 63        # Calculate the distance to src handles.
 64        src_distance_squares = self.src_handle_np_points.copy()
 65        src_distance_squares[:, 0] -= src_point.smooth_x
 66        src_distance_squares[:, 1] -= src_point.smooth_y
 67        np.square(src_distance_squares, out=src_distance_squares)
 68        # (N), and should not contain 0.0.
 69        src_distance_squares = np.sum(src_distance_squares, axis=1)
 70
 71        # Calculate weights based on distances.
 72        # (N), and should not contain inf.
 73        with np.errstate(divide='raise'):
 74            src_distance_squares_inverse = 1 / src_distance_squares
 75            weights = src_distance_squares_inverse / np.sum(src_distance_squares_inverse)
 76
 77        # (2), the weighted centroids.
 78        src_centroid = np.matmul(weights, self.src_handle_np_points)
 79        dst_centroid = np.matmul(weights, self.dst_handle_np_points)
 80
 81        # (N, 2)
 82        src_hat = self.src_handle_np_points - src_centroid
 83        dst_hat = self.dst_handle_np_points - dst_centroid
 84
 85        # (N, 2)
 86        src_hat_vert = src_hat[:, [1, 0]]
 87        src_hat_vert[:, 0] *= -1
 88
 89        # Calculate matrix A.
 90        src_centroid_x, src_centroid_y = src_centroid
 91        src_mat_anchor = np.transpose(
 92            np.asarray(
 93                [
 94                    # v - p*
 95                    (
 96                        src_point.smooth_x - src_centroid_x,
 97                        src_point.smooth_y - src_centroid_y,
 98                    ),
 99                    # -(v - p*)^vert
100                    (
101                        src_point.smooth_y - src_centroid_y,
102                        -(src_point.smooth_x - src_centroid_x),
103                    ),
104                ],
105                dtype=np.float32,
106            )
107        )
108        # (N, 2)
109        src_mat_row0 = np.matmul(src_hat, src_mat_anchor)
110        src_mat_row1 = np.matmul(-src_hat_vert, src_mat_anchor)
111        # (N, 2, 2)
112        src_mat = (
113            np.expand_dims(np.expand_dims(src_distance_squares_inverse, axis=1), axis=1)
114            * np.stack((src_mat_row0, src_mat_row1), axis=1)
115        )
116
117        # Calculate the point in dst.
118        # (N, 2)
119        dst_prod = np.squeeze(
120            # (N, 1, 2)
121            np.matmul(
122                # (N, 1, 2)
123                np.expand_dims(dst_hat, axis=1),
124                # (N, 2, 2)
125                src_mat,
126            ),
127            axis=1,
128        )
129        mu = np.sum(src_distance_squares_inverse * np.sum(src_hat * src_hat, axis=1))
130        dst_x, dst_y = np.sum(dst_prod, axis=0) / mu + dst_centroid
131
132        dst_x = float(dst_x)
133        dst_y = float(dst_y)
134
135        return Point.create(y=dst_y, x=dst_x)
136
137
138class SimilarityMlsState(DistortionStateImageGridBased[SimilarityMlsConfig]):
139
140    def __init__(
141        self,
142        config: SimilarityMlsConfig,
143        shape: Tuple[int, int],
144        rng: Optional[RandomGenerator],
145    ):
146        height, width = shape
147        self.initialize_image_grid_based(
148            create_src_image_grid(height, width, config.grid_size),
149            SimilarityMlsPointProjector(
150                config.src_handle_points,
151                config.dst_handle_points,
152            ),
153            resize_as_src=config.resize_as_src,
154        )
155
156        # For debug only.
157        self.dst_handle_points = list(map(self.shift_and_resize_point, config.dst_handle_points))
158
159
160similarity_mls = DistortionImageGridBased(
161    config_cls=SimilarityMlsConfig,
162    state_cls=SimilarityMlsState,
163)
class SimilarityMlsConfig(vkit.mechanism.distortion.interface.DistortionConfig):
32class SimilarityMlsConfig(DistortionConfig):
33    src_handle_points: PointTuple
34    dst_handle_points: PointTuple
35    grid_size: int
36    resize_as_src: bool = False
SimilarityMlsConfig( src_handle_points: vkit.element.point.PointTuple, dst_handle_points: vkit.element.point.PointTuple, grid_size: int, resize_as_src: bool = False)
2def __init__(self, src_handle_points, dst_handle_points, grid_size, resize_as_src=attr_dict['resize_as_src'].default):
3    self.src_handle_points = src_handle_points
4    self.dst_handle_points = dst_handle_points
5    self.grid_size = grid_size
6    self.resize_as_src = resize_as_src

Method generated by attrs for class SimilarityMlsConfig.

 39class SimilarityMlsPointProjector(PointProjector):
 40
 41    def __init__(self, src_handle_points: PointTuple, dst_handle_points: PointTuple):
 42        self.src_handle_points = src_handle_points
 43        self.dst_handle_points = dst_handle_points
 44
 45        self.src_xy_pair_to_dst_point = {
 46            (src_point.smooth_x, src_point.smooth_y): dst_point
 47            for src_point, dst_point in zip(src_handle_points, dst_handle_points)
 48        }
 49
 50        self.src_handle_np_points = src_handle_points.to_smooth_np_array()
 51        self.dst_handle_np_points = dst_handle_points.to_smooth_np_array()
 52
 53    def project_point(self, src_point: Point):
 54        '''
 55        Calculate the corresponding dst point given the src point.
 56        Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf
 57        '''
 58        src_xy_pair = (src_point.smooth_x, src_point.smooth_y)
 59
 60        if src_xy_pair in self.src_xy_pair_to_dst_point:
 61            # Identity.
 62            return self.src_xy_pair_to_dst_point[src_xy_pair]
 63
 64        # Calculate the distance to src handles.
 65        src_distance_squares = self.src_handle_np_points.copy()
 66        src_distance_squares[:, 0] -= src_point.smooth_x
 67        src_distance_squares[:, 1] -= src_point.smooth_y
 68        np.square(src_distance_squares, out=src_distance_squares)
 69        # (N), and should not contain 0.0.
 70        src_distance_squares = np.sum(src_distance_squares, axis=1)
 71
 72        # Calculate weights based on distances.
 73        # (N), and should not contain inf.
 74        with np.errstate(divide='raise'):
 75            src_distance_squares_inverse = 1 / src_distance_squares
 76            weights = src_distance_squares_inverse / np.sum(src_distance_squares_inverse)
 77
 78        # (2), the weighted centroids.
 79        src_centroid = np.matmul(weights, self.src_handle_np_points)
 80        dst_centroid = np.matmul(weights, self.dst_handle_np_points)
 81
 82        # (N, 2)
 83        src_hat = self.src_handle_np_points - src_centroid
 84        dst_hat = self.dst_handle_np_points - dst_centroid
 85
 86        # (N, 2)
 87        src_hat_vert = src_hat[:, [1, 0]]
 88        src_hat_vert[:, 0] *= -1
 89
 90        # Calculate matrix A.
 91        src_centroid_x, src_centroid_y = src_centroid
 92        src_mat_anchor = np.transpose(
 93            np.asarray(
 94                [
 95                    # v - p*
 96                    (
 97                        src_point.smooth_x - src_centroid_x,
 98                        src_point.smooth_y - src_centroid_y,
 99                    ),
100                    # -(v - p*)^vert
101                    (
102                        src_point.smooth_y - src_centroid_y,
103                        -(src_point.smooth_x - src_centroid_x),
104                    ),
105                ],
106                dtype=np.float32,
107            )
108        )
109        # (N, 2)
110        src_mat_row0 = np.matmul(src_hat, src_mat_anchor)
111        src_mat_row1 = np.matmul(-src_hat_vert, src_mat_anchor)
112        # (N, 2, 2)
113        src_mat = (
114            np.expand_dims(np.expand_dims(src_distance_squares_inverse, axis=1), axis=1)
115            * np.stack((src_mat_row0, src_mat_row1), axis=1)
116        )
117
118        # Calculate the point in dst.
119        # (N, 2)
120        dst_prod = np.squeeze(
121            # (N, 1, 2)
122            np.matmul(
123                # (N, 1, 2)
124                np.expand_dims(dst_hat, axis=1),
125                # (N, 2, 2)
126                src_mat,
127            ),
128            axis=1,
129        )
130        mu = np.sum(src_distance_squares_inverse * np.sum(src_hat * src_hat, axis=1))
131        dst_x, dst_y = np.sum(dst_prod, axis=0) / mu + dst_centroid
132
133        dst_x = float(dst_x)
134        dst_y = float(dst_y)
135
136        return Point.create(y=dst_y, x=dst_x)
SimilarityMlsPointProjector( src_handle_points: vkit.element.point.PointTuple, dst_handle_points: vkit.element.point.PointTuple)
41    def __init__(self, src_handle_points: PointTuple, dst_handle_points: PointTuple):
42        self.src_handle_points = src_handle_points
43        self.dst_handle_points = dst_handle_points
44
45        self.src_xy_pair_to_dst_point = {
46            (src_point.smooth_x, src_point.smooth_y): dst_point
47            for src_point, dst_point in zip(src_handle_points, dst_handle_points)
48        }
49
50        self.src_handle_np_points = src_handle_points.to_smooth_np_array()
51        self.dst_handle_np_points = dst_handle_points.to_smooth_np_array()
def project_point(self, src_point: vkit.element.point.Point):
 53    def project_point(self, src_point: Point):
 54        '''
 55        Calculate the corresponding dst point given the src point.
 56        Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf
 57        '''
 58        src_xy_pair = (src_point.smooth_x, src_point.smooth_y)
 59
 60        if src_xy_pair in self.src_xy_pair_to_dst_point:
 61            # Identity.
 62            return self.src_xy_pair_to_dst_point[src_xy_pair]
 63
 64        # Calculate the distance to src handles.
 65        src_distance_squares = self.src_handle_np_points.copy()
 66        src_distance_squares[:, 0] -= src_point.smooth_x
 67        src_distance_squares[:, 1] -= src_point.smooth_y
 68        np.square(src_distance_squares, out=src_distance_squares)
 69        # (N), and should not contain 0.0.
 70        src_distance_squares = np.sum(src_distance_squares, axis=1)
 71
 72        # Calculate weights based on distances.
 73        # (N), and should not contain inf.
 74        with np.errstate(divide='raise'):
 75            src_distance_squares_inverse = 1 / src_distance_squares
 76            weights = src_distance_squares_inverse / np.sum(src_distance_squares_inverse)
 77
 78        # (2), the weighted centroids.
 79        src_centroid = np.matmul(weights, self.src_handle_np_points)
 80        dst_centroid = np.matmul(weights, self.dst_handle_np_points)
 81
 82        # (N, 2)
 83        src_hat = self.src_handle_np_points - src_centroid
 84        dst_hat = self.dst_handle_np_points - dst_centroid
 85
 86        # (N, 2)
 87        src_hat_vert = src_hat[:, [1, 0]]
 88        src_hat_vert[:, 0] *= -1
 89
 90        # Calculate matrix A.
 91        src_centroid_x, src_centroid_y = src_centroid
 92        src_mat_anchor = np.transpose(
 93            np.asarray(
 94                [
 95                    # v - p*
 96                    (
 97                        src_point.smooth_x - src_centroid_x,
 98                        src_point.smooth_y - src_centroid_y,
 99                    ),
100                    # -(v - p*)^vert
101                    (
102                        src_point.smooth_y - src_centroid_y,
103                        -(src_point.smooth_x - src_centroid_x),
104                    ),
105                ],
106                dtype=np.float32,
107            )
108        )
109        # (N, 2)
110        src_mat_row0 = np.matmul(src_hat, src_mat_anchor)
111        src_mat_row1 = np.matmul(-src_hat_vert, src_mat_anchor)
112        # (N, 2, 2)
113        src_mat = (
114            np.expand_dims(np.expand_dims(src_distance_squares_inverse, axis=1), axis=1)
115            * np.stack((src_mat_row0, src_mat_row1), axis=1)
116        )
117
118        # Calculate the point in dst.
119        # (N, 2)
120        dst_prod = np.squeeze(
121            # (N, 1, 2)
122            np.matmul(
123                # (N, 1, 2)
124                np.expand_dims(dst_hat, axis=1),
125                # (N, 2, 2)
126                src_mat,
127            ),
128            axis=1,
129        )
130        mu = np.sum(src_distance_squares_inverse * np.sum(src_hat * src_hat, axis=1))
131        dst_x, dst_y = np.sum(dst_prod, axis=0) / mu + dst_centroid
132
133        dst_x = float(dst_x)
134        dst_y = float(dst_y)
135
136        return Point.create(y=dst_y, x=dst_x)

Calculate the corresponding dst point given the src point. Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf

139class SimilarityMlsState(DistortionStateImageGridBased[SimilarityMlsConfig]):
140
141    def __init__(
142        self,
143        config: SimilarityMlsConfig,
144        shape: Tuple[int, int],
145        rng: Optional[RandomGenerator],
146    ):
147        height, width = shape
148        self.initialize_image_grid_based(
149            create_src_image_grid(height, width, config.grid_size),
150            SimilarityMlsPointProjector(
151                config.src_handle_points,
152                config.dst_handle_points,
153            ),
154            resize_as_src=config.resize_as_src,
155        )
156
157        # For debug only.
158        self.dst_handle_points = list(map(self.shift_and_resize_point, config.dst_handle_points))

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

SimilarityMlsState( config: vkit.mechanism.distortion.geometric.mls.SimilarityMlsConfig, shape: Tuple[int, int], rng: Union[numpy.random._generator.Generator, NoneType])
141    def __init__(
142        self,
143        config: SimilarityMlsConfig,
144        shape: Tuple[int, int],
145        rng: Optional[RandomGenerator],
146    ):
147        height, width = shape
148        self.initialize_image_grid_based(
149            create_src_image_grid(height, width, config.grid_size),
150            SimilarityMlsPointProjector(
151                config.src_handle_points,
152                config.dst_handle_points,
153            ),
154            resize_as_src=config.resize_as_src,
155        )
156
157        # For debug only.
158        self.dst_handle_points = list(map(self.shift_and_resize_point, config.dst_handle_points))