vkit.utility.pool

  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  2#
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  4#
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  8#
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Optional, Protocol, TypeVar, Generic, Type, Any
 15import multiprocessing
 16import multiprocessing.util
 17import multiprocessing.process
 18import threading
 19import logging
 20import time
 21import atexit
 22
 23import attrs
 24from numpy.random import SeedSequence
 25import psutil
 26
 27logger = logging.getLogger(__name__)
 28
 29_T_CONFIG = TypeVar('_T_CONFIG', contravariant=True)
 30_T_OUTPUT = TypeVar('_T_OUTPUT', covariant=True)
 31
 32
 33class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]):
 34
 35    def __init__(
 36        self,
 37        process_idx: int,
 38        seed_sequence: SeedSequence,
 39        logger: logging.Logger,
 40        config: _T_CONFIG,
 41    ) -> None:
 42        ...
 43
 44    def run(self) -> _T_OUTPUT:
 45        ...
 46
 47
 48@attrs.define
 49class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]):
 50    inventory: int
 51    num_processes: int
 52    pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]]
 53    pool_worker_config: _T_CONFIG
 54    schedule_size_min_factor: float = 1.0
 55    rng_seed: int = 133700
 56    logging_level: int = logging.INFO
 57    logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s'
 58    logging_to_stderr: bool = False
 59    timeout: Optional[int] = None
 60
 61
 62class PoolWorkerState:
 63    pool_worker: Optional[PoolWorkerProtocol] = None
 64    logger: logging.Logger
 65
 66
 67def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any):
 68    with process_counter.get_lock():
 69        process_idx: int = process_counter.value  # type: ignore
 70        process_counter.value += 1  # type: ignore
 71
 72    # Overriding logger.
 73    logger = multiprocessing.get_logger()
 74    logger_stream_handler = logging.StreamHandler()
 75    logger_formatter = logging.Formatter(
 76        pool_config.logging_format.replace('PROCESS_IDX', str(process_idx))
 77    )
 78    logger_stream_handler.setFormatter(logger_formatter)
 79    logger.addHandler(logger_stream_handler)
 80    logger.setLevel(pool_config.logging_level)
 81
 82    if pool_config.logging_to_stderr:
 83        multiprocessing.util._log_to_stderr = True  # type: ignore
 84
 85    logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}')
 86
 87    # Generate seed_sequence.
 88    seed_sequences = SeedSequence(pool_config.rng_seed)
 89    seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx]
 90
 91    # Initialize pool worker.
 92    logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}')
 93    pool_worker = pool_config.pool_worker_class(
 94        process_idx=process_idx,
 95        seed_sequence=seed_sequence,
 96        config=pool_config.pool_worker_config,
 97        logger=logger,
 98    )
 99    PoolWorkerState.pool_worker = pool_worker
