vkit.utility.pool

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Optional, Protocol, TypeVar, Generic, Type, Any
 15import multiprocessing
 16import threading
 17import logging
 18import time
 19import atexit
 20
 21import attrs
 22from numpy.random import SeedSequence
 23import psutil
 24
 25logger = logging.getLogger(__name__)
 26
 27_T_CONFIG = TypeVar('_T_CONFIG', contravariant=True)
 28_T_OUTPUT = TypeVar('_T_OUTPUT', covariant=True)
 29
 30
 31class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]):
 32
 33    def __init__(
 34        self,
 35        process_idx: int,
 36        seed_sequence: SeedSequence,
 37        logger: logging.Logger,
 38        config: _T_CONFIG,
 39    ) -> None:
 40        ...
 41
 42    def run(self) -> _T_OUTPUT:
 43        ...
 44
 45
 46@attrs.define
 47class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]):
 48    inventory: int
 49    num_processes: int
 50    pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]]
 51    pool_worker_config: _T_CONFIG
 52    schedule_size_min_factor: float = 1.0
 53    rng_seed: int = 133700
 54    logging_level: int = logging.INFO
 55    logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s'
 56    logging_to_stderr: bool = False
 57    timeout: Optional[int] = None
 58
 59
 60class PoolWorkerState:
 61    pool_worker: Optional[PoolWorkerProtocol] = None
 62    logger: logging.Logger
 63
 64
 65def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any):
 66    with process_counter.get_lock():
 67        process_idx: int = process_counter.value  # type: ignore
 68        process_counter.value += 1  # type: ignore
 69
 70    # Overriding logger.
 71    logger = multiprocessing.get_logger()
 72    logger_stream_handler = logging.StreamHandler()
 73    logger_formatter = logging.Formatter(
 74        pool_config.logging_format.replace('PROCESS_IDX', str(process_idx))
 75    )
 76    logger_stream_handler.setFormatter(logger_formatter)
 77    logger.addHandler(logger_stream_handler)
 78    logger.setLevel(pool_config.logging_level)
 79
 80    if pool_config.logging_to_stderr:
 81        multiprocessing.log_to_stderr()
 82
 83    logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}')
 84
 85    # Generate seed_sequence.
 86    seed_sequences = SeedSequence(pool_config.rng_seed)
 87    seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx]
 88
 89    # Initialize pool worker.
 90    logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}')
 91    pool_worker = pool_config.pool_worker_class(
 92        process_idx=process_idx,
 93        seed_sequence=seed_sequence,
 94        config=pool_config.pool_worker_config,
 95        logger=logger,
 96    )
 97    PoolWorkerState.pool_worker = pool_worker
 98    PoolWorkerState.logger = logger
 99    logger.debug('Initialized.')
