vkit.pipeline.interface

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import (
 15    cast,
 16    Generic,
 17    TypeVar,
 18    Type,
 19    Dict,
 20    Mapping,
 21    Any,
 22    Sequence,
 23    Optional,
 24    Union,
 25    List,
 26)
 27
 28import attrs
 29from numpy.random import Generator as RandomGenerator
 30
 31from vkit.utility import (
 32    is_path_type,
 33    read_json_file,
 34    dyn_structure,
 35    convert_camel_case_name_to_snake_case_name,
 36    get_generic_classes,
 37    PathType,
 38)
 39
 40_T_VALUE = TypeVar('_T_VALUE')
 41_T_CONFIG = TypeVar('_T_CONFIG')
 42_T_INPUT = TypeVar('_T_INPUT')
 43_T_OUTPUT = TypeVar('_T_OUTPUT')
 44
 45
 46@attrs.define
 47class PipelineState:
 48    key_to_value: Dict[str, Any] = attrs.field(factory=dict)
 49
 50    def get_value(self, key: str, value_cls: Type[_T_VALUE]) -> _T_VALUE:
 51        if key not in self.key_to_value:
 52            raise KeyError(f'key={key} not found.')
 53        value = self.key_to_value[key]
 54        if not isinstance(value, value_cls):
 55            raise TypeError(f'key={key}, value type={type(value)} is not instance of {value_cls}')
 56        return value
 57
 58    def set_value(self, key: str, value: Any, override: bool = False):
 59        if key in self.key_to_value and not override:
 60            raise KeyError(f'key={key} exists but override is not set.')
 61        self.key_to_value[key] = value
 62
 63
 64class PipelineStep(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
 65
 66    @classmethod
 67    def get_config_cls(cls) -> Type[_T_CONFIG]:
 68        return get_generic_classes(cls)[0]  # type: ignore
 69
 70    @classmethod
 71    def get_input_cls(cls) -> Type[_T_INPUT]:
 72        return get_generic_classes(cls)[1]  # type: ignore
 73
 74    @classmethod
 75    def get_output_cls(cls) -> Type[_T_OUTPUT]:
 76        return get_generic_classes(cls)[2]  # type: ignore
 77
 78    _cached_name: str = ''
 79
 80    @classmethod
 81    def get_name(cls):
 82        if not cls._cached_name:
 83            cls._cached_name = convert_camel_case_name_to_snake_case_name(cls.__name__)
 84        return cls._cached_name
 85
 86    def __init__(self, config: _T_CONFIG):
 87        self.config = config
 88
 89    def run(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT:
 90        raise NotImplementedError()
 91
 92
 93class PipelineStepFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
 94
 95    def __init__(self, pipeline_step_cls: Type[PipelineStep[_T_CONFIG, _T_INPUT, _T_OUTPUT]]):
 96        self.pipeline_step_cls = pipeline_step_cls
 97
 98    @property
 99    def name(self):
100        return self.pipeline_step_cls.get_name()
101
102    def get_config_cls(self):
103        return self.pipeline_step_cls.get_config_cls()
104
105    def create(
106        self,
107        config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None,
108    ):
109        config = dyn_structure(
110            config,
111            self.get_config_cls(),
112            support_path_type=True,
113            support_none_type=True,
114        )
115        return self.pipeline_step_cls(config)
116
117
118class PipelineStepCollectionFactory:
119
120    def __init__(self):
121        self.name_to_step_factory: Dict[str, PipelineStepFactory] = {}
122
123    def register_step_factories(
124        self,
125        namespace: str,
126        step_factories: Sequence[PipelineStepFactory],
127    ):
128        for step_factory in step_factories:
129            name = f'{namespace}.{step_factory.name}'
130            assert name not in self.name_to_step_factory
131            self.name_to_step_factory[name] = step_factory
132
133    def create(
134        self,
135        step_configs: Union[Sequence[Mapping[str, Any]], PathType],
136    ):
137        if is_path_type(step_configs):
138            step_configs = read_json_file(step_configs)  # type: ignore
139        step_configs = cast(Sequence[Mapping[str, Any]], step_configs)
140
141        steps: List[PipelineStep] = []
142        for step_config in step_configs:
143            name = step_config['name']
144            if name not in self.name_to_step_factory:
145                raise KeyError(f'name={name} not found.')
146            step_factory = self.name_to_step_factory[name]
147            steps.append(step_factory.create(step_config.get('config')))
148        return steps
149
150
151class PipelinePostProcessor(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
152
153    def __init__(self, config: _T_CONFIG):
154        self.config = config
155
156    @classmethod
157    def get_input_cls(cls) -> Type[_T_INPUT]:
158        return get_generic_classes(cls)[1]  # type: ignore
159
160    def generate_output(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT:
161        raise NotImplementedError()
162
163
164class PipelinePostProcessorFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
165
166    def __init__(
167        self,
168        pipeline_post_processor_cls: Type[PipelinePostProcessor[_T_CONFIG, _T_INPUT, _T_OUTPUT]],
169    ):
170        self.pipeline_post_processor_cls = pipeline_post_processor_cls
171
172    def get_config_cls(self) -> Type[_T_CONFIG]:
173        return get_generic_classes(self.pipeline_post_processor_cls)[0]  # type: ignore
174
175    def create(
176        self,
177        config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None,
178    ):
179        config = dyn_structure(
180            config,
181            self.get_config_cls(),
182            support_path_type=True,
183            support_none_type=True,
184        )
185        return self.pipeline_post_processor_cls(config)
186
187
188@attrs.define
189class PipelineRunRngStateOutput:
190    rng_state: Mapping[str, Any]
191
192
193class Pipeline(Generic[_T_OUTPUT]):
194
195    def __init__(
196        self,
197        steps: Sequence[PipelineStep],
198        post_processor: PipelinePostProcessor[Any, Any, _T_OUTPUT],
199    ):
200        self.steps = steps
201        self.post_processor = post_processor
202
203    @classmethod
204    def build_input(cls, state: PipelineState, input_cls: Any):
205        assert attrs.has(input_cls)
206
207        input_kwargs = {}
208        for key, key_field in attrs.fields_dict(input_cls).items():
209            assert key_field.type
210            assert attrs.has(key_field.type)
211            value = state.get_value(
212                convert_camel_case_name_to_snake_case_name(key_field.type.__name__),
213                key_field.type,
214            )
215            input_kwargs[key] = value
216
217        return input_cls(**input_kwargs)
218
219    def run(
220        self,
221        rng: RandomGenerator,
222        state: Optional[PipelineState] = None,
223    ) -> _T_OUTPUT:
224        if state is None:
225            state = PipelineState()
226
227        # Save the rng state.
228        state.set_value(
229            convert_camel_case_name_to_snake_case_name(PipelineRunRngStateOutput.__name__),
230            PipelineRunRngStateOutput(rng.bit_generator.state),
231        )
232
233        # Run steps.
234        for step in self.steps:
235            # Build input.
236            step_input = self.build_input(state, step.get_input_cls())
237
238            # Generate output.
239            step_output = step.run(step_input, rng)
240
241            # Update state.
242            step_output_cls = step.get_output_cls()
243            assert isinstance(step_output, step_output_cls)
244            assert attrs.has(step_output_cls)
245            state.set_value(
246                convert_camel_case_name_to_snake_case_name(step_output_cls.__name__),
247                step_output,
248            )
249
250        # Post processing.
251        return self.post_processor.generate_output(
252            self.build_input(state, self.post_processor.get_input_cls()),
253            rng,
254        )
class PipelineState:
48class PipelineState:
49    key_to_value: Dict[str, Any] = attrs.field(factory=dict)
50
51    def get_value(self, key: str, value_cls: Type[_T_VALUE]) -> _T_VALUE:
52        if key not in self.key_to_value:
53            raise KeyError(f'key={key} not found.')
54        value = self.key_to_value[key]
55        if not isinstance(value, value_cls):
56            raise TypeError(f'key={key}, value type={type(value)} is not instance of {value_cls}')
57        return value
58
59    def set_value(self, key: str, value: Any, override: bool = False):
60        if key in self.key_to_value and not override:
61            raise KeyError(f'key={key} exists but override is not set.')
62        self.key_to_value[key] = value
PipelineState(key_to_value: Dict[str, Any] = NOTHING)
2def __init__(self, key_to_value=NOTHING):
3    if key_to_value is not NOTHING:
4        self.key_to_value = key_to_value
5    else:
6        self.key_to_value = __attr_factory_key_to_value()

Method generated by attrs for class PipelineState.

def get_value(self, key: str, value_cls: Type[~_T_VALUE]) -> ~_T_VALUE:
51    def get_value(self, key: str, value_cls: Type[_T_VALUE]) -> _T_VALUE:
52        if key not in self.key_to_value:
53            raise KeyError(f'key={key} not found.')
54        value = self.key_to_value[key]
55        if not isinstance(value, value_cls):
56            raise TypeError(f'key={key}, value type={type(value)} is not instance of {value_cls}')
57        return value
def set_value(self, key: str, value: Any, override: bool = False):
59    def set_value(self, key: str, value: Any, override: bool = False):
60        if key in self.key_to_value and not override:
61            raise KeyError(f'key={key} exists but override is not set.')
62        self.key_to_value[key] = value
class PipelineStep(typing.Generic[~_T_CONFIG, ~_T_INPUT, ~_T_OUTPUT]):
65class PipelineStep(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
66
67    @classmethod
68    def get_config_cls(cls) -> Type[_T_CONFIG]:
69        return get_generic_classes(cls)[0]  # type: ignore
70
71    @classmethod
72    def get_input_cls(cls) -> Type[_T_INPUT]:
73        return get_generic_classes(cls)[1]  # type: ignore
74
75    @classmethod
76    def get_output_cls(cls) -> Type[_T_OUTPUT]:
77        return get_generic_classes(cls)[2]  # type: ignore
78
79    _cached_name: str = ''
80
81    @classmethod
82    def get_name(cls):
83        if not cls._cached_name:
84            cls._cached_name = convert_camel_case_name_to_snake_case_name(cls.__name__)
85        return cls._cached_name
86
87    def __init__(self, config: _T_CONFIG):
88        self.config = config
89
90    def run(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT:
91        raise NotImplementedError()

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

PipelineStep(config: ~_T_CONFIG)
87    def __init__(self, config: _T_CONFIG):
88        self.config = config
@classmethod
def get_config_cls(cls) -> Type[~_T_CONFIG]:
67    @classmethod
68    def get_config_cls(cls) -> Type[_T_CONFIG]:
69        return get_generic_classes(cls)[0]  # type: ignore
@classmethod
def get_input_cls(cls) -> Type[~_T_INPUT]:
71    @classmethod
72    def get_input_cls(cls) -> Type[_T_INPUT]:
73        return get_generic_classes(cls)[1]  # type: ignore
@classmethod
def get_output_cls(cls) -> Type[~_T_OUTPUT]:
75    @classmethod
76    def get_output_cls(cls) -> Type[_T_OUTPUT]:
77        return get_generic_classes(cls)[2]  # type: ignore
@classmethod
def get_name(cls):
81    @classmethod
82    def get_name(cls):
83        if not cls._cached_name:
84            cls._cached_name = convert_camel_case_name_to_snake_case_name(cls.__name__)
85        return cls._cached_name
def run( self, input: ~_T_INPUT, rng: numpy.random._generator.Generator) -> ~_T_OUTPUT:
90    def run(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT:
91        raise NotImplementedError()
class PipelineStepFactory(typing.Generic[~_T_CONFIG, ~_T_INPUT, ~_T_OUTPUT]):
 94class PipelineStepFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
 95
 96    def __init__(self, pipeline_step_cls: Type[PipelineStep[_T_CONFIG, _T_INPUT, _T_OUTPUT]]):
 97        self.pipeline_step_cls = pipeline_step_cls
 98
 99    @property
100    def name(self):
101        return self.pipeline_step_cls.get_name()
102
103    def get_config_cls(self):
104        return self.pipeline_step_cls.get_config_cls()
105
106    def create(
107        self,
108        config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None,
109    ):
110        config = dyn_structure(
111            config,
112            self.get_config_cls(),
113            support_path_type=True,
114            support_none_type=True,
115        )
116        return self.pipeline_step_cls(config)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

PipelineStepFactory( pipeline_step_cls: Type[vkit.pipeline.interface.PipelineStep[~_T_CONFIG, ~_T_INPUT, ~_T_OUTPUT]])
96    def __init__(self, pipeline_step_cls: Type[PipelineStep[_T_CONFIG, _T_INPUT, _T_OUTPUT]]):
97        self.pipeline_step_cls = pipeline_step_cls
def get_config_cls(self):
103    def get_config_cls(self):
104        return self.pipeline_step_cls.get_config_cls()
def create( self, config: Union[Mapping[str, Any], str, os.PathLike, ~_T_CONFIG, NoneType] = None):
106    def create(
107        self,
108        config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None,
109    ):
110        config = dyn_structure(
111            config,
112            self.get_config_cls(),
113            support_path_type=True,
114            support_none_type=True,
115        )
116        return self.pipeline_step_cls(config)
class PipelineStepCollectionFactory:
119class PipelineStepCollectionFactory:
120
121    def __init__(self):
122        self.name_to_step_factory: Dict[str, PipelineStepFactory] = {}
123
124    def register_step_factories(
125        self,
126        namespace: str,
127        step_factories: Sequence[PipelineStepFactory],
128    ):
129        for step_factory in step_factories:
130            name = f'{namespace}.{step_factory.name}'
131            assert name not in self.name_to_step_factory
132            self.name_to_step_factory[name] = step_factory
133
134    def create(
135        self,
136        step_configs: Union[Sequence[Mapping[str, Any]], PathType],
137    ):
138        if is_path_type(step_configs):
139            step_configs = read_json_file(step_configs)  # type: ignore
140        step_configs = cast(Sequence[Mapping[str, Any]], step_configs)
141
142        steps: List[PipelineStep] = []
143        for step_config in step_configs:
144            name = step_config['name']
145            if name not in self.name_to_step_factory:
146                raise KeyError(f'name={name} not found.')
147            step_factory = self.name_to_step_factory[name]
148            steps.append(step_factory.create(step_config.get('config')))
149        return steps
PipelineStepCollectionFactory()
121    def __init__(self):
122        self.name_to_step_factory: Dict[str, PipelineStepFactory] = {}
def register_step_factories( self, namespace: str, step_factories: Sequence[vkit.pipeline.interface.PipelineStepFactory]):
124    def register_step_factories(
125        self,
126        namespace: str,
127        step_factories: Sequence[PipelineStepFactory],
128    ):
129        for step_factory in step_factories:
130            name = f'{namespace}.{step_factory.name}'
131            assert name not in self.name_to_step_factory
132            self.name_to_step_factory[name] = step_factory
def create( self, step_configs: Union[Sequence[Mapping[str, Any]], str, os.PathLike]):
134    def create(
135        self,
136        step_configs: Union[Sequence[Mapping[str, Any]], PathType],
137    ):
138        if is_path_type(step_configs):
139            step_configs = read_json_file(step_configs)  # type: ignore
140        step_configs = cast(Sequence[Mapping[str, Any]], step_configs)
141
142        steps: List[PipelineStep] = []
143        for step_config in step_configs:
144            name = step_config['name']
145            if name not in self.name_to_step_factory:
146                raise KeyError(f'name={name} not found.')
147            step_factory = self.name_to_step_factory[name]
148            steps.append(step_factory.create(step_config.get('config')))
149        return steps
class PipelinePostProcessor(typing.Generic[~_T_CONFIG, ~_T_INPUT, ~_T_OUTPUT]):
152class PipelinePostProcessor(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
153
154    def __init__(self, config: _T_CONFIG):
155        self.config = config
156
157    @classmethod
158    def get_input_cls(cls) -> Type[_T_INPUT]:
159        return get_generic_classes(cls)[1]  # type: ignore
160
161    def generate_output(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT:
162        raise NotImplementedError()

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

PipelinePostProcessor(config: ~_T_CONFIG)
154    def __init__(self, config: _T_CONFIG):
155        self.config = config
@classmethod
def get_input_cls(cls) -> Type[~_T_INPUT]:
157    @classmethod
158    def get_input_cls(cls) -> Type[_T_INPUT]:
159        return get_generic_classes(cls)[1]  # type: ignore
def generate_output( self, input: ~_T_INPUT, rng: numpy.random._generator.Generator) -> ~_T_OUTPUT:
161    def generate_output(self, input: _T_INPUT, rng: RandomGenerator) -> _T_OUTPUT:
162        raise NotImplementedError()
class PipelinePostProcessorFactory(typing.Generic[~_T_CONFIG, ~_T_INPUT, ~_T_OUTPUT]):
165class PipelinePostProcessorFactory(Generic[_T_CONFIG, _T_INPUT, _T_OUTPUT]):
166
167    def __init__(
168        self,
169        pipeline_post_processor_cls: Type[PipelinePostProcessor[_T_CONFIG, _T_INPUT, _T_OUTPUT]],
170    ):
171        self.pipeline_post_processor_cls = pipeline_post_processor_cls
172
173    def get_config_cls(self) -> Type[_T_CONFIG]:
174        return get_generic_classes(self.pipeline_post_processor_cls)[0]  # type: ignore
175
176    def create(
177        self,
178        config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None,
179    ):
180        config = dyn_structure(
181            config,
182            self.get_config_cls(),
183            support_path_type=True,
184            support_none_type=True,
185        )
186        return self.pipeline_post_processor_cls(config)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

PipelinePostProcessorFactory( pipeline_post_processor_cls: Type[vkit.pipeline.interface.PipelinePostProcessor[~_T_CONFIG, ~_T_INPUT, ~_T_OUTPUT]])
167    def __init__(
168        self,
169        pipeline_post_processor_cls: Type[PipelinePostProcessor[_T_CONFIG, _T_INPUT, _T_OUTPUT]],
170    ):
171        self.pipeline_post_processor_cls = pipeline_post_processor_cls
def get_config_cls(self) -> Type[~_T_CONFIG]:
173    def get_config_cls(self) -> Type[_T_CONFIG]:
174        return get_generic_classes(self.pipeline_post_processor_cls)[0]  # type: ignore
def create( self, config: Union[Mapping[str, Any], str, os.PathLike, ~_T_CONFIG, NoneType] = None):
176    def create(
177        self,
178        config: Optional[Union[Mapping[str, Any], PathType, _T_CONFIG]] = None,
179    ):
180        config = dyn_structure(
181            config,
182            self.get_config_cls(),
183            support_path_type=True,
184            support_none_type=True,
185        )
186        return self.pipeline_post_processor_cls(config)
class PipelineRunRngStateOutput:
190class PipelineRunRngStateOutput:
191    rng_state: Mapping[str, Any]
PipelineRunRngStateOutput(rng_state: Mapping[str, Any])
2def __init__(self, rng_state):
3    self.rng_state = rng_state

Method generated by attrs for class PipelineRunRngStateOutput.

class Pipeline(typing.Generic[~_T_OUTPUT]):
194class Pipeline(Generic[_T_OUTPUT]):
195
196    def __init__(
197        self,
198        steps: Sequence[PipelineStep],
199        post_processor: PipelinePostProcessor[Any, Any, _T_OUTPUT],
200    ):
201        self.steps = steps
202        self.post_processor = post_processor
203
204    @classmethod
205    def build_input(cls, state: PipelineState, input_cls: Any):
206        assert attrs.has(input_cls)
207
208        input_kwargs = {}
209        for key, key_field in attrs.fields_dict(input_cls).items():
210            assert key_field.type
211            assert attrs.has(key_field.type)
212            value = state.get_value(
213                convert_camel_case_name_to_snake_case_name(key_field.type.__name__),
214                key_field.type,
215            )
216            input_kwargs[key] = value
217
218        return input_cls(**input_kwargs)
219
220    def run(
221        self,
222        rng: RandomGenerator,
223        state: Optional[PipelineState] = None,
224    ) -> _T_OUTPUT:
225        if state is None:
226            state = PipelineState()
227
228        # Save the rng state.
229        state.set_value(
230            convert_camel_case_name_to_snake_case_name(PipelineRunRngStateOutput.__name__),
231            PipelineRunRngStateOutput(rng.bit_generator.state),
232        )
233
234        # Run steps.
235        for step in self.steps:
236            # Build input.
237            step_input = self.build_input(state, step.get_input_cls())
238
239            # Generate output.
240            step_output = step.run(step_input, rng)
241
242            # Update state.
243            step_output_cls = step.get_output_cls()
244            assert isinstance(step_output, step_output_cls)
245            assert attrs.has(step_output_cls)
246            state.set_value(
247                convert_camel_case_name_to_snake_case_name(step_output_cls.__name__),
248                step_output,
249            )
250
251        # Post processing.
252        return self.post_processor.generate_output(
253            self.build_input(state, self.post_processor.get_input_cls()),
254            rng,
255        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Pipeline( steps: Sequence[vkit.pipeline.interface.PipelineStep], post_processor: vkit.pipeline.interface.PipelinePostProcessor[typing.Any, typing.Any, ~_T_OUTPUT])
196    def __init__(
197        self,
198        steps: Sequence[PipelineStep],
199        post_processor: PipelinePostProcessor[Any, Any, _T_OUTPUT],
200    ):
201        self.steps = steps
202        self.post_processor = post_processor
@classmethod
def build_input(cls, state: vkit.pipeline.interface.PipelineState, input_cls: Any):
204    @classmethod
205    def build_input(cls, state: PipelineState, input_cls: Any):
206        assert attrs.has(input_cls)
207
208        input_kwargs = {}
209        for key, key_field in attrs.fields_dict(input_cls).items():
210            assert key_field.type
211            assert attrs.has(key_field.type)
212            value = state.get_value(
213                convert_camel_case_name_to_snake_case_name(key_field.type.__name__),
214                key_field.type,
215            )
216            input_kwargs[key] = value
217
218        return input_cls(**input_kwargs)
def run( self, rng: numpy.random._generator.Generator, state: Union[vkit.pipeline.interface.PipelineState, NoneType] = None) -> ~_T_OUTPUT:
220    def run(
221        self,
222        rng: RandomGenerator,
223        state: Optional[PipelineState] = None,
224    ) -> _T_OUTPUT:
225        if state is None:
226            state = PipelineState()
227
228        # Save the rng state.
229        state.set_value(
230            convert_camel_case_name_to_snake_case_name(PipelineRunRngStateOutput.__name__),
231            PipelineRunRngStateOutput(rng.bit_generator.state),
232        )
233
234        # Run steps.
235        for step in self.steps:
236            # Build input.
237            step_input = self.build_input(state, step.get_input_cls())
238
239            # Generate output.
240            step_output = step.run(step_input, rng)
241
242            # Update state.
243            step_output_cls = step.get_output_cls()
244            assert isinstance(step_output, step_output_cls)
245            assert attrs.has(step_output_cls)
246            state.set_value(
247                convert_camel_case_name_to_snake_case_name(step_output_cls.__name__),
248                step_output,
249            )
250
251        # Post processing.
252        return self.post_processor.generate_output(
253            self.build_input(state, self.post_processor.get_input_cls()),
254            rng,
255        )