100    PoolWorkerState.logger = logger
101    logger.debug('Initialized.')
102
103
104def pool_worker_runner(_):
105    logger = PoolWorkerState.logger
106    logger.debug('Triggered.')
107
108    pool_worker = PoolWorkerState.pool_worker
109    assert pool_worker is not None
110    result = pool_worker.run()
111    logger.debug('Result generated.')
112
113    return result
114
115
116@attrs.define
117class PoolInventoryState:
118    cond: threading.Condition
119    inventory: int
120    num_scheduled: int
121    inventory_target: int
122    abort: bool
123
124    def predicate(self):
125        return self.abort or (self.inventory + self.num_scheduled < self.inventory_target)
126
127    def __repr__(self):
128        return (
129            'PoolInventoryState('
130            f'inventory={self.inventory}, '
131            f'num_scheduled={self.num_scheduled}, '
132            f'inventory_target={self.inventory_target}, '
133            f'should_schedule={self.predicate()}'
134            ')'
135        )
136
137
138def trigger_generator(schedule_size_min: int, state: PoolInventoryState):
139    while True:
140        with state.cond:
141            state.cond.wait_for(state.predicate)
142            if state.abort:
143                return
144
145            schedule_size = max(
146                schedule_size_min,
147                state.inventory_target - state.inventory - state.num_scheduled,
148            )
149            logger.debug(f'state={state}, Need to schedule {schedule_size}.')
150            state.num_scheduled += schedule_size
151            for _ in range(schedule_size):
152                yield None
153
154
155class Pool(Generic[_T_CONFIG, _T_OUTPUT]):
156
157    def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]):
158        self.config = config
159
160        process_counter = multiprocessing.Value('i')
161        with process_counter.get_lock():
162            process_counter.value = 0  # type: ignore
163
164        self.mp_pool = multiprocessing.Pool(
165            processes=self.config.num_processes,
166            initializer=pool_worker_initializer,
167            initargs=(self.config, process_counter),
168        )
169
170        self.state = PoolInventoryState(
171            cond=threading.Condition(threading.Lock()),
172            inventory=0,
173            num_scheduled=0,
174            inventory_target=self.config.inventory,
175            abort=False,
176        )
177
178        self.mp_pool_iter = self.mp_pool.imap_unordered(
179            pool_worker_runner,
180            trigger_generator(
181                schedule_size_min=round(
182                    self.config.schedule_size_min_factor * self.config.num_processes
183                ),
184                state=self.state,
185            ),
186        )
187
188        self.cleanup_flag = False
189        atexit.register(self.cleanup)
190
191    def cleanup(self):
192        if not self.cleanup_flag:
193            self.cleanup_flag = True
194
195            with self.state.cond:
196                self.state.abort = True
197                self.state.cond.notify()
198
199            self.mp_pool.close()
200            self.mp_pool.terminate()
201            time.sleep(1)
202
203            for proc in self.mp_pool._pool:  # type: ignore
204                if not psutil.pid_exists(proc.pid):
205                    continue
206                logger.warning(f'worker pid={proc.pid} still exists, killing...')
207                proc = psutil.Process(proc.pid)
208                proc.terminate()
209                try:
210                    proc.wait(timeout=3)
211                except psutil.TimeoutExpired:
212                    proc.kill()
213
214            # For gc.
215            self.__setattr__('mp_pool_iter', None)
216            self.__setattr__('state', None)
217            self.__setattr__('mp_pool', None)
218            self.__setattr__('config', None)
219
220            atexit.unregister(self.cleanup)
221
222    def run(self):
223        output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout)
224
225        # Update inventory.
226        with self.state.cond:
227            with self.mp_pool_iter._cond:  # type: ignore
228                new_inventory = len(self.mp_pool_iter._items)  # type: ignore
229            logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}')
230
231            # NOTE: We have just get one output, hence need to minus one.
232            num_scheduled_delta = new_inventory - self.state.inventory + 1
233            logger.debug(f'num_scheduled_delta: {num_scheduled_delta}')
234            assert num_scheduled_delta >= 0
235
236            new_num_scheduled = self.state.num_scheduled - num_scheduled_delta
237            logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}')
238
239            self.state.inventory = new_inventory
240            self.state.num_scheduled = new_num_scheduled
241
242            # Wake up trigger_generator.
243            self.state.cond.notify()
244
245        return output
class PoolWorkerProtocol(typing.Protocol[-_T_CONFIG, +_T_OUTPUT]):
34class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]):
35
36    def __init__(
37        self,
38        process_idx: int,
39        seed_sequence: SeedSequence,
40        logger: logging.Logger,
41        config: _T_CONFIG,
42    ) -> None:
43        ...
44
45    def run(self) -> _T_OUTPUT:
46        ...

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
PoolWorkerProtocol(*args, **kwargs)
981def _no_init(self, *args, **kwargs):
982    if type(self)._is_protocol:
983        raise TypeError('Protocols cannot be instantiated')
def run(self) -> +_T_OUTPUT:
45    def run(self) -> _T_OUTPUT:
46        ...
class PoolConfig(typing.Generic[-_T_CONFIG, +_T_OUTPUT]):
50class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]):
51    inventory: int
52    num_processes: int
53    pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]]
54    pool_worker_config: _T_CONFIG
55    schedule_size_min_factor: float = 1.0
56    rng_seed: int = 133700
57    logging_level: int = logging.INFO
58    logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s'
59    logging_to_stderr: bool = False
60    timeout: Optional[int] = None

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