100
101
102def pool_worker_runner(_):
103    logger = PoolWorkerState.logger
104    logger.debug('Triggered.')
105
106    pool_worker = PoolWorkerState.pool_worker
107    assert pool_worker is not None
108    result = pool_worker.run()
109    logger.debug('Result generated.')
110
111    return result
112
113
114@attrs.define
115class PoolInventoryState:
116    cond: threading.Condition
117    inventory: int
118    num_scheduled: int
119    inventory_target: int
120    abort: bool
121
122    def predicate(self):
123        return self.abort or (self.inventory + self.num_scheduled < self.inventory_target)
124
125    def __repr__(self):
126        return (
127            'PoolInventoryState('
128            f'inventory={self.inventory}, '
129            f'num_scheduled={self.num_scheduled}, '
130            f'inventory_target={self.inventory_target}, '
131            f'should_schedule={self.predicate()}'
132            ')'
133        )
134
135
136def trigger_generator(schedule_size_min: int, state: PoolInventoryState):
137    while True:
138        with state.cond:
139            state.cond.wait_for(state.predicate)
140            if state.abort:
141                return
142
143            schedule_size = max(
144                schedule_size_min,
145                state.inventory_target - state.inventory - state.num_scheduled,
146            )
147            logger.debug(f'state={state}, Need to schedule {schedule_size}.')
148            state.num_scheduled += schedule_size
149            for _ in range(schedule_size):
150                yield None
151
152
153class Pool(Generic[_T_CONFIG, _T_OUTPUT]):
154
155    def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]):
156        self.config = config
157
158        process_counter = multiprocessing.Value('i')
159        with process_counter.get_lock():
160            process_counter.value = 0  # type: ignore
161
162        self.mp_pool = multiprocessing.Pool(
163            processes=self.config.num_processes,
164            initializer=pool_worker_initializer,
165            initargs=(self.config, process_counter),
166        )
167
168        self.state = PoolInventoryState(
169            cond=threading.Condition(threading.Lock()),
170            inventory=0,
171            num_scheduled=0,
172            inventory_target=self.config.inventory,
173            abort=False,
174        )
175
176        self.mp_pool_iter = self.mp_pool.imap_unordered(
177            pool_worker_runner,
178            trigger_generator(
179                schedule_size_min=round(
180                    self.config.schedule_size_min_factor * self.config.num_processes
181                ),
182                state=self.state,
183            ),
184        )
185
186        self.cleanup_flag = False
187        atexit.register(self.cleanup)
188
189    def cleanup(self):
190        if not self.cleanup_flag:
191            self.cleanup_flag = True
192
193            with self.state.cond:
194                self.state.abort = True
195                self.state.cond.notify()
196
197            self.mp_pool.close()
198            self.mp_pool.terminate()
199            time.sleep(1)
200
201            for proc in self.mp_pool._pool:  # type: ignore
202                if not psutil.pid_exists(proc.pid):
203                    continue
204                logger.warning(f'worker pid={proc.pid} still exists, killing...')
205                proc = psutil.Process(proc.pid)
206                proc.terminate()
207                try:
208                    proc.wait(timeout=3)
209                except psutil.TimeoutExpired:
210                    proc.kill()
211
212            # For gc.
213            self.__setattr__('mp_pool_iter', None)
214            self.__setattr__('state', None)
215            self.__setattr__('mp_pool', None)
216            self.__setattr__('config', None)
217
218            atexit.unregister(self.cleanup)
219
220    def run(self):
221        output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout)
222
223        # Update inventory.
224        with self.state.cond:
225            with self.mp_pool_iter._cond:  # type: ignore
226                new_inventory = len(self.mp_pool_iter._items)  # type: ignore
227            logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}')
228
229            # NOTE: We have just get one output, hence need to minus one.
230            num_scheduled_delta = new_inventory - self.state.inventory + 1
231            logger.debug(f'num_scheduled_delta: {num_scheduled_delta}')
232            assert num_scheduled_delta >= 0
233
234            new_num_scheduled = self.state.num_scheduled - num_scheduled_delta
235            logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}')
236
237            self.state.inventory = new_inventory
238            self.state.num_scheduled = new_num_scheduled
239
240            # Wake up trigger_generator.
241            self.state.cond.notify()
242
243        return output
class PoolWorkerProtocol(typing.Protocol[-_T_CONFIG, +_T_OUTPUT]):
32class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]):
33
34    def __init__(
35        self,
36        process_idx: int,
37        seed_sequence: SeedSequence,
38        logger: logging.Logger,
39        config: _T_CONFIG,
40    ) -> None:
41        ...
42
43    def run(self) -> _T_OUTPUT:
44        ...

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
PoolWorkerProtocol(*args, **kwargs)
981def _no_init(self, *args, **kwargs):
982    if type(self)._is_protocol:
983        raise TypeError('Protocols cannot be instantiated')
def run(self) -> +_T_OUTPUT:
43    def run(self) -> _T_OUTPUT:
44        ...
class PoolConfig(typing.Generic[-_T_CONFIG, +_T_OUTPUT]):
48class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]):
49    inventory: int
50    num_processes: int
51    pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]]
52    pool_worker_config: _T_CONFIG
53    schedule_size_min_factor: float = 1.0
54    rng_seed: int = 133700
55    logging_level: int = logging.INFO
56    logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s'
57    logging_to_stderr: bool = False
58    timeout: Optional[int] = None

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

