vkit.engine.char_sampler.func_collate

 1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
 2#
 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
 4#
 5# The commercial license gives you the full rights to create and distribute software
 6# on your own terms without any SSPL license obligations. For more information,
 7# please see the "LICENSE_COMMERCIAL.txt" file.
 8#
 9# This project is also available under Server Side Public License (SSPL).
10# The SSPL licensing is ideal for use cases such as open source projects with
11# SSPL distribution, student/academic purposes, hobby projects, internal research
12# projects without external distribution, or other projects where all SSPL
13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
14from typing import Sequence, List
15
16from numpy.random import Generator as RandomGenerator
17
18from vkit.engine.interface import EngineExecutorAggregatorSelector
19from .type import CharSamplerEngineRunConfig
20
21
22def char_sampler_func_collate(
23    selector: EngineExecutorAggregatorSelector[
24        CharSamplerEngineRunConfig,
25        Sequence[str],
26    ],
27    run_config: CharSamplerEngineRunConfig,
28    rng: RandomGenerator,
29):  # yapf: disable
30
31    if run_config.enable_aggregator_mode:
32        num_chars = run_config.num_chars
33
34        chars: List[str] = []
35        while len(chars) < num_chars:
36            if chars and rng.random() < 0.5:
37                chars.append(' ')
38            new_chars = selector.select_engine_executor(rng).run(run_config, rng)
39            chars.extend(new_chars)
40
41        # Trim and make sure the last char is not space.
42        if len(chars) > num_chars:
43            rest = chars[num_chars:]
44            chars = chars[:num_chars]
45            if chars[-1].isspace():
46                chars.pop()
47                assert not rest[0].isspace()
48                chars.append(rest[0])
49
50        assert len(chars) == num_chars
51        return chars
52
53    else:
54        return selector.select_engine_executor(rng).run(run_config, rng)
def char_sampler_func_collate( selector: vkit.engine.interface.EngineExecutorAggregatorSelector[vkit.engine.char_sampler.type.CharSamplerEngineRunConfig, typing.Sequence[str]], run_config: vkit.engine.char_sampler.type.CharSamplerEngineRunConfig, rng: numpy.random._generator.Generator):
23def char_sampler_func_collate(
24    selector: EngineExecutorAggregatorSelector[
25        CharSamplerEngineRunConfig,
26        Sequence[str],
27    ],
28    run_config: CharSamplerEngineRunConfig,
29    rng: RandomGenerator,
30):  # yapf: disable
31
32    if run_config.enable_aggregator_mode:
33        num_chars = run_config.num_chars
34
35        chars: List[str] = []
36        while len(chars) < num_chars:
37            if chars and rng.random() < 0.5:
38                chars.append(' ')
39            new_chars = selector.select_engine_executor(rng).run(run_config, rng)
40            chars.extend(new_chars)
41
42        # Trim and make sure the last char is not space.
43        if len(chars) > num_chars:
44            rest = chars[num_chars:]
45            chars = chars[:num_chars]
46            if chars[-1].isspace():
47                chars.pop()
48                assert not rest[0].isspace()
49                chars.append(rest[0])
50
51        assert len(chars) == num_chars
52        return chars
53
54    else:
55        return selector.select_engine_executor(rng).run(run_config, rng)