PoolConfig( inventory: int, num_processes: int, pool_worker_class: Type[vkit.utility.pool.PoolWorkerProtocol[-_T_CONFIG, +_T_OUTPUT]], pool_worker_config: -_T_CONFIG, schedule_size_min_factor: float = 1.0, rng_seed: int = 133700, logging_level: int = 20, logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s', logging_to_stderr: bool = False, timeout: Union[int, NoneType] = None)
 2def __init__(self, inventory, num_processes, pool_worker_class, pool_worker_config, schedule_size_min_factor=attr_dict['schedule_size_min_factor'].default, rng_seed=attr_dict['rng_seed'].default, logging_level=attr_dict['logging_level'].default, logging_format=attr_dict['logging_format'].default, logging_to_stderr=attr_dict['logging_to_stderr'].default, timeout=attr_dict['timeout'].default):
 3    self.inventory = inventory
 4    self.num_processes = num_processes
 5    self.pool_worker_class = pool_worker_class
 6    self.pool_worker_config = pool_worker_config
 7    self.schedule_size_min_factor = schedule_size_min_factor
 8    self.rng_seed = rng_seed
 9    self.logging_level = logging_level
10    self.logging_format = logging_format
11    self.logging_to_stderr = logging_to_stderr
12    self.timeout = timeout

Method generated by attrs for class PoolConfig.

class PoolWorkerState:
63class PoolWorkerState:
64    pool_worker: Optional[PoolWorkerProtocol] = None
65    logger: logging.Logger
PoolWorkerState()
def pool_worker_initializer(pool_config: vkit.utility.pool.PoolConfig, process_counter: Any):
 68def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any):
 69    with process_counter.get_lock():
 70        process_idx: int = process_counter.value  # type: ignore
 71        process_counter.value += 1  # type: ignore
 72
 73    # Overriding logger.
 74    logger = multiprocessing.get_logger()
 75    logger_stream_handler = logging.StreamHandler()
 76    logger_formatter = logging.Formatter(
 77        pool_config.logging_format.replace('PROCESS_IDX', str(process_idx))
 78    )
 79    logger_stream_handler.setFormatter(logger_formatter)
 80    logger.addHandler(logger_stream_handler)
 81    logger.setLevel(pool_config.logging_level)
 82
 83    if pool_config.logging_to_stderr:
 84        multiprocessing.util._log_to_stderr = True  # type: ignore
 85
 86    logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}')
 87
 88    # Generate seed_sequence.
 89    seed_sequences = SeedSequence(pool_config.rng_seed)
 90    seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx]
 91
 92    # Initialize pool worker.
 93    logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}')
 94    pool_worker = pool_config.pool_worker_class(
 95        process_idx=process_idx,
 96        seed_sequence=seed_sequence,
 97        config=pool_config.pool_worker_config,
 98        logger=logger,
 99    )
100    PoolWorkerState.pool_worker = pool_worker
101    PoolWorkerState.logger = logger
102    logger.debug('Initialized.')
def pool_worker_runner(_):
105def pool_worker_runner(_):
106    logger = PoolWorkerState.logger
107    logger.debug('Triggered.')
108
109    pool_worker = PoolWorkerState.pool_worker
110    assert pool_worker is not None
111    result = pool_worker.run()
112    logger.debug('Result generated.')
113
114    return result
class PoolInventoryState:
118class PoolInventoryState:
119    cond: threading.Condition
120    inventory: int
121    num_scheduled: int
122    inventory_target: int
123    abort: bool
124
125    def predicate(self):
126        return self.abort or (self.inventory + self.num_scheduled < self.inventory_target)
127
128    def __repr__(self):
129        return (
130            'PoolInventoryState('
131            f'inventory={self.inventory}, '
132            f'num_scheduled={self.num_scheduled}, '
133            f'inventory_target={self.inventory_target}, '
134            f'should_schedule={self.predicate()}'
135            ')'
136        )
PoolInventoryState( cond: threading.Condition, inventory: int, num_scheduled: int, inventory_target: int, abort: bool)
2def __init__(self, cond, inventory, num_scheduled, inventory_target, abort):
3    self.cond = cond
4    self.inventory = inventory
5    self.num_scheduled = num_scheduled
6    self.inventory_target = inventory_target
7    self.abort = abort

Method generated by attrs for class PoolInventoryState.