PoolConfig( inventory: int, num_processes: int, pool_worker_class: Type[vkit.utility.pool.PoolWorkerProtocol[-_T_CONFIG, +_T_OUTPUT]], pool_worker_config: -_T_CONFIG, schedule_size_min_factor: float = 1.0, rng_seed: int = 133700, logging_level: int = 20, logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s', logging_to_stderr: bool = False, timeout: Union[int, NoneType] = None)
 2def __init__(self, inventory, num_processes, pool_worker_class, pool_worker_config, schedule_size_min_factor=attr_dict['schedule_size_min_factor'].default, rng_seed=attr_dict['rng_seed'].default, logging_level=attr_dict['logging_level'].default, logging_format=attr_dict['logging_format'].default, logging_to_stderr=attr_dict['logging_to_stderr'].default, timeout=attr_dict['timeout'].default):
 3    self.inventory = inventory
 4    self.num_processes = num_processes
 5    self.pool_worker_class = pool_worker_class
 6    self.pool_worker_config = pool_worker_config
 7    self.schedule_size_min_factor = schedule_size_min_factor
 8    self.rng_seed = rng_seed
 9    self.logging_level = logging_level
10    self.logging_format = logging_format
11    self.logging_to_stderr = logging_to_stderr
12    self.timeout = timeout

Method generated by attrs for class PoolConfig.

class PoolWorkerState:
61class PoolWorkerState:
62    pool_worker: Optional[PoolWorkerProtocol] = None
63    logger: logging.Logger
def pool_worker_initializer(pool_config: vkit.utility.pool.PoolConfig, process_counter: Any):
 66def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any):
 67    with process_counter.get_lock():
 68        process_idx: int = process_counter.value  # type: ignore
 69        process_counter.value += 1  # type: ignore
 70
 71    # Overriding logger.
 72    logger = multiprocessing.get_logger()
 73    logger_stream_handler = logging.StreamHandler()
 74    logger_formatter = logging.Formatter(
 75        pool_config.logging_format.replace('PROCESS_IDX', str(process_idx))
 76    )
 77    logger_stream_handler.setFormatter(logger_formatter)
 78    logger.addHandler(logger_stream_handler)
 79    logger.setLevel(pool_config.logging_level)
 80
 81    if pool_config.logging_to_stderr:
 82        multiprocessing.log_to_stderr()
 83
 84    logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}')
 85
 86    # Generate seed_sequence.
 87    seed_sequences = SeedSequence(pool_config.rng_seed)
 88    seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx]
 89
 90    # Initialize pool worker.
 91    logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}')
 92    pool_worker = pool_config.pool_worker_class(
 93        process_idx=process_idx,
 94        seed_sequence=seed_sequence,
 95        config=pool_config.pool_worker_config,
 96        logger=logger,
 97    )
 98    PoolWorkerState.pool_worker = pool_worker
 99    PoolWorkerState.logger = logger
100    logger.debug('Initialized.')
def pool_worker_runner(_):
103def pool_worker_runner(_):
104    logger = PoolWorkerState.logger
105    logger.debug('Triggered.')
106
107    pool_worker = PoolWorkerState.pool_worker
108    assert pool_worker is not None
109    result = pool_worker.run()
110    logger.debug('Result generated.')
111
112    return result
class PoolInventoryState:
116class PoolInventoryState:
117    cond: threading.Condition
118    inventory: int
119    num_scheduled: int
120    inventory_target: int
121    abort: bool
122
123    def predicate(self):
124        return self.abort or (self.inventory + self.num_scheduled < self.inventory_target)
125
126    def __repr__(self):
127        return (
128            'PoolInventoryState('
129            f'inventory={self.inventory}, '
130            f'num_scheduled={self.num_scheduled}, '
131            f'inventory_target={self.inventory_target}, '
132            f'should_schedule={self.predicate()}'
133            ')'
134        )
PoolInventoryState( cond: threading.Condition, inventory: int, num_scheduled: int, inventory_target: int, abort: bool)
2def __init__(self, cond, inventory, num_scheduled, inventory_target, abort):
3    self.cond = cond
4    self.inventory = inventory
5    self.num_scheduled = num_scheduled
6    self.inventory_target = inventory_target
7    self.abort = abort

Method generated by attrs for class PoolInventoryState.

