vkit.utility.opt

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import (
 15    get_args,
 16    cast,
 17    Sequence,
 18    Optional,
 19    TypeVar,
 20    Any,
 21    Type,
 22    Union,
 23    Tuple,
 24    Mapping,
 25    List,
 26)
 27import subprocess
 28from os import PathLike
 29from collections import abc
 30import re
 31
 32from numpy.random import Generator as RandomGenerator
 33import cv2 as cv
 34import iolite as io
 35import attrs
 36import cattrs
 37from cattrs.errors import ClassValidationError
 38
 39from .type import PathType
 40
 41
 42def attrs_lazy_field():
 43    return attrs.field(default=None, init=False, repr=False)
 44
 45
 46_T_FIELD = TypeVar('_T_FIELD')
 47
 48
 49def unwrap_optional_field(field: Optional[_T_FIELD]) -> _T_FIELD:
 50    assert field is not None
 51    return field
 52
 53
 54def get_cattrs_converter_ignoring_init_equals_false():
 55    converter = cattrs.Converter()
 56    converter.register_unstructure_hook_factory(
 57        attrs.has,
 58        lambda cl: cattrs.gen.make_dict_unstructure_fn(
 59            cl,
 60            converter,
 61            **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init},
 62        ),
 63    )
 64    return converter
 65
 66
 67def is_path_type(path: Any):
 68    return isinstance(path, (str, PathLike))  # type: ignore
 69
 70
 71def read_json_file(path: PathType):
 72    return io.read_json(path, expandvars=True)
 73
 74
 75def get_data_folder(file: PathType):
 76    proc = subprocess.run(
 77        f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"',
 78        shell=True,
 79        capture_output=True,
 80        text=True,
 81    )
 82    assert proc.returncode == 0
 83
 84    data_folder = proc.stdout.strip()
 85    assert data_folder
 86
 87    io.folder(data_folder, touch=True)
 88
 89    return data_folder
 90
 91
 92_T_ITEM = TypeVar('_T_ITEM')
 93
 94
 95def rng_choice(
 96    rng: RandomGenerator,
 97    items: Sequence[_T_ITEM],
 98    probs: Optional[Sequence[float]] = None,
 99) -> _T_ITEM:
100    idx = rng.choice(len(items), p=probs)
101    return items[idx]
102
103
104def rng_choice_with_size(
105    rng: RandomGenerator,
106    items: Sequence[_T_ITEM],
107    size: int,
108    probs: Optional[Sequence[float]] = None,
109    replace: bool = True,
110) -> Sequence[_T_ITEM]:
111    # NOTE: Without replacement!
112    indices = rng.choice(len(items), p=probs, size=size, replace=replace)
113    return [items[idx] for idx in indices]
114
115
116def rng_shuffle(
117    rng: RandomGenerator,
118    items: Sequence[_T_ITEM],
119) -> Sequence[_T_ITEM]:
120    indices = list(range(len(items)))
121    rng.shuffle(indices)
122    return tuple(items[idx] for idx in indices)
123
124
125_CV_INTER_FLAGS = cast(
126    Sequence[int],
127    (
128        # NOTE: Keep the EXACT version.
129        # cv.INTER_NEAREST,
130        # NOTE: this one is Any.
131        cv.INTER_NEAREST_EXACT,
132        # NOTE: Keep the EXACT version.
133        # cv.INTER_LINEAR,
134        cv.INTER_LINEAR_EXACT,
135        cv.INTER_CUBIC,
136        cv.INTER_LANCZOS4,
137    ),
138)
139
140
141def sample_cv_resize_interpolation(
142    rng: RandomGenerator,
143    include_cv_inter_area: bool = False,
144):
145    flags = _CV_INTER_FLAGS
146    if include_cv_inter_area:
147        flags = (*_CV_INTER_FLAGS, cv.INTER_AREA)
148    return rng_choice(rng, flags)
149
150
151_T_TARGET = TypeVar('_T_TARGET')
152
153_cattrs = cattrs.GenConverter(forbid_extra_keys=True)
154
155# Do nothing for config field.
156_cattrs.register_structure_hook(
157    Optional[Union[Mapping[str, Any], str, PathLike]],
158    lambda d, _: d,
159)
160
161
162def dyn_structure(
163    dyn_object: Any,
164    target_cls: Type[_T_TARGET],
165    support_path_type: bool = False,
166    force_path_type: bool = False,
167    support_none_type: bool = False,
168) -> _T_TARGET:
169    if support_none_type and dyn_object is None:
170        return target_cls()
171
172    if support_path_type or force_path_type:
173        dyn_object_is_path_type = is_path_type(dyn_object)
174        if force_path_type:
175            assert dyn_object_is_path_type
176        if dyn_object_is_path_type:
177            dyn_object = read_json_file(dyn_object)
178
179    isinstance_target_cls = False
180    try:
181        if isinstance(dyn_object, target_cls):
182            isinstance_target_cls = True
183    except TypeError:
184        # target_cls could be type annotation like Sequence[int].
185        pass
186
187    if isinstance_target_cls:
188        # Do nothing.
189        pass
190    elif isinstance(dyn_object, abc.Mapping):
191        try:
192            dyn_object = _cattrs.structure(dyn_object, target_cls)
193        except ClassValidationError:
194            # cattrs cannot handle Class with hierarchy structure,
195            # in such case, fallback to manually initialization.
196            dyn_object = target_cls(**dyn_object)
197    elif isinstance(dyn_object, abc.Sequence):
198        dyn_object = _cattrs.structure(dyn_object, target_cls)
199    else:
200        raise NotImplementedError()
201
202    return dyn_object
203
204
205def normalize_to_probs(weights: Sequence[float]):
206    total = sum(weights)
207    probs = [weight / total for weight in weights]
208    return probs
209
210
211_T_KEY = TypeVar('_T_KEY')
212
213
214def normalize_to_keys_and_probs(
215    key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]]
216) -> Tuple[Sequence[_T_KEY], Sequence[float]]:
217    keys: List[_T_KEY] = []
218    weights: List[float] = []
219
220    if isinstance(key_weight_items, abc.Sequence):
221        for key, weight in key_weight_items:
222            keys.append(key)
223            weights.append(weight)
224    elif isinstance(key_weight_items, abc.Mapping):  # type: ignore
225        for key, weight in key_weight_items.items():
226            keys.append(key)
227            weights.append(weight)
228    else:
229        raise NotImplementedError()
230
231    probs = normalize_to_probs(weights)
232    return keys, probs
233
234
235def convert_camel_case_name_to_snake_case_name(name: str):
236    return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
237
238
239def get_config_class_snake_case_name(class_name: str):
240    snake_case_name = convert_camel_case_name_to_snake_case_name(class_name)
241    if snake_case_name.endswith('_config'):
242        snake_case_name = snake_case_name[:-len('_config')]
243    return snake_case_name
244
245
246def get_generic_classes(cls: Type[Any]):
247    return get_args(cls.__orig_bases__[0])  # type: ignore
def attrs_lazy_field():
43def attrs_lazy_field():
44    return attrs.field(default=None, init=False, repr=False)
def unwrap_optional_field(field: Union[~_T_FIELD, NoneType]) -> ~_T_FIELD:
50def unwrap_optional_field(field: Optional[_T_FIELD]) -> _T_FIELD:
51    assert field is not None
52    return field
def get_cattrs_converter_ignoring_init_equals_false():
55def get_cattrs_converter_ignoring_init_equals_false():
56    converter = cattrs.Converter()
57    converter.register_unstructure_hook_factory(
58        attrs.has,
59        lambda cl: cattrs.gen.make_dict_unstructure_fn(
60            cl,
61            converter,
62            **{a.name: cattrs.override(omit=True) for a in attrs.fields(cl) if not a.init},
63        ),
64    )
65    return converter
def is_path_type(path: Any):
68def is_path_type(path: Any):
69    return isinstance(path, (str, PathLike))  # type: ignore
def read_json_file(path: Union[str, os.PathLike]):
72def read_json_file(path: PathType):
73    return io.read_json(path, expandvars=True)
def get_data_folder(file: Union[str, os.PathLike]):
76def get_data_folder(file: PathType):
77    proc = subprocess.run(
78        f'$VKIT_ROOT/.direnv/bin/pyproject-data-folder "$VKIT_ROOT" "$VKIT_DATA" "{file}"',
79        shell=True,
80        capture_output=True,
81        text=True,
82    )
83    assert proc.returncode == 0
84
85    data_folder = proc.stdout.strip()
86    assert data_folder
87
88    io.folder(data_folder, touch=True)
89
90    return data_folder
def rng_choice( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], probs: Union[Sequence[float], NoneType] = None) -> ~_T_ITEM:
 96def rng_choice(
 97    rng: RandomGenerator,
 98    items: Sequence[_T_ITEM],
 99    probs: Optional[Sequence[float]] = None,
100) -> _T_ITEM:
101    idx = rng.choice(len(items), p=probs)
102    return items[idx]
def rng_choice_with_size( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM], size: int, probs: Union[Sequence[float], NoneType] = None, replace: bool = True) -> Sequence[~_T_ITEM]:
105def rng_choice_with_size(
106    rng: RandomGenerator,
107    items: Sequence[_T_ITEM],
108    size: int,
109    probs: Optional[Sequence[float]] = None,
110    replace: bool = True,
111) -> Sequence[_T_ITEM]:
112    # NOTE: Without replacement!
113    indices = rng.choice(len(items), p=probs, size=size, replace=replace)
114    return [items[idx] for idx in indices]
def rng_shuffle( rng: numpy.random._generator.Generator, items: Sequence[~_T_ITEM]) -> Sequence[~_T_ITEM]:
117def rng_shuffle(
118    rng: RandomGenerator,
119    items: Sequence[_T_ITEM],
120) -> Sequence[_T_ITEM]:
121    indices = list(range(len(items)))
122    rng.shuffle(indices)
123    return tuple(items[idx] for idx in indices)
def sample_cv_resize_interpolation( rng: numpy.random._generator.Generator, include_cv_inter_area: bool = False):
142def sample_cv_resize_interpolation(
143    rng: RandomGenerator,
144    include_cv_inter_area: bool = False,
145):
146    flags = _CV_INTER_FLAGS
147    if include_cv_inter_area:
148        flags = (*_CV_INTER_FLAGS, cv.INTER_AREA)
149    return rng_choice(rng, flags)
def dyn_structure( dyn_object: Any, target_cls: Type[~_T_TARGET], support_path_type: bool = False, force_path_type: bool = False, support_none_type: bool = False) -> ~_T_TARGET:
163def dyn_structure(
164    dyn_object: Any,
165    target_cls: Type[_T_TARGET],
166    support_path_type: bool = False,
167    force_path_type: bool = False,
168    support_none_type: bool = False,
169) -> _T_TARGET:
170    if support_none_type and dyn_object is None:
171        return target_cls()
172
173    if support_path_type or force_path_type:
174        dyn_object_is_path_type = is_path_type(dyn_object)
175        if force_path_type:
176            assert dyn_object_is_path_type
177        if dyn_object_is_path_type:
178            dyn_object = read_json_file(dyn_object)
179
180    isinstance_target_cls = False
181    try:
182        if isinstance(dyn_object, target_cls):
183            isinstance_target_cls = True
184    except TypeError:
185        # target_cls could be type annotation like Sequence[int].
186        pass
187
188    if isinstance_target_cls:
189        # Do nothing.
190        pass
191    elif isinstance(dyn_object, abc.Mapping):
192        try:
193            dyn_object = _cattrs.structure(dyn_object, target_cls)
194        except ClassValidationError:
195            # cattrs cannot handle Class with hierarchy structure,
196            # in such case, fallback to manually initialization.
197            dyn_object = target_cls(**dyn_object)
198    elif isinstance(dyn_object, abc.Sequence):
199        dyn_object = _cattrs.structure(dyn_object, target_cls)
200    else:
201        raise NotImplementedError()
202
203    return dyn_object
def normalize_to_probs(weights: Sequence[float]):
206def normalize_to_probs(weights: Sequence[float]):
207    total = sum(weights)
208    probs = [weight / total for weight in weights]
209    return probs
def normalize_to_keys_and_probs( key_weight_items: Union[Sequence[Tuple[~_T_KEY, float]], Mapping[~_T_KEY, float]]) -> Tuple[Sequence[~_T_KEY], Sequence[float]]:
215def normalize_to_keys_and_probs(
216    key_weight_items: Union[Sequence[Tuple[_T_KEY, float]], Mapping[_T_KEY, float]]
217) -> Tuple[Sequence[_T_KEY], Sequence[float]]:
218    keys: List[_T_KEY] = []
219    weights: List[float] = []
220
221    if isinstance(key_weight_items, abc.Sequence):
222        for key, weight in key_weight_items:
223            keys.append(key)
224            weights.append(weight)
225    elif isinstance(key_weight_items, abc.Mapping):  # type: ignore
226        for key, weight in key_weight_items.items():
227            keys.append(key)
228            weights.append(weight)
229    else:
230        raise NotImplementedError()
231
232    probs = normalize_to_probs(weights)
233    return keys, probs
def convert_camel_case_name_to_snake_case_name(name: str):
236def convert_camel_case_name_to_snake_case_name(name: str):
237    return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
def get_config_class_snake_case_name(class_name: str):
240def get_config_class_snake_case_name(class_name: str):
241    snake_case_name = convert_camel_case_name_to_snake_case_name(class_name)
242    if snake_case_name.endswith('_config'):
243        snake_case_name = snake_case_name[:-len('_config')]
244    return snake_case_name
def get_generic_classes(cls: Type[Any]):
247def get_generic_classes(cls: Type[Any]):
248    return get_args(cls.__orig_bases__[0])  # type: ignore