vkit.utility.opt

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import (
 15    get_args,
 16    cast,
 17    Sequence,
 18    Optional,
 19    TypeVar,
 20    Any,
 21    Type,
 22    Union,
 23    Tuple,
 24    Mapping,
 25    List,
 26)
 27import subprocess
 28from os import PathLike
 29from collections import abc
 30import re
 31
 32from numpy.random import Generator as RandomGenerator
 33import cv2 as cv
 34import iolite as io
 35import attrs
 36import cattrs
 37from cattrs.errors import ClassValidationError
 38
 39from vkit.utility import PathType
 40
 41
 42def attrs_lazy_field():
 43    return attrs.field(default=None, init=False, repr=False)
 44
 45
 46def get_cattrs_converter_ignoring_init_equals_false():
 47    converter = cattrs.Converter()
 48    converter.register_unstructure_hook_factory(
 49        attrs.has,
 50        lambda cl: cattrs.gen.make_dict_unstructure_fn(
 51            cl,
 52            converter,
 53            **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init},
 54        ),
 55    )
 56    return converter
 57
 58
 59def is_path_type(path: Any):
 60    return isinstance(path, (str, PathLike))  # type: ignore
 61
 62
 63def read_json_file(path: PathType):
 64    return io.read_json(path, expandvars=True)
 65
 66
 67def get_data_folder(file: PathType):
 68    proc = subprocess.run(
 69        f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"',
 70        shell=True,
 71        capture_output=True,
 72        text=True,
 73    )
 74    assert proc.returncode == 0
 75
 76    data_folder = proc.stdout.strip()
 77    assert data_folder
 78
 79    io.folder(data_folder, touch=True)
 80
 81    return data_folder
 82
 83
 84_T_ITEM = TypeVar('_T_ITEM')
 85
 86
 87def rng_choice(
 88    rng: RandomGenerator,
 89    items: Sequence[_T_ITEM],
 90    probs: Optional[Sequence[float]] = None,
 91) -> _T_ITEM:
 92    idx = rng.choice(len(items), p=probs)
 93    return items[idx]
 94
 95
 96def rng_choice_with_size(
 97    rng: RandomGenerator,
 98    items: Sequence[_T_ITEM],
 99    size: int,
100    probs: Optional[Sequence[float]] = None,
101    replace: bool = True,
102) -> Sequence[_T_ITEM]:
103    # NOTE: Without replacement!
104    indices = rng.choice(len(items), p=probs, size=size, replace=replace)
105    return [items[idx] for idx in indices]
106
107
108def rng_shuffle(
109    rng: RandomGenerator,
110    items: Sequence[_T_ITEM],
111) -> Sequence[_T_ITEM]:
112    indices = list(range(len(items)))
113    rng.shuffle(indices)
114    return tuple(items[idx] for idx in indices)
115
116
117_CV_INTER_FLAGS = cast(
118    Sequence[int],
119    (
120        # NOTE: Keep the EXACT version.
121        # cv.INTER_NEAREST,
122        # NOTE: this one is Any.
123        cv.INTER_NEAREST_EXACT,
124        # NOTE: Keep the EXACT version.
125        # cv.INTER_LINEAR,
126        cv.INTER_LINEAR_EXACT,
127        cv.INTER_CUBIC,
128        cv.INTER_LANCZOS4,
129    ),
130)
131
132
133def sample_cv_resize_interpolation(
134    rng: RandomGenerator,
135    include_cv_inter_area: bool = False,
136):
137    flags = _CV_INTER_FLAGS
138    if include_cv_inter_area:
139        flags = (*_CV_INTER_FLAGS, cv.INTER_AREA)
140    return rng_choice(rng, flags)
141
142
143_T_TARGET = TypeVar('_T_TARGET')
144
145_cattrs = cattrs.GenConverter(forbid_extra_keys=True)
146
147
148def dyn_structure(
149    dyn_object: Any,
150    target_cls: Type[_T_TARGET],
151    support_path_type: bool = False,
152    force_path_type: bool = False,
153    support_none_type: bool = False,
154) -> _T_TARGET:
155    if support_none_type and dyn_object is None:
156        return target_cls()
157
158    if support_path_type or force_path_type:
159        dyn_object_is_path_type = is_path_type(dyn_object)
160        if force_path_type:
161            assert dyn_object_is_path_type
162        if dyn_object_is_path_type:
163            dyn_object = read_json_file(dyn_object)
164
165    isinstance_target_cls = False
166    try:
167        if isinstance(dyn_object, target_cls):
168            isinstance_target_cls = True
169    except TypeError:
170        # target_cls could be type annotation like Sequence[int].
171        pass
172
173    if isinstance_target_cls:
174        # Do nothing.
175        pass
176    elif isinstance(dyn_object, abc.Mapping):
177        try:
178            dyn_object = _cattrs.structure(dyn_object, target_cls)
179        except ClassValidationError:
180            # cattrs cannot handle Class with hierarchy structure,
181            # in such case, fallback to manually initialization.
182            dyn_object = target_cls(**dyn_object)
183    elif isinstance(dyn_object, abc.Sequence):
184        dyn_object = _cattrs.structure(dyn_object, target_cls)
185    else:
186        raise NotImplementedError()
187
188    return dyn_object
189
190
191def normalize_to_probs(weights: Sequence[float]):
192    total = sum(weights)
193    probs = [weight / total for weight in weights]
194    return probs
195
196
197_T_KEY = TypeVar('_T_KEY')
198
199
200def normalize_to_keys_and_probs(
201    key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]]
202) -> Tuple[Sequence[_T_KEY], Sequence[float]]:
203    keys: List[_T_KEY] = []
204    weights: List[float] = []
205
206    if isinstance(key_weight_items, abc.Sequence):
207        for key, weight in key_weight_items:
208            keys.append(key)
209            weights.append(weight)
210    elif isinstance(key_weight_items, abc.Mapping):  # type: ignore
211        for key, weight in key_weight_items.items():
212            keys.append(key)
213            weights.append(weight)
214    else:
215        raise NotImplementedError()
216
217    probs = normalize_to_probs(weights)
218    return keys, probs
219
220
221def convert_camel_case_name_to_snake_case_name(name: str):
222    return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
223
224
225def get_config_class_snake_case_name(class_name: str):
226    snake_case_name = convert_camel_case_name_to_snake_case_name(class_name)
227    if snake_case_name.endswith('_config'):
228        snake_case_name = snake_case_name[:-len('_config')]
229    return snake_case_name
230
231
232def get_generic_classes(cls: Type[Any]):
233    return get_args(cls.__orig_bases__[0])  # type: ignore
def attrs_lazy_field():
43def attrs_lazy_field():
44    return attrs.field(default=None, init=False, repr=False)
def get_cattrs_converter_ignoring_init_equals_false():
47def get_cattrs_converter_ignoring_init_equals_false():
48    converter = cattrs.Converter()
49    converter.register_unstructure_hook_factory(
50        attrs.has,
51        lambda cl: cattrs.gen.make_dict_unstructure_fn(
52            cl,
53            converter,
54            **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init},
55        ),
56    )
57    return converter
def is_path_type(path: Any):
60def is_path_type(path: Any):
61    return isinstance(path, (str, PathLike))  # type: ignore
def read_json_file(path: Union[str, os.PathLike]):
64def read_json_file(path: PathType):
65    return io.read_json(path, expandvars=True)
def get_data_folder(file: Union[str, os.PathLike]):
68def get_data_folder(file: PathType):
69    proc = subprocess.run(
70        f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"',
71        shell=True,
72        capture_output=True,
73        text=True,
74    )
75    assert proc.returncode == 0
76
77    data_folder = proc.stdout.strip()
78    assert data_folder
79
80    io.folder(data_folder, touch=True)
81
82    return data_folder
def rng_choice( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], probs: Union[Sequence[float], NoneType] = None) -> ~_T_ITEM:
88def rng_choice(
89    rng: RandomGenerator,
90    items: Sequence[_T_ITEM],
91    probs: Optional[Sequence[float]] = None,
92) -> _T_ITEM:
93    idx = rng.choice(len(items), p=probs)
94    return items[idx]
def rng_choice_with_size( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], size: int, probs: Union[Sequence[float], NoneType] = None, replace: bool = True) -> Sequence[~_T_ITEM]:
 97def rng_choice_with_size(
 98    rng: RandomGenerator,
 99    items: Sequence[_T_ITEM],
100    size: int,
101    probs: Optional[Sequence[float]] = None,
102    replace: bool = True,
103) -> Sequence[_T_ITEM]:
104    # NOTE: Without replacement!
105    indices = rng.choice(len(items), p=probs, size=size, replace=replace)
106    return [items[idx] for idx in indices]
def rng_shuffle( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM]) -> Sequence[~_T_ITEM]:
109def rng_shuffle(
110    rng: RandomGenerator,
111    items: Sequence[_T_ITEM],
112) -> Sequence[_T_ITEM]:
113    indices = list(range(len(items)))
114    rng.shuffle(indices)
115    return tuple(items[idx] for idx in indices)
def sample_cv_resize_interpolation( rng: numpy.random._generator.Generator, include_cv_inter_area: bool = False):
134def sample_cv_resize_interpolation(
135    rng: RandomGenerator,
136    include_cv_inter_area: bool = False,
137):
138    flags = _CV_INTER_FLAGS
139    if include_cv_inter_area:
140        flags = (*_CV_INTER_FLAGS, cv.INTER_AREA)
141    return rng_choice(rng, flags)
def dyn_structure( dyn_object: Any, target_cls: Type[~_T_TARGET], support_path_type: bool = False, force_path_type: bool = False, support_none_type: bool = False) -> ~_T_TARGET:
149def dyn_structure(
150    dyn_object: Any,
151    target_cls: Type[_T_TARGET],
152    support_path_type: bool = False,
153    force_path_type: bool = False,
154    support_none_type: bool = False,
155) -> _T_TARGET:
156    if support_none_type and dyn_object is None:
157        return target_cls()
158
159    if support_path_type or force_path_type:
160        dyn_object_is_path_type = is_path_type(dyn_object)
161        if force_path_type:
162            assert dyn_object_is_path_type
163        if dyn_object_is_path_type:
164            dyn_object = read_json_file(dyn_object)
165
166    isinstance_target_cls = False
167    try:
168        if isinstance(dyn_object, target_cls):
169            isinstance_target_cls = True
170    except TypeError:
171        # target_cls could be type annotation like Sequence[int].
172        pass
173
174    if isinstance_target_cls:
175        # Do nothing.
176        pass
177    elif isinstance(dyn_object, abc.Mapping):
178        try:
179            dyn_object = _cattrs.structure(dyn_object, target_cls)
180        except ClassValidationError:
181            # cattrs cannot handle Class with hierarchy structure,
182            # in such case, fallback to manually initialization.
183            dyn_object = target_cls(**dyn_object)
184    elif isinstance(dyn_object, abc.Sequence):
185        dyn_object = _cattrs.structure(dyn_object, target_cls)
186    else:
187        raise NotImplementedError()
188
189    return dyn_object
def normalize_to_probs(weights: Sequence[float]):
192def normalize_to_probs(weights: Sequence[float]):
193    total = sum(weights)
194    probs = [weight / total for weight in weights]
195    return probs
def normalize_to_keys_and_probs( key_weight_items: Union[Sequence[Tuple[~_T_KEY, float]], Mapping[~_T_KEY, float]]) -> Tuple[Sequence[~_T_KEY], Sequence[float]]:
201def normalize_to_keys_and_probs(
202    key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]]
203) -> Tuple[Sequence[_T_KEY], Sequence[float]]:
204    keys: List[_T_KEY] = []
205    weights: List[float] = []
206
207    if isinstance(key_weight_items, abc.Sequence):
208        for key, weight in key_weight_items:
209            keys.append(key)
210            weights.append(weight)
211    elif isinstance(key_weight_items, abc.Mapping):  # type: ignore
212        for key, weight in key_weight_items.items():
213            keys.append(key)
214            weights.append(weight)
215    else:
216        raise NotImplementedError()
217
218    probs = normalize_to_probs(weights)
219    return keys, probs
def convert_camel_case_name_to_snake_case_name(name: str):
222def convert_camel_case_name_to_snake_case_name(name: str):
223    return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
def get_config_class_snake_case_name(class_name: str):
226def get_config_class_snake_case_name(class_name: str):
227    snake_case_name = convert_camel_case_name_to_snake_case_name(class_name)
228    if snake_case_name.endswith('_config'):
229        snake_case_name = snake_case_name[:-len('_config')]
230    return snake_case_name
def get_generic_classes(cls: Type[Any]):
233def get_generic_classes(cls: Type[Any]):
234    return get_args(cls.__orig_bases__[0])  # type: ignore