def predicate(self):
123    def predicate(self):
124        return self.abort or (self.inventory + self.num_scheduled < self.inventory_target)
def trigger_generator(schedule_size_min: int, state: vkit.utility.pool.PoolInventoryState):
137def trigger_generator(schedule_size_min: int, state: PoolInventoryState):
138    while True:
139        with state.cond:
140            state.cond.wait_for(state.predicate)
141            if state.abort:
142                return
143
144            schedule_size = max(
145                schedule_size_min,
146                state.inventory_target - state.inventory - state.num_scheduled,
147            )
148            logger.debug(f'state={state}, Need to schedule {schedule_size}.')
149            state.num_scheduled += schedule_size
150            for _ in range(schedule_size):
151                yield None
class Pool(typing.Generic[-_T_CONFIG, +_T_OUTPUT]):
154class Pool(Generic[_T_CONFIG, _T_OUTPUT]):
155
156    def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]):
157        self.config = config
158
159        process_counter = multiprocessing.Value('i')
160        with process_counter.get_lock():
161            process_counter.value = 0  # type: ignore
162
163        self.mp_pool = multiprocessing.Pool(
164            processes=self.config.num_processes,
165            initializer=pool_worker_initializer,
166            initargs=(self.config, process_counter),
167        )
168
169        self.state = PoolInventoryState(
170            cond=threading.Condition(threading.Lock()),
171            inventory=0,
172            num_scheduled=0,
173            inventory_target=self.config.inventory,
174            abort=False,
175        )
176
177        self.mp_pool_iter = self.mp_pool.imap_unordered(
178            pool_worker_runner,
179            trigger_generator(
180                schedule_size_min=round(
181                    self.config.schedule_size_min_factor * self.config.num_processes
182                ),
183                state=self.state,
184            ),
185        )
186
187        self.cleanup_flag = False
188        atexit.register(self.cleanup)
189
190    def cleanup(self):
191        if not self.cleanup_flag:
192            self.cleanup_flag = True
193
194            with self.state.cond:
195                self.state.abort = True
196                self.state.cond.notify()
197
198            self.mp_pool.close()
199            self.mp_pool.terminate()
200            time.sleep(1)
201
202            for proc in self.mp_pool._pool:  # type: ignore
203                if not psutil.pid_exists(proc.pid):
204                    continue
205                logger.warning(f'worker pid={proc.pid} still exists, killing...')
206                proc = psutil.Process(proc.pid)
207                proc.terminate()
208                try:
209                    proc.wait(timeout=3)
210                except psutil.TimeoutExpired:
211                    proc.kill()
212
213            # For gc.
214            self.__setattr__('mp_pool_iter', None)
215            self.__setattr__('state', None)
216            self.__setattr__('mp_pool', None)
217            self.__setattr__('config', None)
218
219            atexit.unregister(self.cleanup)
220
221    def run(self):
222        output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout)
223
224        # Update inventory.
225        with self.state.cond:
226            with self.mp_pool_iter._cond:  # type: ignore
227                new_inventory = len(self.mp_pool_iter._items)  # type: ignore
228            logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}')
229
230            # NOTE: We have just get one output, hence need to minus one.
231            num_scheduled_delta = new_inventory - self.state.inventory + 1
232            logger.debug(f'num_scheduled_delta: {num_scheduled_delta}')
233            assert num_scheduled_delta >= 0
234
235            new_num_scheduled = self.state.num_scheduled - num_scheduled_delta
236            logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}')
237
238            self.state.inventory = new_inventory
239            self.state.num_scheduled = new_num_scheduled
240
241            # Wake up trigger_generator.
242            self.state.cond.notify()
243
244        return output

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Pool(config: vkit.utility.pool.PoolConfig[-_T_CONFIG, +_T_OUTPUT])
156    def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]):
157        self.config = config
158
159        process_counter = multiprocessing.Value('i')
160        with process_counter.get_lock():
161            process_counter.value = 0  # type: ignore
162
163        self.mp_pool = multiprocessing.Pool(
164            processes=self.config.num_processes,
165            initializer=pool_worker_initializer,
166            initargs=(self.config, process_counter),
167        )
168
169        self.state = PoolInventoryState(
170            cond=threading.Condition(threading.Lock()),
171            inventory=0,
172            num_scheduled=0,
173            inventory_target=self.config.inventory,
174            abort=False,
175        )
176
177        self.mp_pool_iter = self.mp_pool.imap_unordered(
178            pool_worker_runner,
179            trigger_generator(
180                schedule_size_min=round(
181                    self.config.schedule_size_min_factor * self.config.num_processes
182                ),
183                state=self.state,
184            ),
185        )
186
187        self.cleanup_flag = False
188        atexit.register(self.cleanup)
def cleanup(self):
190    def cleanup(self):
191        if not self.cleanup_flag:
192            self.cleanup_flag = True
193
194            with self.state.cond:
195                self.state.abort = True
196                self.state.cond.notify()
197
198            self.mp_pool.close()
199            self.mp_pool.terminate()
200            time.sleep(1)
201
202            for proc in self.mp_pool._pool:  # type: ignore
203                if not psutil.pid_exists(proc.pid):
204                    continue
205                logger.warning(f'worker pid={proc.pid} still exists, killing...')
206                proc = psutil.Process(proc.pid)
207                proc.terminate()
208                try:
209                    proc.wait(timeout=3)
210                except psutil.TimeoutExpired:
211                    proc.kill()
212
213            # For gc.
214            self.__setattr__('mp_pool_iter', None)
215            self.__setattr__('state', None)
216            self.__setattr__('mp_pool', None)
217            self.__setattr__('config', None)
218
219            atexit.unregister(self.cleanup)
def run(self):
221    def run(self):
222        output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout)
223
224        # Update inventory.
225        with self.state.cond:
226            with self.mp_pool_iter._cond:  # type: ignore
227                new_inventory = len(self.mp_pool_iter._items)  # type: ignore
228            logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}')
229
230            # NOTE: We have just get one output, hence need to minus one.
231            num_scheduled_delta = new_inventory - self.state.inventory + 1
232            logger.debug(f'num_scheduled_delta: {num_scheduled_delta}')
233            assert num_scheduled_delta >= 0
234
235            new_num_scheduled = self.state.num_scheduled - num_scheduled_delta
236            logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}')
237
238            self.state.inventory = new_inventory
239            self.state.num_scheduled = new_num_scheduled
240
241            # Wake up trigger_generator.
242            self.state.cond.notify()
243
244        return output