vkit.pipeline.interface
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import ( 15 cast, 16 Generic, 17 TypeVar, 18 Type, 19 Dict, 20 Mapping, 21 Any, 22 Sequence, 23 Optional, 24 Union, 25 List, 26) 27 28import attrs 29from numpy.random import Generator as RandomGenerator 30 31from vkit.utility import ( 32 is_path_type, 33 read_json_file, 34 dyn_structure, 35 convert_camel_case_name_to_snake_case_name, 36 get_generic_classes, 37 PathType, 38) 39 40_T_VALUE = TypeVar('_T_VALUE') 41_T_CONFIG = TypeVar('_T_CONFIG') 42_T_INPUT = TypeVar('_T_INPUT') 43_T_OUTPUT = TypeVar('_T_OUTPUT') 44 45 46@attrs.define 47class PipelineState: 48 key_to_value: Dict[str, Any] = attrs.field(factory=dict) 49 50 def get_value(self, key: str, value_cls: Type[_T_VALUE]) -> _T_VALUE: 51 if key not in self.key_to_value: 52 raise KeyError(f'key={key} not found.') 53 value = self.key_to_value[key] 54 if not isinstance(value, value_cls): 55 raise TypeError(f'key={key}, value type={type(value)} is not instance of {value_cls}') 56 return value 57 58 def set_value(self, key: str, value: Any, override: bool = False): 59 if key in self.key_to_value and not override: 60 raise KeyError(f'key={key} exists but override is not set.') 61 self.key_to_value[key] = value 62 63 64class PipelineStep(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 65 66 @classmethod 67 def get_config_cls(cls) -> Type[_T_CONFIG]: 68 return get_generic_classes(cls)[0] # type: ignore 69 70 @classmethod 71 def get_input_cls(cls) -> Type[_T_INPUT]: 72 return get_generic_classes(cls)[1] # type: ignore 73 74 @classmethod 75 def get_output_cls(cls) -> Type[_T_OUTPUT]: 76 return get_generic_classes(cls)[2] # type: ignore 77 78 _cached_name: str = '' 79 80 @classmethod 81 def get_name(cls): 82 if not cls._cached_name: 83 cls._cached_name = convert_camel_case_name_to_snake_case_name(cls.__name__) 84 return cls._cached_name 85 86 def __init__(self, config: _T_CONFIG): 87 self.config = config 88 89 def run(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT: 90 raise NotImplementedError() 91 92 93class PipelineStepFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 94 95 def __init__(self, pipeline_step_cls: Type[PipelineStep[_T_CONFIG, _T_INPUT, _T_OUTPUT]]): 96 self.pipeline_step_cls = pipeline_step_cls 97 98 @property 99 def name(self): 100 return self.pipeline_step_cls.get_name() 101 102 def get_config_cls(self): 103 return self.pipeline_step_cls.get_config_cls() 104 105 def create( 106 self, 107 config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None, 108 ): 109 config = dyn_structure( 110 config, 111 self.get_config_cls(), 112 support_path_type=True, 113 support_none_type=True, 114 ) 115 return self.pipeline_step_cls(config) 116 117 118class PipelineStepCollectionFactory: 119 120 def __init__(self): 121 self.name_to_step_factory: Dict[str, PipelineStepFactory] = {} 122 123 def register_step_factories( 124 self, 125 namespace: str, 126 step_factories: Sequence[PipelineStepFactory], 127 ): 128 for step_factory in step_factories: 129 name = f'{namespace}.{step_factory.name}' 130 assert name not in self.name_to_step_factory 131 self.name_to_step_factory[name] = step_factory 132 133 def create( 134 self, 135 step_configs: Union[Sequence[Mapping[str, Any]], PathType], 136 ): 137 if is_path_type(step_configs): 138 step_configs = read_json_file(step_configs) # type: ignore 139 step_configs = cast(Sequence[Mapping[str, Any]], step_configs) 140 141 steps: List[PipelineStep] = [] 142 for step_config in step_configs: 143 name = step_config['name'] 144 if name not in self.name_to_step_factory: 145 raise KeyError(f'name={name} not found.') 146 step_factory = self.name_to_step_factory[name] 147 steps.append(step_factory.create(step_config.get('config'))) 148 return steps 149 150 151class PipelinePostProcessor(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 152 153 def __init__(self, config: _T_CONFIG): 154 self.config = config 155 156 @classmethod 157 def get_input_cls(cls) -> Type[_T_INPUT]: 158 return get_generic_classes(cls)[1] # type: ignore 159 160 def generate_output(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT: 161 raise NotImplementedError() 162 163 164class PipelinePostProcessorFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 165 166 def __init__( 167 self, 168 pipeline_post_processor_cls: Type[PipelinePostProcessor[_T_CONFIG, _T_INPUT, _T_OUTPUT]], 169 ): 170 self.pipeline_post_processor_cls = pipeline_post_processor_cls 171 172 def get_config_cls(self) -> Type[_T_CONFIG]: 173 return get_generic_classes(self.pipeline_post_processor_cls)[0] # type: ignore 174 175 def create( 176 self, 177 config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None, 178 ): 179 config = dyn_structure( 180 config, 181 self.get_config_cls(), 182 support_path_type=True, 183 support_none_type=True, 184 ) 185 return self.pipeline_post_processor_cls(config) 186 187 188@attrs.define 189class PipelineRunRngStateOutput: 190 rng_state: Mapping[str, Any] 191 192 193class Pipeline(Generic[_T_OUTPUT]): 194 195 def __init__( 196 self, 197 steps: Sequence[PipelineStep], 198 post_processor: PipelinePostProcessor[Any, Any, _T_OUTPUT], 199 ): 200 self.steps = steps 201 self.post_processor = post_processor 202 203 @classmethod 204 def build_input(cls, state: PipelineState, input_cls: Any): 205 assert attrs.has(input_cls) 206 207 input_kwargs = {} 208 for key, key_field in attrs.fields_dict(input_cls).items(): 209 assert key_field.type 210 assert attrs.has(key_field.type) 211 value = state.get_value( 212 convert_camel_case_name_to_snake_case_name(key_field.type.__name__), 213 key_field.type, 214 ) 215 input_kwargs[key] = value 216 217 return input_cls(**input_kwargs) 218 219 def run( 220 self, 221 rng: RandomGenerator, 222 state: Optional[PipelineState] = None, 223 ) -> _T_OUTPUT: 224 if state is None: 225 state = PipelineState() 226 227 # Save the rng state. 228 state.set_value( 229 convert_camel_case_name_to_snake_case_name(PipelineRunRngStateOutput.__name__), 230 PipelineRunRngStateOutput(rng.bit_generator.state), 231 ) 232 233 # Run steps. 234 for step in self.steps: 235 # Build input. 236 step_input = self.build_input(state, step.get_input_cls()) 237 238 # Generate output. 239 step_output = step.run(step_input, rng) 240 241 # Update state. 242 step_output_cls = step.get_output_cls() 243 assert isinstance(step_output, step_output_cls) 244 assert attrs.has(step_output_cls) 245 state.set_value( 246 convert_camel_case_name_to_snake_case_name(step_output_cls.__name__), 247 step_output, 248 ) 249 250 # Post processing. 251 return self.post_processor.generate_output( 252 self.build_input(state, self.post_processor.get_input_cls()), 253 rng, 254 )
48class PipelineState: 49 key_to_value: Dict[str, Any] = attrs.field(factory=dict) 50 51 def get_value(self, key: str, value_cls: Type[_T_VALUE]) -> _T_VALUE: 52 if key not in self.key_to_value: 53 raise KeyError(f'key={key} not found.') 54 value = self.key_to_value[key] 55 if not isinstance(value, value_cls): 56 raise TypeError(f'key={key}, value type={type(value)} is not instance of {value_cls}') 57 return value 58 59 def set_value(self, key: str, value: Any, override: bool = False): 60 if key in self.key_to_value and not override: 61 raise KeyError(f'key={key} exists but override is not set.') 62 self.key_to_value[key] = value
2def __init__(self, key_to_value=NOTHING): 3 if key_to_value is not NOTHING: 4 self.key_to_value = key_to_value 5 else: 6 self.key_to_value = __attr_factory_key_to_value()
Method generated by attrs for class PipelineState.
51 def get_value(self, key: str, value_cls: Type[_T_VALUE]) -> _T_VALUE: 52 if key not in self.key_to_value: 53 raise KeyError(f'key={key} not found.') 54 value = self.key_to_value[key] 55 if not isinstance(value, value_cls): 56 raise TypeError(f'key={key}, value type={type(value)} is not instance of {value_cls}') 57 return value
65class PipelineStep(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 66 67 @classmethod 68 def get_config_cls(cls) -> Type[_T_CONFIG]: 69 return get_generic_classes(cls)[0] # type: ignore 70 71 @classmethod 72 def get_input_cls(cls) -> Type[_T_INPUT]: 73 return get_generic_classes(cls)[1] # type: ignore 74 75 @classmethod 76 def get_output_cls(cls) -> Type[_T_OUTPUT]: 77 return get_generic_classes(cls)[2] # type: ignore 78 79 _cached_name: str = '' 80 81 @classmethod 82 def get_name(cls): 83 if not cls._cached_name: 84 cls._cached_name = convert_camel_case_name_to_snake_case_name(cls.__name__) 85 return cls._cached_name 86 87 def __init__(self, config: _T_CONFIG): 88 self.config = config 89 90 def run(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT: 91 raise NotImplementedError()
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
94class PipelineStepFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 95 96 def __init__(self, pipeline_step_cls: Type[PipelineStep[_T_CONFIG, _T_INPUT, _T_OUTPUT]]): 97 self.pipeline_step_cls = pipeline_step_cls 98 99 @property 100 def name(self): 101 return self.pipeline_step_cls.get_name() 102 103 def get_config_cls(self): 104 return self.pipeline_step_cls.get_config_cls() 105 106 def create( 107 self, 108 config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None, 109 ): 110 config = dyn_structure( 111 config, 112 self.get_config_cls(), 113 support_path_type=True, 114 support_none_type=True, 115 ) 116 return self.pipeline_step_cls(config)
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
119class PipelineStepCollectionFactory: 120 121 def __init__(self): 122 self.name_to_step_factory: Dict[str, PipelineStepFactory] = {} 123 124 def register_step_factories( 125 self, 126 namespace: str, 127 step_factories: Sequence[PipelineStepFactory], 128 ): 129 for step_factory in step_factories: 130 name = f'{namespace}.{step_factory.name}' 131 assert name not in self.name_to_step_factory 132 self.name_to_step_factory[name] = step_factory 133 134 def create( 135 self, 136 step_configs: Union[Sequence[Mapping[str, Any]], PathType], 137 ): 138 if is_path_type(step_configs): 139 step_configs = read_json_file(step_configs) # type: ignore 140 step_configs = cast(Sequence[Mapping[str, Any]], step_configs) 141 142 steps: List[PipelineStep] = [] 143 for step_config in step_configs: 144 name = step_config['name'] 145 if name not in self.name_to_step_factory: 146 raise KeyError(f'name={name} not found.') 147 step_factory = self.name_to_step_factory[name] 148 steps.append(step_factory.create(step_config.get('config'))) 149 return steps
124 def register_step_factories( 125 self, 126 namespace: str, 127 step_factories: Sequence[PipelineStepFactory], 128 ): 129 for step_factory in step_factories: 130 name = f'{namespace}.{step_factory.name}' 131 assert name not in self.name_to_step_factory 132 self.name_to_step_factory[name] = step_factory
134 def create( 135 self, 136 step_configs: Union[Sequence[Mapping[str, Any]], PathType], 137 ): 138 if is_path_type(step_configs): 139 step_configs = read_json_file(step_configs) # type: ignore 140 step_configs = cast(Sequence[Mapping[str, Any]], step_configs) 141 142 steps: List[PipelineStep] = [] 143 for step_config in step_configs: 144 name = step_config['name'] 145 if name not in self.name_to_step_factory: 146 raise KeyError(f'name={name} not found.') 147 step_factory = self.name_to_step_factory[name] 148 steps.append(step_factory.create(step_config.get('config'))) 149 return steps
152class PipelinePostProcessor(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 153 154 def __init__(self, config: _T_CONFIG): 155 self.config = config 156 157 @classmethod 158 def get_input_cls(cls) -> Type[_T_INPUT]: 159 return get_generic_classes(cls)[1] # type: ignore 160 161 def generate_output(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT: 162 raise NotImplementedError()
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
165class PipelinePostProcessorFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]): 166 167 def __init__( 168 self, 169 pipeline_post_processor_cls: Type[PipelinePostProcessor[_T_CONFIG, _T_INPUT, _T_OUTPUT]], 170 ): 171 self.pipeline_post_processor_cls = pipeline_post_processor_cls 172 173 def get_config_cls(self) -> Type[_T_CONFIG]: 174 return get_generic_classes(self.pipeline_post_processor_cls)[0] # type: ignore 175 176 def create( 177 self, 178 config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None, 179 ): 180 config = dyn_structure( 181 config, 182 self.get_config_cls(), 183 support_path_type=True, 184 support_none_type=True, 185 ) 186 return self.pipeline_post_processor_cls(config)
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
194class Pipeline(Generic[_T_OUTPUT]): 195 196 def __init__( 197 self, 198 steps: Sequence[PipelineStep], 199 post_processor: PipelinePostProcessor[Any, Any, _T_OUTPUT], 200 ): 201 self.steps = steps 202 self.post_processor = post_processor 203 204 @classmethod 205 def build_input(cls, state: PipelineState, input_cls: Any): 206 assert attrs.has(input_cls) 207 208 input_kwargs = {} 209 for key, key_field in attrs.fields_dict(input_cls).items(): 210 assert key_field.type 211 assert attrs.has(key_field.type) 212 value = state.get_value( 213 convert_camel_case_name_to_snake_case_name(key_field.type.__name__), 214 key_field.type, 215 ) 216 input_kwargs[key] = value 217 218 return input_cls(**input_kwargs) 219 220 def run( 221 self, 222 rng: RandomGenerator, 223 state: Optional[PipelineState] = None, 224 ) -> _T_OUTPUT: 225 if state is None: 226 state = PipelineState() 227 228 # Save the rng state. 229 state.set_value( 230 convert_camel_case_name_to_snake_case_name(PipelineRunRngStateOutput.__name__), 231 PipelineRunRngStateOutput(rng.bit_generator.state), 232 ) 233 234 # Run steps. 235 for step in self.steps: 236 # Build input. 237 step_input = self.build_input(state, step.get_input_cls()) 238 239 # Generate output. 240 step_output = step.run(step_input, rng) 241 242 # Update state. 243 step_output_cls = step.get_output_cls() 244 assert isinstance(step_output, step_output_cls) 245 assert attrs.has(step_output_cls) 246 state.set_value( 247 convert_camel_case_name_to_snake_case_name(step_output_cls.__name__), 248 step_output, 249 ) 250 251 # Post processing. 252 return self.post_processor.generate_output( 253 self.build_input(state, self.post_processor.get_input_cls()), 254 rng, 255 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
204 @classmethod 205 def build_input(cls, state: PipelineState, input_cls: Any): 206 assert attrs.has(input_cls) 207 208 input_kwargs = {} 209 for key, key_field in attrs.fields_dict(input_cls).items(): 210 assert key_field.type 211 assert attrs.has(key_field.type) 212 value = state.get_value( 213 convert_camel_case_name_to_snake_case_name(key_field.type.__name__), 214 key_field.type, 215 ) 216 input_kwargs[key] = value 217 218 return input_cls(**input_kwargs)
220 def run( 221 self, 222 rng: RandomGenerator, 223 state: Optional[PipelineState] = None, 224 ) -> _T_OUTPUT: 225 if state is None: 226 state = PipelineState() 227 228 # Save the rng state. 229 state.set_value( 230 convert_camel_case_name_to_snake_case_name(PipelineRunRngStateOutput.__name__), 231 PipelineRunRngStateOutput(rng.bit_generator.state), 232 ) 233 234 # Run steps. 235 for step in self.steps: 236 # Build input. 237 step_input = self.build_input(state, step.get_input_cls()) 238 239 # Generate output. 240 step_output = step.run(step_input, rng) 241 242 # Update state. 243 step_output_cls = step.get_output_cls() 244 assert isinstance(step_output, step_output_cls) 245 assert attrs.has(step_output_cls) 246 state.set_value( 247 convert_camel_case_name_to_snake_case_name(step_output_cls.__name__), 248 step_output, 249 ) 250 251 # Post processing. 252 return self.post_processor.generate_output( 253 self.build_input(state, self.post_processor.get_input_cls()), 254 rng, 255 )