vkit.utility.opt
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import ( 15 get_args, 16 cast, 17 Sequence, 18 Optional, 19 TypeVar, 20 Any, 21 Type, 22 Union, 23 Tuple, 24 Mapping, 25 List, 26) 27import subprocess 28from os import PathLike 29from collections import abc 30import re 31 32from numpy.random import Generator as RandomGenerator 33import cv2 as cv 34import iolite as io 35import attrs 36import cattrs 37from cattrs.errors import ClassValidationError 38 39from vkit.utility import PathType 40 41 42def attrs_lazy_field(): 43 return attrs.field(default=None, init=False, repr=False) 44 45 46def get_cattrs_converter_ignoring_init_equals_false(): 47 converter = cattrs.Converter() 48 converter.register_unstructure_hook_factory( 49 attrs.has, 50 lambda cl: cattrs.gen.make_dict_unstructure_fn( 51 cl, 52 converter, 53 **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init}, 54 ), 55 ) 56 return converter 57 58 59def is_path_type(path: Any): 60 return isinstance(path, (str, PathLike)) # type: ignore 61 62 63def read_json_file(path: PathType): 64 return io.read_json(path, expandvars=True) 65 66 67def get_data_folder(file: PathType): 68 proc = subprocess.run( 69 f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"', 70 shell=True, 71 capture_output=True, 72 text=True, 73 ) 74 assert proc.returncode == 0 75 76 data_folder = proc.stdout.strip() 77 assert data_folder 78 79 io.folder(data_folder, touch=True) 80 81 return data_folder 82 83 84_T_ITEM = TypeVar('_T_ITEM') 85 86 87def rng_choice( 88 rng: RandomGenerator, 89 items: Sequence[_T_ITEM], 90 probs: Optional[Sequence[float]] = None, 91) -> _T_ITEM: 92 idx = rng.choice(len(items), p=probs) 93 return items[idx] 94 95 96def rng_choice_with_size( 97 rng: RandomGenerator, 98 items: Sequence[_T_ITEM], 99 size: int, 100 probs: Optional[Sequence[float]] = None, 101 replace: bool = True, 102) -> Sequence[_T_ITEM]: 103 # NOTE: Without replacement! 104 indices = rng.choice(len(items), p=probs, size=size, replace=replace) 105 return [items[idx] for idx in indices] 106 107 108def rng_shuffle( 109 rng: RandomGenerator, 110 items: Sequence[_T_ITEM], 111) -> Sequence[_T_ITEM]: 112 indices = list(range(len(items))) 113 rng.shuffle(indices) 114 return tuple(items[idx] for idx in indices) 115 116 117_CV_INTER_FLAGS = cast( 118 Sequence[int], 119 ( 120 # NOTE: Keep the EXACT version. 121 # cv.INTER_NEAREST, 122 # NOTE: this one is Any. 123 cv.INTER_NEAREST_EXACT, 124 # NOTE: Keep the EXACT version. 125 # cv.INTER_LINEAR, 126 cv.INTER_LINEAR_EXACT, 127 cv.INTER_CUBIC, 128 cv.INTER_LANCZOS4, 129 ), 130) 131 132 133def sample_cv_resize_interpolation( 134 rng: RandomGenerator, 135 include_cv_inter_area: bool = False, 136): 137 flags = _CV_INTER_FLAGS 138 if include_cv_inter_area: 139 flags = (*_CV_INTER_FLAGS, cv.INTER_AREA) 140 return rng_choice(rng, flags) 141 142 143_T_TARGET = TypeVar('_T_TARGET') 144 145_cattrs = cattrs.GenConverter(forbid_extra_keys=True) 146 147 148def dyn_structure( 149 dyn_object: Any, 150 target_cls: Type[_T_TARGET], 151 support_path_type: bool = False, 152 force_path_type: bool = False, 153 support_none_type: bool = False, 154) -> _T_TARGET: 155 if support_none_type and dyn_object is None: 156 return target_cls() 157 158 if support_path_type or force_path_type: 159 dyn_object_is_path_type = is_path_type(dyn_object) 160 if force_path_type: 161 assert dyn_object_is_path_type 162 if dyn_object_is_path_type: 163 dyn_object = read_json_file(dyn_object) 164 165 isinstance_target_cls = False 166 try: 167 if isinstance(dyn_object, target_cls): 168 isinstance_target_cls = True 169 except TypeError: 170 # target_cls could be type annotation like Sequence[int]. 171 pass 172 173 if isinstance_target_cls: 174 # Do nothing. 175 pass 176 elif isinstance(dyn_object, abc.Mapping): 177 try: 178 dyn_object = _cattrs.structure(dyn_object, target_cls) 179 except ClassValidationError: 180 # cattrs cannot handle Class with hierarchy structure, 181 # in such case, fallback to manually initialization. 182 dyn_object = target_cls(**dyn_object) 183 elif isinstance(dyn_object, abc.Sequence): 184 dyn_object = _cattrs.structure(dyn_object, target_cls) 185 else: 186 raise NotImplementedError() 187 188 return dyn_object 189 190 191def normalize_to_probs(weights: Sequence[float]): 192 total = sum(weights) 193 probs = [weight / total for weight in weights] 194 return probs 195 196 197_T_KEY = TypeVar('_T_KEY') 198 199 200def normalize_to_keys_and_probs( 201 key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]] 202) -> Tuple[Sequence[_T_KEY], Sequence[float]]: 203 keys: List[_T_KEY] = [] 204 weights: List[float] = [] 205 206 if isinstance(key_weight_items, abc.Sequence): 207 for key, weight in key_weight_items: 208 keys.append(key) 209 weights.append(weight) 210 elif isinstance(key_weight_items, abc.Mapping): # type: ignore 211 for key, weight in key_weight_items.items(): 212 keys.append(key) 213 weights.append(weight) 214 else: 215 raise NotImplementedError() 216 217 probs = normalize_to_probs(weights) 218 return keys, probs 219 220 221def convert_camel_case_name_to_snake_case_name(name: str): 222 return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower() 223 224 225def get_config_class_snake_case_name(class_name: str): 226 snake_case_name = convert_camel_case_name_to_snake_case_name(class_name) 227 if snake_case_name.endswith('_config'): 228 snake_case_name = snake_case_name[:-len('_config')] 229 return snake_case_name 230 231 232def get_generic_classes(cls: Type[Any]): 233 return get_args(cls.__orig_bases__[0]) # type: ignore
def
attrs_lazy_field():
def
get_cattrs_converter_ignoring_init_equals_false():
47def get_cattrs_converter_ignoring_init_equals_false(): 48 converter = cattrs.Converter() 49 converter.register_unstructure_hook_factory( 50 attrs.has, 51 lambda cl: cattrs.gen.make_dict_unstructure_fn( 52 cl, 53 converter, 54 **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init}, 55 ), 56 ) 57 return converter
def
is_path_type(path: Any):
def
read_json_file(path: Union[str, os.PathLike]):
def
get_data_folder(file: Union[str, os.PathLike]):
68def get_data_folder(file: PathType): 69 proc = subprocess.run( 70 f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"', 71 shell=True, 72 capture_output=True, 73 text=True, 74 ) 75 assert proc.returncode == 0 76 77 data_folder = proc.stdout.strip() 78 assert data_folder 79 80 io.folder(data_folder, touch=True) 81 82 return data_folder
def
rng_choice( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], probs: Union[Sequence[float], NoneType] = None) -> ~_T_ITEM:
def
rng_choice_with_size( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], size: int, probs: Union[Sequence[float], NoneType] = None, replace: bool = True) -> Sequence[~_T_ITEM]:
97def rng_choice_with_size( 98 rng: RandomGenerator, 99 items: Sequence[_T_ITEM], 100 size: int, 101 probs: Optional[Sequence[float]] = None, 102 replace: bool = True, 103) -> Sequence[_T_ITEM]: 104 # NOTE: Without replacement! 105 indices = rng.choice(len(items), p=probs, size=size, replace=replace) 106 return [items[idx] for idx in indices]
def
rng_shuffle( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM]) -> Sequence[~_T_ITEM]:
def
sample_cv_resize_interpolation( rng: numpy.random._generator.Generator, include_cv_inter_area: bool = False):
def
dyn_structure( dyn_object: Any, target_cls: Type[~_T_TARGET], support_path_type: bool = False, force_path_type: bool = False, support_none_type: bool = False) -> ~_T_TARGET:
149def dyn_structure( 150 dyn_object: Any, 151 target_cls: Type[_T_TARGET], 152 support_path_type: bool = False, 153 force_path_type: bool = False, 154 support_none_type: bool = False, 155) -> _T_TARGET: 156 if support_none_type and dyn_object is None: 157 return target_cls() 158 159 if support_path_type or force_path_type: 160 dyn_object_is_path_type = is_path_type(dyn_object) 161 if force_path_type: 162 assert dyn_object_is_path_type 163 if dyn_object_is_path_type: 164 dyn_object = read_json_file(dyn_object) 165 166 isinstance_target_cls = False 167 try: 168 if isinstance(dyn_object, target_cls): 169 isinstance_target_cls = True 170 except TypeError: 171 # target_cls could be type annotation like Sequence[int]. 172 pass 173 174 if isinstance_target_cls: 175 # Do nothing. 176 pass 177 elif isinstance(dyn_object, abc.Mapping): 178 try: 179 dyn_object = _cattrs.structure(dyn_object, target_cls) 180 except ClassValidationError: 181 # cattrs cannot handle Class with hierarchy structure, 182 # in such case, fallback to manually initialization. 183 dyn_object = target_cls(**dyn_object) 184 elif isinstance(dyn_object, abc.Sequence): 185 dyn_object = _cattrs.structure(dyn_object, target_cls) 186 else: 187 raise NotImplementedError() 188 189 return dyn_object
def
normalize_to_probs(weights: Sequence[float]):
def
normalize_to_keys_and_probs( key_weight_items: Union[Sequence[Tuple[~_T_KEY, float]], Mapping[~_T_KEY, float]]) -> Tuple[Sequence[~_T_KEY], Sequence[float]]:
201def normalize_to_keys_and_probs( 202 key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]] 203) -> Tuple[Sequence[_T_KEY], Sequence[float]]: 204 keys: List[_T_KEY] = [] 205 weights: List[float] = [] 206 207 if isinstance(key_weight_items, abc.Sequence): 208 for key, weight in key_weight_items: 209 keys.append(key) 210 weights.append(weight) 211 elif isinstance(key_weight_items, abc.Mapping): # type: ignore 212 for key, weight in key_weight_items.items(): 213 keys.append(key) 214 weights.append(weight) 215 else: 216 raise NotImplementedError() 217 218 probs = normalize_to_probs(weights) 219 return keys, probs
def
convert_camel_case_name_to_snake_case_name(name: str):
def
get_config_class_snake_case_name(class_name: str):
def
get_generic_classes(cls: Type[Any]):