def predicate(self):
125    def predicate(self):
126        return self.abort or (self.inventory + self.num_scheduled < self.inventory_target)
def trigger_generator(schedule_size_min: int, state: vkit.utility.pool.PoolInventoryState):
139def trigger_generator(schedule_size_min: int, state: PoolInventoryState):
140    while True:
141        with state.cond:
142            state.cond.wait_for(state.predicate)
143            if state.abort:
144                return
145
146            schedule_size = max(
147                schedule_size_min,
148                state.inventory_target - state.inventory - state.num_scheduled,
149            )
150            logger.debug(f'state={state}, Need to schedule {schedule_size}.')
151            state.num_scheduled += schedule_size
152            for _ in range(schedule_size):
153                yield None
class Pool(typing.Generic[-_T_CONFIG, +_T_OUTPUT]):
156class Pool(Generic[_T_CONFIG, _T_OUTPUT]):
157
158    def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]):
159        self.config = config
160
161        process_counter = multiprocessing.Value('i')
162        with process_counter.get_lock():
163            process_counter.value = 0  # type: ignore
164
165        self.mp_pool = multiprocessing.Pool(
166            processes=self.config.num_processes,
167            initializer=pool_worker_initializer,
168            initargs=(self.config, process_counter),
169        )
170
171        self.state = PoolInventoryState(
172            cond=threading.Condition(threading.Lock()),
173            inventory=0,
174            num_scheduled=0,
175            inventory_target=self.config.inventory,
176            abort=False,
177        )
178
179        self.mp_pool_iter = self.mp_pool.imap_unordered(
180            pool_worker_runner,
181            trigger_generator(
182                schedule_size_min=round(
183                    self.config.schedule_size_min_factor * self.config.num_processes
184                ),
185                state=self.state,
186            ),
187        )
188
189        self.cleanup_flag = False
190        atexit.register(self.cleanup)
191
192    def cleanup(self):
193        if not self.cleanup_flag:
194            self.cleanup_flag = True
195
196            with self.state.cond:
197                self.state.abort = True
198                self.state.cond.notify()
199
200            self.mp_pool.close()
201            self.mp_pool.terminate()
202            time.sleep(1)
203
204            for proc in self.mp_pool._pool:  # type: ignore
205                if not psutil.pid_exists(proc.pid):
206                    continue
207                logger.warning(f'worker pid={proc.pid} still exists, killing...')
208                proc = psutil.Process(proc.pid)
209                proc.terminate()
210                try:
211                    proc.wait(timeout=3)
212                except psutil.TimeoutExpired:
213                    proc.kill()
214
215            # For gc.
216            self.__setattr__('mp_pool_iter', None)
217            self.__setattr__('state', None)
218            self.__setattr__('mp_pool', None)
219            self.__setattr__('config', None)
220
221            atexit.unregister(self.cleanup)
222
223    def run(self):
224        output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout)
225
226        # Update inventory.
227        with self.state.cond:
228            with self.mp_pool_iter._cond:  # type: ignore
229                new_inventory = len(self.mp_pool_iter._items)  # type: ignore
230            logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}')
231
232            # NOTE: We have just get one output, hence need to minus one.
233            num_scheduled_delta = new_inventory - self.state.inventory + 1
234            logger.debug(f'num_scheduled_delta: {num_scheduled_delta}')
235            assert num_scheduled_delta >= 0
236
237            new_num_scheduled = self.state.num_scheduled - num_scheduled_delta
238            logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}')
239
240            self.state.inventory = new_inventory
241            self.state.num_scheduled = new_num_scheduled
242
243            # Wake up trigger_generator.
244            self.state.cond.notify()
245
246        return output

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Pool(config: vkit.utility.pool.PoolConfig[-_T_CONFIG, +_T_OUTPUT])
158    def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]):
159        self.config = config
160
161        process_counter = multiprocessing.Value('i')
162        with process_counter.get_lock():
163            process_counter.value = 0  # type: ignore
164
165        self.mp_pool = multiprocessing.Pool(
166            processes=self.config.num_processes,
167            initializer=pool_worker_initializer,
168            initargs=(self.config, process_counter),
169        )
170
171        self.state = PoolInventoryState(
172            cond=threading.Condition(threading.Lock()),
173            inventory=0,
174            num_scheduled=0,
175            inventory_target=self.config.inventory,
176            abort=False,
177        )
178
179        self.mp_pool_iter = self.mp_pool.imap_unordered(
180            pool_worker_runner,
181            trigger_generator(
182                schedule_size_min=round(
183                    self.config.schedule_size_min_factor * self.config.num_processes
184                ),
185                state=self.state,
186            ),
187        )
188
189        self.cleanup_flag = False
190        atexit.register(self.cleanup)
def cleanup(self):
192    def cleanup(self):
193        if not self.cleanup_flag:
194            self.cleanup_flag = True
195
196            with self.state.cond:
197                self.state.abort = True
198                self.state.cond.notify()
199
200            self.mp_pool.close()
201            self.mp_pool.terminate()
202            time.sleep(1)
203
204            for proc in self.mp_pool._pool:  # type: ignore
205                if not psutil.pid_exists(proc.pid):
206                    continue
207                logger.warning(f'worker pid={proc.pid} still exists, killing...')
208                proc = psutil.Process(proc.pid)
209                proc.terminate()
210                try:
211                    proc.wait(timeout=3)
212                except psutil.TimeoutExpired:
213                    proc.kill()
214
215            # For gc.
216            self.__setattr__('mp_pool_iter', None)
217            self.__setattr__('state', None)
218            self.__setattr__('mp_pool', None)
219            self.__setattr__('config', None)
220
221            atexit.unregister(self.cleanup)
def run(self):
223    def run(self):
224        output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout)
225
226        # Update inventory.
227        with self.state.cond:
228            with self.mp_pool_iter._cond:  # type: ignore
229                new_inventory = len(self.mp_pool_iter._items)  # type: ignore
230            logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}')
231
232            # NOTE: We have just get one output, hence need to minus one.
233            num_scheduled_delta = new_inventory - self.state.inventory + 1
234            logger.debug(f'num_scheduled_delta: {num_scheduled_delta}')
235            assert num_scheduled_delta >= 0
236
237            new_num_scheduled = self.state.num_scheduled - num_scheduled_delta
238            logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}')
239
240            self.state.inventory = new_inventory
241            self.state.num_scheduled = new_num_scheduled
242
243            # Wake up trigger_generator.
244            self.state.cond.notify()
245
246        return output