vkit.engine.char_sampler.func_collate
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Sequence, List 15 16from numpy.random import Generator as RandomGenerator 17 18from ..interface import EngineExecutorAggregatorSelector 19from .type import CharSamplerEngineRunConfig 20 21 22def char_sampler_func_collate( 23 selector: EngineExecutorAggregatorSelector[ 24 CharSamplerEngineRunConfig, 25 Sequence[str], 26 ], 27 run_config: CharSamplerEngineRunConfig, 28 rng: RandomGenerator, 29): # yapf: disable 30 31 if run_config.enable_aggregator_mode: 32 num_chars = run_config.num_chars 33 34 chars: List[str] = [] 35 while len(chars) < num_chars: 36 if chars and rng.random() < 0.5: 37 chars.append(' ') 38 new_chars = selector.select_engine_executor(rng).run(run_config, rng) 39 chars.extend(new_chars) 40 41 # Trim and make sure the last char is not space. 42 if len(chars) > num_chars: 43 rest = chars[num_chars:] 44 chars = chars[:num_chars] 45 if chars[-1].isspace(): 46 chars.pop() 47 assert not rest[0].isspace() 48 chars.append(rest[0]) 49 50 assert len(chars) == num_chars 51 return chars 52 53 else: 54 return selector.select_engine_executor(rng).run(run_config, rng)
def
char_sampler_func_collate( selector: vkit.engine.interface.EngineExecutorAggregatorSelector[vkit.engine.char_sampler.type.CharSamplerEngineRunConfig, typing.Sequence[str]], run_config: vkit.engine.char_sampler.type.CharSamplerEngineRunConfig, rng: numpy.random._generator.Generator):
23def char_sampler_func_collate( 24 selector: EngineExecutorAggregatorSelector[ 25 CharSamplerEngineRunConfig, 26 Sequence[str], 27 ], 28 run_config: CharSamplerEngineRunConfig, 29 rng: RandomGenerator, 30): # yapf: disable 31 32 if run_config.enable_aggregator_mode: 33 num_chars = run_config.num_chars 34 35 chars: List[str] = [] 36 while len(chars) < num_chars: 37 if chars and rng.random() < 0.5: 38 chars.append(' ') 39 new_chars = selector.select_engine_executor(rng).run(run_config, rng) 40 chars.extend(new_chars) 41 42 # Trim and make sure the last char is not space. 43 if len(chars) > num_chars: 44 rest = chars[num_chars:] 45 chars = chars[:num_chars] 46 if chars[-1].isspace(): 47 chars.pop() 48 assert not rest[0].isspace() 49 chars.append(rest[0]) 50 51 assert len(chars) == num_chars 52 return chars 53 54 else: 55 return selector.select_engine_executor(rng).run(run_config, rng)