vkit.mechanism.distortion.geometric.mls
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Tuple, Optional 15 16import numpy as np 17from numpy.random import Generator as RandomGenerator 18import attrs 19 20from vkit.element import Point, PointTuple 21from ..interface import DistortionConfig 22from .grid_rendering.interface import ( 23 PointProjector, 24 DistortionStateImageGridBased, 25 DistortionImageGridBased, 26) 27from .grid_rendering.grid_creator import create_src_image_grid 28 29 30@attrs.define 31class SimilarityMlsConfig(DistortionConfig): 32 src_handle_points: PointTuple 33 dst_handle_points: PointTuple 34 grid_size: int 35 resize_as_src: bool = False 36 37 38class SimilarityMlsPointProjector(PointProjector): 39 40 def __init__(self, src_handle_points: PointTuple, dst_handle_points: PointTuple): 41 self.src_handle_points = src_handle_points 42 self.dst_handle_points = dst_handle_points 43 44 self.src_xy_pair_to_dst_point = { 45 (src_point.smooth_x, src_point.smooth_y): dst_point 46 for src_point, dst_point in zip(src_handle_points, dst_handle_points) 47 } 48 49 self.src_handle_np_points = src_handle_points.to_smooth_np_array() 50 self.dst_handle_np_points = dst_handle_points.to_smooth_np_array() 51 52 def project_point(self, src_point: Point): 53 ''' 54 Calculate the corresponding dst point given the src point. 55 Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf 56 ''' 57 src_xy_pair = (src_point.smooth_x, src_point.smooth_y) 58 59 if src_xy_pair in self.src_xy_pair_to_dst_point: 60 # Identity. 61 return self.src_xy_pair_to_dst_point[src_xy_pair] 62 63 # Calculate the distance to src handles. 64 src_distance_squares = self.src_handle_np_points.copy() 65 src_distance_squares[:, 0] -= src_point.smooth_x 66 src_distance_squares[:, 1] -= src_point.smooth_y 67 np.square(src_distance_squares, out=src_distance_squares) 68 # (N), and should not contain 0.0. 69 src_distance_squares = np.sum(src_distance_squares, axis=1) 70 71 # Calculate weights based on distances. 72 # (N), and should not contain inf. 73 with np.errstate(divide='raise'): 74 src_distance_squares_inverse = 1 / src_distance_squares 75 weights = src_distance_squares_inverse / np.sum(src_distance_squares_inverse) 76 77 # (2), the weighted centroids. 78 src_centroid = np.matmul(weights, self.src_handle_np_points) 79 dst_centroid = np.matmul(weights, self.dst_handle_np_points) 80 81 # (N, 2) 82 src_hat = self.src_handle_np_points - src_centroid 83 dst_hat = self.dst_handle_np_points - dst_centroid 84 85 # (N, 2) 86 src_hat_vert = src_hat[:, [1, 0]] 87 src_hat_vert[:, 0] *= -1 88 89 # Calculate matrix A. 90 src_centroid_x, src_centroid_y = src_centroid 91 src_mat_anchor = np.transpose( 92 np.asarray( 93 [ 94 # v - p* 95 ( 96 src_point.smooth_x - src_centroid_x, 97 src_point.smooth_y - src_centroid_y, 98 ), 99 # -(v - p*)^vert 100 ( 101 src_point.smooth_y - src_centroid_y, 102 -(src_point.smooth_x - src_centroid_x), 103 ), 104 ], 105 dtype=np.float32, 106 ) 107 ) 108 # (N, 2) 109 src_mat_row0 = np.matmul(src_hat, src_mat_anchor) 110 src_mat_row1 = np.matmul(-src_hat_vert, src_mat_anchor) 111 # (N, 2, 2) 112 src_mat = ( 113 np.expand_dims(np.expand_dims(src_distance_squares_inverse, axis=1), axis=1) 114 * np.stack((src_mat_row0, src_mat_row1), axis=1) 115 ) 116 117 # Calculate the point in dst. 118 # (N, 2) 119 dst_prod = np.squeeze( 120 # (N, 1, 2) 121 np.matmul( 122 # (N, 1, 2) 123 np.expand_dims(dst_hat, axis=1), 124 # (N, 2, 2) 125 src_mat, 126 ), 127 axis=1, 128 ) 129 mu = np.sum(src_distance_squares_inverse * np.sum(src_hat * src_hat, axis=1)) 130 dst_x, dst_y = np.sum(dst_prod, axis=0) / mu + dst_centroid 131 132 dst_x = float(dst_x) 133 dst_y = float(dst_y) 134 135 return Point.create(y=dst_y, x=dst_x) 136 137 138class SimilarityMlsState(DistortionStateImageGridBased[SimilarityMlsConfig]): 139 140 def __init__( 141 self, 142 config: SimilarityMlsConfig, 143 shape: Tuple[int, int], 144 rng: Optional[RandomGenerator], 145 ): 146 height, width = shape 147 self.initialize_image_grid_based( 148 create_src_image_grid(height, width, config.grid_size), 149 SimilarityMlsPointProjector( 150 config.src_handle_points, 151 config.dst_handle_points, 152 ), 153 resize_as_src=config.resize_as_src, 154 ) 155 156 # For debug only. 157 self.dst_handle_points = list(map(self.shift_and_resize_point, config.dst_handle_points)) 158 159 160similarity_mls = DistortionImageGridBased( 161 config_cls=SimilarityMlsConfig, 162 state_cls=SimilarityMlsState, 163)
32class SimilarityMlsConfig(DistortionConfig): 33 src_handle_points: PointTuple 34 dst_handle_points: PointTuple 35 grid_size: int 36 resize_as_src: bool = False
SimilarityMlsConfig( src_handle_points: vkit.element.point.PointTuple, dst_handle_points: vkit.element.point.PointTuple, grid_size: int, resize_as_src: bool = False)
2def __init__(self, src_handle_points, dst_handle_points, grid_size, resize_as_src=attr_dict['resize_as_src'].default): 3 self.src_handle_points = src_handle_points 4 self.dst_handle_points = dst_handle_points 5 self.grid_size = grid_size 6 self.resize_as_src = resize_as_src
Method generated by attrs for class SimilarityMlsConfig.
Inherited Members
class
SimilarityMlsPointProjector(vkit.mechanism.distortion.geometric.grid_rendering.point_projector.PointProjector):
39class SimilarityMlsPointProjector(PointProjector): 40 41 def __init__(self, src_handle_points: PointTuple, dst_handle_points: PointTuple): 42 self.src_handle_points = src_handle_points 43 self.dst_handle_points = dst_handle_points 44 45 self.src_xy_pair_to_dst_point = { 46 (src_point.smooth_x, src_point.smooth_y): dst_point 47 for src_point, dst_point in zip(src_handle_points, dst_handle_points) 48 } 49 50 self.src_handle_np_points = src_handle_points.to_smooth_np_array() 51 self.dst_handle_np_points = dst_handle_points.to_smooth_np_array() 52 53 def project_point(self, src_point: Point): 54 ''' 55 Calculate the corresponding dst point given the src point. 56 Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf 57 ''' 58 src_xy_pair = (src_point.smooth_x, src_point.smooth_y) 59 60 if src_xy_pair in self.src_xy_pair_to_dst_point: 61 # Identity. 62 return self.src_xy_pair_to_dst_point[src_xy_pair] 63 64 # Calculate the distance to src handles. 65 src_distance_squares = self.src_handle_np_points.copy() 66 src_distance_squares[:, 0] -= src_point.smooth_x 67 src_distance_squares[:, 1] -= src_point.smooth_y 68 np.square(src_distance_squares, out=src_distance_squares) 69 # (N), and should not contain 0.0. 70 src_distance_squares = np.sum(src_distance_squares, axis=1) 71 72 # Calculate weights based on distances. 73 # (N), and should not contain inf. 74 with np.errstate(divide='raise'): 75 src_distance_squares_inverse = 1 / src_distance_squares 76 weights = src_distance_squares_inverse / np.sum(src_distance_squares_inverse) 77 78 # (2), the weighted centroids. 79 src_centroid = np.matmul(weights, self.src_handle_np_points) 80 dst_centroid = np.matmul(weights, self.dst_handle_np_points) 81 82 # (N, 2) 83 src_hat = self.src_handle_np_points - src_centroid 84 dst_hat = self.dst_handle_np_points - dst_centroid 85 86 # (N, 2) 87 src_hat_vert = src_hat[:, [1, 0]] 88 src_hat_vert[:, 0] *= -1 89 90 # Calculate matrix A. 91 src_centroid_x, src_centroid_y = src_centroid 92 src_mat_anchor = np.transpose( 93 np.asarray( 94 [ 95 # v - p* 96 ( 97 src_point.smooth_x - src_centroid_x, 98 src_point.smooth_y - src_centroid_y, 99 ), 100 # -(v - p*)^vert 101 ( 102 src_point.smooth_y - src_centroid_y, 103 -(src_point.smooth_x - src_centroid_x), 104 ), 105 ], 106 dtype=np.float32, 107 ) 108 ) 109 # (N, 2) 110 src_mat_row0 = np.matmul(src_hat, src_mat_anchor) 111 src_mat_row1 = np.matmul(-src_hat_vert, src_mat_anchor) 112 # (N, 2, 2) 113 src_mat = ( 114 np.expand_dims(np.expand_dims(src_distance_squares_inverse, axis=1), axis=1) 115 * np.stack((src_mat_row0, src_mat_row1), axis=1) 116 ) 117 118 # Calculate the point in dst. 119 # (N, 2) 120 dst_prod = np.squeeze( 121 # (N, 1, 2) 122 np.matmul( 123 # (N, 1, 2) 124 np.expand_dims(dst_hat, axis=1), 125 # (N, 2, 2) 126 src_mat, 127 ), 128 axis=1, 129 ) 130 mu = np.sum(src_distance_squares_inverse * np.sum(src_hat * src_hat, axis=1)) 131 dst_x, dst_y = np.sum(dst_prod, axis=0) / mu + dst_centroid 132 133 dst_x = float(dst_x) 134 dst_y = float(dst_y) 135 136 return Point.create(y=dst_y, x=dst_x)
SimilarityMlsPointProjector( src_handle_points: vkit.element.point.PointTuple, dst_handle_points: vkit.element.point.PointTuple)
41 def __init__(self, src_handle_points: PointTuple, dst_handle_points: PointTuple): 42 self.src_handle_points = src_handle_points 43 self.dst_handle_points = dst_handle_points 44 45 self.src_xy_pair_to_dst_point = { 46 (src_point.smooth_x, src_point.smooth_y): dst_point 47 for src_point, dst_point in zip(src_handle_points, dst_handle_points) 48 } 49 50 self.src_handle_np_points = src_handle_points.to_smooth_np_array() 51 self.dst_handle_np_points = dst_handle_points.to_smooth_np_array()
53 def project_point(self, src_point: Point): 54 ''' 55 Calculate the corresponding dst point given the src point. 56 Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf 57 ''' 58 src_xy_pair = (src_point.smooth_x, src_point.smooth_y) 59 60 if src_xy_pair in self.src_xy_pair_to_dst_point: 61 # Identity. 62 return self.src_xy_pair_to_dst_point[src_xy_pair] 63 64 # Calculate the distance to src handles. 65 src_distance_squares = self.src_handle_np_points.copy() 66 src_distance_squares[:, 0] -= src_point.smooth_x 67 src_distance_squares[:, 1] -= src_point.smooth_y 68 np.square(src_distance_squares, out=src_distance_squares) 69 # (N), and should not contain 0.0. 70 src_distance_squares = np.sum(src_distance_squares, axis=1) 71 72 # Calculate weights based on distances. 73 # (N), and should not contain inf. 74 with np.errstate(divide='raise'): 75 src_distance_squares_inverse = 1 / src_distance_squares 76 weights = src_distance_squares_inverse / np.sum(src_distance_squares_inverse) 77 78 # (2), the weighted centroids. 79 src_centroid = np.matmul(weights, self.src_handle_np_points) 80 dst_centroid = np.matmul(weights, self.dst_handle_np_points) 81 82 # (N, 2) 83 src_hat = self.src_handle_np_points - src_centroid 84 dst_hat = self.dst_handle_np_points - dst_centroid 85 86 # (N, 2) 87 src_hat_vert = src_hat[:, [1, 0]] 88 src_hat_vert[:, 0] *= -1 89 90 # Calculate matrix A. 91 src_centroid_x, src_centroid_y = src_centroid 92 src_mat_anchor = np.transpose( 93 np.asarray( 94 [ 95 # v - p* 96 ( 97 src_point.smooth_x - src_centroid_x, 98 src_point.smooth_y - src_centroid_y, 99 ), 100 # -(v - p*)^vert 101 ( 102 src_point.smooth_y - src_centroid_y, 103 -(src_point.smooth_x - src_centroid_x), 104 ), 105 ], 106 dtype=np.float32, 107 ) 108 ) 109 # (N, 2) 110 src_mat_row0 = np.matmul(src_hat, src_mat_anchor) 111 src_mat_row1 = np.matmul(-src_hat_vert, src_mat_anchor) 112 # (N, 2, 2) 113 src_mat = ( 114 np.expand_dims(np.expand_dims(src_distance_squares_inverse, axis=1), axis=1) 115 * np.stack((src_mat_row0, src_mat_row1), axis=1) 116 ) 117 118 # Calculate the point in dst. 119 # (N, 2) 120 dst_prod = np.squeeze( 121 # (N, 1, 2) 122 np.matmul( 123 # (N, 1, 2) 124 np.expand_dims(dst_hat, axis=1), 125 # (N, 2, 2) 126 src_mat, 127 ), 128 axis=1, 129 ) 130 mu = np.sum(src_distance_squares_inverse * np.sum(src_hat * src_hat, axis=1)) 131 dst_x, dst_y = np.sum(dst_prod, axis=0) / mu + dst_centroid 132 133 dst_x = float(dst_x) 134 dst_y = float(dst_y) 135 136 return Point.create(y=dst_y, x=dst_x)
Calculate the corresponding dst point given the src point. Paper: https://people.engr.tamu.edu/schaefer/research/mls.pdf
class
SimilarityMlsState(vkit.mechanism.distortion.geometric.grid_rendering.interface.DistortionStateImageGridBased[vkit.mechanism.distortion.geometric.mls.SimilarityMlsConfig]):
139class SimilarityMlsState(DistortionStateImageGridBased[SimilarityMlsConfig]): 140 141 def __init__( 142 self, 143 config: SimilarityMlsConfig, 144 shape: Tuple[int, int], 145 rng: Optional[RandomGenerator], 146 ): 147 height, width = shape 148 self.initialize_image_grid_based( 149 create_src_image_grid(height, width, config.grid_size), 150 SimilarityMlsPointProjector( 151 config.src_handle_points, 152 config.dst_handle_points, 153 ), 154 resize_as_src=config.resize_as_src, 155 ) 156 157 # For debug only. 158 self.dst_handle_points = list(map(self.shift_and_resize_point, config.dst_handle_points))
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
SimilarityMlsState( config: vkit.mechanism.distortion.geometric.mls.SimilarityMlsConfig, shape: Tuple[int, int], rng: Union[numpy.random._generator.Generator, NoneType])
141 def __init__( 142 self, 143 config: SimilarityMlsConfig, 144 shape: Tuple[int, int], 145 rng: Optional[RandomGenerator], 146 ): 147 height, width = shape 148 self.initialize_image_grid_based( 149 create_src_image_grid(height, width, config.grid_size), 150 SimilarityMlsPointProjector( 151 config.src_handle_points, 152 config.dst_handle_points, 153 ), 154 resize_as_src=config.resize_as_src, 155 ) 156 157 # For debug only. 158 self.dst_handle_points = list(map(self.shift_and_resize_point, config.dst_handle_points))