vkit.utility.opt
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import ( 15 get_args, 16 cast, 17 Sequence, 18 Optional, 19 TypeVar, 20 Any, 21 Type, 22 Union, 23 Tuple, 24 Mapping, 25 List, 26) 27import subprocess 28from os import PathLike 29from collections import abc 30import re 31 32from numpy.random import Generator as RandomGenerator 33import cv2 as cv 34import iolite as io 35import attrs 36import cattrs 37from cattrs.errors import ClassValidationError 38 39from .type import PathType 40 41 42def attrs_lazy_field(): 43 return attrs.field(default=None, init=False, repr=False) 44 45 46_T_FIELD = TypeVar('_T_FIELD') 47 48 49def unwrap_optional_field(field: Optional[_T_FIELD]) -> _T_FIELD: 50 assert field is not None 51 return field 52 53 54def get_cattrs_converter_ignoring_init_equals_false(): 55 converter = cattrs.Converter() 56 converter.register_unstructure_hook_factory( 57 attrs.has, 58 lambda cl: cattrs.gen.make_dict_unstructure_fn( 59 cl, 60 converter, 61 **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init}, 62 ), 63 ) 64 return converter 65 66 67def is_path_type(path: Any): 68 return isinstance(path, (str, PathLike)) # type: ignore 69 70 71def read_json_file(path: PathType): 72 return io.read_json(path, expandvars=True) 73 74 75def get_data_folder(file: PathType): 76 proc = subprocess.run( 77 f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"', 78 shell=True, 79 capture_output=True, 80 text=True, 81 ) 82 assert proc.returncode == 0 83 84 data_folder = proc.stdout.strip() 85 assert data_folder 86 87 io.folder(data_folder, touch=True) 88 89 return data_folder 90 91 92_T_ITEM = TypeVar('_T_ITEM') 93 94 95def rng_choice( 96 rng: RandomGenerator, 97 items: Sequence[_T_ITEM], 98 probs: Optional[Sequence[float]] = None, 99) -> _T_ITEM: 100 idx = rng.choice(len(items), p=probs) 101 return items[idx] 102 103 104def rng_choice_with_size( 105 rng: RandomGenerator, 106 items: Sequence[_T_ITEM], 107 size: int, 108 probs: Optional[Sequence[float]] = None, 109 replace: bool = True, 110) -> Sequence[_T_ITEM]: 111 # NOTE: Without replacement! 112 indices = rng.choice(len(items), p=probs, size=size, replace=replace) 113 return [items[idx] for idx in indices] 114 115 116def rng_shuffle( 117 rng: RandomGenerator, 118 items: Sequence[_T_ITEM], 119) -> Sequence[_T_ITEM]: 120 indices = list(range(len(items))) 121 rng.shuffle(indices) 122 return tuple(items[idx] for idx in indices) 123 124 125_CV_INTER_FLAGS = cast( 126 Sequence[int], 127 ( 128 # NOTE: Keep the EXACT version. 129 # cv.INTER_NEAREST, 130 # NOTE: this one is Any. 131 cv.INTER_NEAREST_EXACT, 132 # NOTE: Keep the EXACT version. 133 # cv.INTER_LINEAR, 134 cv.INTER_LINEAR_EXACT, 135 cv.INTER_CUBIC, 136 cv.INTER_LANCZOS4, 137 ), 138) 139 140 141def sample_cv_resize_interpolation( 142 rng: RandomGenerator, 143 include_cv_inter_area: bool = False, 144): 145 flags = _CV_INTER_FLAGS 146 if include_cv_inter_area: 147 flags = (*_CV_INTER_FLAGS, cv.INTER_AREA) 148 return rng_choice(rng, flags) 149 150 151_T_TARGET = TypeVar('_T_TARGET') 152 153_cattrs = cattrs.GenConverter(forbid_extra_keys=True) 154 155# Do nothing for config field. 156_cattrs.register_structure_hook( 157 Optional[Union[Mapping[str, Any], str, PathLike]], 158 lambda d, _: d, 159) 160 161 162def dyn_structure( 163 dyn_object: Any, 164 target_cls: Type[_T_TARGET], 165 support_path_type: bool = False, 166 force_path_type: bool = False, 167 support_none_type: bool = False, 168) -> _T_TARGET: 169 if support_none_type and dyn_object is None: 170 return target_cls() 171 172 if support_path_type or force_path_type: 173 dyn_object_is_path_type = is_path_type(dyn_object) 174 if force_path_type: 175 assert dyn_object_is_path_type 176 if dyn_object_is_path_type: 177 dyn_object = read_json_file(dyn_object) 178 179 isinstance_target_cls = False 180 try: 181 if isinstance(dyn_object, target_cls): 182 isinstance_target_cls = True 183 except TypeError: 184 # target_cls could be type annotation like Sequence[int]. 185 pass 186 187 if isinstance_target_cls: 188 # Do nothing. 189 pass 190 elif isinstance(dyn_object, abc.Mapping): 191 try: 192 dyn_object = _cattrs.structure(dyn_object, target_cls) 193 except ClassValidationError: 194 # cattrs cannot handle Class with hierarchy structure, 195 # in such case, fallback to manually initialization. 196 dyn_object = target_cls(**dyn_object) 197 elif isinstance(dyn_object, abc.Sequence): 198 dyn_object = _cattrs.structure(dyn_object, target_cls) 199 else: 200 raise NotImplementedError() 201 202 return dyn_object 203 204 205def normalize_to_probs(weights: Sequence[float]): 206 total = sum(weights) 207 probs = [weight / total for weight in weights] 208 return probs 209 210 211_T_KEY = TypeVar('_T_KEY') 212 213 214def normalize_to_keys_and_probs( 215 key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]] 216) -> Tuple[Sequence[_T_KEY], Sequence[float]]: 217 keys: List[_T_KEY] = [] 218 weights: List[float] = [] 219 220 if isinstance(key_weight_items, abc.Sequence): 221 for key, weight in key_weight_items: 222 keys.append(key) 223 weights.append(weight) 224 elif isinstance(key_weight_items, abc.Mapping): # type: ignore 225 for key, weight in key_weight_items.items(): 226 keys.append(key) 227 weights.append(weight) 228 else: 229 raise NotImplementedError() 230 231 probs = normalize_to_probs(weights) 232 return keys, probs 233 234 235def convert_camel_case_name_to_snake_case_name(name: str): 236 return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower() 237 238 239def get_config_class_snake_case_name(class_name: str): 240 snake_case_name = convert_camel_case_name_to_snake_case_name(class_name) 241 if snake_case_name.endswith('_config'): 242 snake_case_name = snake_case_name[:-len('_config')] 243 return snake_case_name 244 245 246def get_generic_classes(cls: Type[Any]): 247 return get_args(cls.__orig_bases__[0]) # type: ignore
def
attrs_lazy_field():
def
unwrap_optional_field(field: Union[~_T_FIELD, NoneType]) -> ~_T_FIELD:
def
get_cattrs_converter_ignoring_init_equals_false():
55def get_cattrs_converter_ignoring_init_equals_false(): 56 converter = cattrs.Converter() 57 converter.register_unstructure_hook_factory( 58 attrs.has, 59 lambda cl: cattrs.gen.make_dict_unstructure_fn( 60 cl, 61 converter, 62 **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init}, 63 ), 64 ) 65 return converter
def
is_path_type(path: Any):
def
read_json_file(path: Union[str, os.PathLike]):
def
get_data_folder(file: Union[str, os.PathLike]):
76def get_data_folder(file: PathType): 77 proc = subprocess.run( 78 f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"', 79 shell=True, 80 capture_output=True, 81 text=True, 82 ) 83 assert proc.returncode == 0 84 85 data_folder = proc.stdout.strip() 86 assert data_folder 87 88 io.folder(data_folder, touch=True) 89 90 return data_folder
def
rng_choice( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], probs: Union[Sequence[float], NoneType] = None) -> ~_T_ITEM:
def
rng_choice_with_size( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], size: int, probs: Union[Sequence[float], NoneType] = None, replace: bool = True) -> Sequence[~_T_ITEM]:
105def rng_choice_with_size( 106 rng: RandomGenerator, 107 items: Sequence[_T_ITEM], 108 size: int, 109 probs: Optional[Sequence[float]] = None, 110 replace: bool = True, 111) -> Sequence[_T_ITEM]: 112 # NOTE: Without replacement! 113 indices = rng.choice(len(items), p=probs, size=size, replace=replace) 114 return [items[idx] for idx in indices]
def
rng_shuffle( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM]) -> Sequence[~_T_ITEM]:
def
sample_cv_resize_interpolation( rng: numpy.random._generator.Generator, include_cv_inter_area: bool = False):
def
dyn_structure( dyn_object: Any, target_cls: Type[~_T_TARGET], support_path_type: bool = False, force_path_type: bool = False, support_none_type: bool = False) -> ~_T_TARGET:
163def dyn_structure( 164 dyn_object: Any, 165 target_cls: Type[_T_TARGET], 166 support_path_type: bool = False, 167 force_path_type: bool = False, 168 support_none_type: bool = False, 169) -> _T_TARGET: 170 if support_none_type and dyn_object is None: 171 return target_cls() 172 173 if support_path_type or force_path_type: 174 dyn_object_is_path_type = is_path_type(dyn_object) 175 if force_path_type: 176 assert dyn_object_is_path_type 177 if dyn_object_is_path_type: 178 dyn_object = read_json_file(dyn_object) 179 180 isinstance_target_cls = False 181 try: 182 if isinstance(dyn_object, target_cls): 183 isinstance_target_cls = True 184 except TypeError: 185 # target_cls could be type annotation like Sequence[int]. 186 pass 187 188 if isinstance_target_cls: 189 # Do nothing. 190 pass 191 elif isinstance(dyn_object, abc.Mapping): 192 try: 193 dyn_object = _cattrs.structure(dyn_object, target_cls) 194 except ClassValidationError: 195 # cattrs cannot handle Class with hierarchy structure, 196 # in such case, fallback to manually initialization. 197 dyn_object = target_cls(**dyn_object) 198 elif isinstance(dyn_object, abc.Sequence): 199 dyn_object = _cattrs.structure(dyn_object, target_cls) 200 else: 201 raise NotImplementedError() 202 203 return dyn_object
def
normalize_to_probs(weights: Sequence[float]):
def
normalize_to_keys_and_probs( key_weight_items: Union[Sequence[Tuple[~_T_KEY, float]], Mapping[~_T_KEY, float]]) -> Tuple[Sequence[~_T_KEY], Sequence[float]]:
215def normalize_to_keys_and_probs( 216 key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]] 217) -> Tuple[Sequence[_T_KEY], Sequence[float]]: 218 keys: List[_T_KEY] = [] 219 weights: List[float] = [] 220 221 if isinstance(key_weight_items, abc.Sequence): 222 for key, weight in key_weight_items: 223 keys.append(key) 224 weights.append(weight) 225 elif isinstance(key_weight_items, abc.Mapping): # type: ignore 226 for key, weight in key_weight_items.items(): 227 keys.append(key) 228 weights.append(weight) 229 else: 230 raise NotImplementedError() 231 232 probs = normalize_to_probs(weights) 233 return keys, probs
def
convert_camel_case_name_to_snake_case_name(name: str):
def
get_config_class_snake_case_name(class_name: str):
def
get_generic_classes(cls: Type[Any]):