vkit.utility.pool
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Optional, Protocol, TypeVar, Generic, Type, Any 15import multiprocessing 16import threading 17import logging 18import time 19import atexit 20 21import attrs 22from numpy.random import SeedSequence 23import psutil 24 25logger = logging.getLogger(__name__) 26 27_T_CONFIG = TypeVar('_T_CONFIG', contravariant=True) 28_T_OUTPUT = TypeVar('_T_OUTPUT', covariant=True) 29 30 31class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]): 32 33 def __init__( 34 self, 35 process_idx: int, 36 seed_sequence: SeedSequence, 37 logger: logging.Logger, 38 config: _T_CONFIG, 39 ) -> None: 40 ... 41 42 def run(self) -> _T_OUTPUT: 43 ... 44 45 46@attrs.define 47class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]): 48 inventory: int 49 num_processes: int 50 pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]] 51 pool_worker_config: _T_CONFIG 52 schedule_size_min_factor: float = 1.0 53 rng_seed: int = 133700 54 logging_level: int = logging.INFO 55 logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s' 56 logging_to_stderr: bool = False 57 timeout: Optional[int] = None 58 59 60class PoolWorkerState: 61 pool_worker: Optional[PoolWorkerProtocol] = None 62 logger: logging.Logger 63 64 65def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any): 66 with process_counter.get_lock(): 67 process_idx: int = process_counter.value # type: ignore 68 process_counter.value += 1 # type: ignore 69 70 # Overriding logger. 71 logger = multiprocessing.get_logger() 72 logger_stream_handler = logging.StreamHandler() 73 logger_formatter = logging.Formatter( 74 pool_config.logging_format.replace('PROCESS_IDX', str(process_idx)) 75 ) 76 logger_stream_handler.setFormatter(logger_formatter) 77 logger.addHandler(logger_stream_handler) 78 logger.setLevel(pool_config.logging_level) 79 80 if pool_config.logging_to_stderr: 81 multiprocessing.log_to_stderr() 82 83 logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}') 84 85 # Generate seed_sequence. 86 seed_sequences = SeedSequence(pool_config.rng_seed) 87 seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx] 88 89 # Initialize pool worker. 90 logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}') 91 pool_worker = pool_config.pool_worker_class( 92 process_idx=process_idx, 93 seed_sequence=seed_sequence, 94 config=pool_config.pool_worker_config, 95 logger=logger, 96 ) 97 PoolWorkerState.pool_worker = pool_worker 98 PoolWorkerState.logger = logger 99 logger.debug('Initialized.') 100 101 102def pool_worker_runner(_): 103 logger = PoolWorkerState.logger 104 logger.debug('Triggered.') 105 106 pool_worker = PoolWorkerState.pool_worker 107 assert pool_worker is not None 108 result = pool_worker.run() 109 logger.debug('Result generated.') 110 111 return result 112 113 114@attrs.define 115class PoolInventoryState: 116 cond: threading.Condition 117 inventory: int 118 num_scheduled: int 119 inventory_target: int 120 abort: bool 121 122 def predicate(self): 123 return self.abort or (self.inventory + self.num_scheduled < self.inventory_target) 124 125 def __repr__(self): 126 return ( 127 'PoolInventoryState(' 128 f'inventory={self.inventory}, ' 129 f'num_scheduled={self.num_scheduled}, ' 130 f'inventory_target={self.inventory_target}, ' 131 f'should_schedule={self.predicate()}' 132 ')' 133 ) 134 135 136def trigger_generator(schedule_size_min: int, state: PoolInventoryState): 137 while True: 138 with state.cond: 139 state.cond.wait_for(state.predicate) 140 if state.abort: 141 return 142 143 schedule_size = max( 144 schedule_size_min, 145 state.inventory_target - state.inventory - state.num_scheduled, 146 ) 147 logger.debug(f'state={state}, Need to schedule {schedule_size}.') 148 state.num_scheduled += schedule_size 149 for _ in range(schedule_size): 150 yield None 151 152 153class Pool(Generic[_T_CONFIG, _T_OUTPUT]): 154 155 def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]): 156 self.config = config 157 158 process_counter = multiprocessing.Value('i') 159 with process_counter.get_lock(): 160 process_counter.value = 0 # type: ignore 161 162 self.mp_pool = multiprocessing.Pool( 163 processes=self.config.num_processes, 164 initializer=pool_worker_initializer, 165 initargs=(self.config, process_counter), 166 ) 167 168 self.state = PoolInventoryState( 169 cond=threading.Condition(threading.Lock()), 170 inventory=0, 171 num_scheduled=0, 172 inventory_target=self.config.inventory, 173 abort=False, 174 ) 175 176 self.mp_pool_iter = self.mp_pool.imap_unordered( 177 pool_worker_runner, 178 trigger_generator( 179 schedule_size_min=round( 180 self.config.schedule_size_min_factor * self.config.num_processes 181 ), 182 state=self.state, 183 ), 184 ) 185 186 self.cleanup_flag = False 187 atexit.register(self.cleanup) 188 189 def cleanup(self): 190 if not self.cleanup_flag: 191 self.cleanup_flag = True 192 193 with self.state.cond: 194 self.state.abort = True 195 self.state.cond.notify() 196 197 self.mp_pool.close() 198 self.mp_pool.terminate() 199 time.sleep(1) 200 201 for proc in self.mp_pool._pool: # type: ignore 202 if not psutil.pid_exists(proc.pid): 203 continue 204 logger.warning(f'worker pid={proc.pid} still exists, killing...') 205 proc = psutil.Process(proc.pid) 206 proc.terminate() 207 try: 208 proc.wait(timeout=3) 209 except psutil.TimeoutExpired: 210 proc.kill() 211 212 # For gc. 213 self.__setattr__('mp_pool_iter', None) 214 self.__setattr__('state', None) 215 self.__setattr__('mp_pool', None) 216 self.__setattr__('config', None) 217 218 atexit.unregister(self.cleanup) 219 220 def run(self): 221 output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout) 222 223 # Update inventory. 224 with self.state.cond: 225 with self.mp_pool_iter._cond: # type: ignore 226 new_inventory = len(self.mp_pool_iter._items) # type: ignore 227 logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}') 228 229 # NOTE: We have just get one output, hence need to minus one. 230 num_scheduled_delta = new_inventory - self.state.inventory + 1 231 logger.debug(f'num_scheduled_delta: {num_scheduled_delta}') 232 assert num_scheduled_delta >= 0 233 234 new_num_scheduled = self.state.num_scheduled - num_scheduled_delta 235 logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}') 236 237 self.state.inventory = new_inventory 238 self.state.num_scheduled = new_num_scheduled 239 240 # Wake up trigger_generator. 241 self.state.cond.notify() 242 243 return output
32class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]): 33 34 def __init__( 35 self, 36 process_idx: int, 37 seed_sequence: SeedSequence, 38 logger: logging.Logger, 39 config: _T_CONFIG, 40 ) -> None: 41 ... 42 43 def run(self) -> _T_OUTPUT: 44 ...
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
48class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]): 49 inventory: int 50 num_processes: int 51 pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]] 52 pool_worker_config: _T_CONFIG 53 schedule_size_min_factor: float = 1.0 54 rng_seed: int = 133700 55 logging_level: int = logging.INFO 56 logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s' 57 logging_to_stderr: bool = False 58 timeout: Optional[int] = None
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
2def __init__(self, inventory, num_processes, pool_worker_class, pool_worker_config, schedule_size_min_factor=attr_dict['schedule_size_min_factor'].default, rng_seed=attr_dict['rng_seed'].default, logging_level=attr_dict['logging_level'].default, logging_format=attr_dict['logging_format'].default, logging_to_stderr=attr_dict['logging_to_stderr'].default, timeout=attr_dict['timeout'].default): 3 self.inventory = inventory 4 self.num_processes = num_processes 5 self.pool_worker_class = pool_worker_class 6 self.pool_worker_config = pool_worker_config 7 self.schedule_size_min_factor = schedule_size_min_factor 8 self.rng_seed = rng_seed 9 self.logging_level = logging_level 10 self.logging_format = logging_format 11 self.logging_to_stderr = logging_to_stderr 12 self.timeout = timeout
Method generated by attrs for class PoolConfig.
66def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any): 67 with process_counter.get_lock(): 68 process_idx: int = process_counter.value # type: ignore 69 process_counter.value += 1 # type: ignore 70 71 # Overriding logger. 72 logger = multiprocessing.get_logger() 73 logger_stream_handler = logging.StreamHandler() 74 logger_formatter = logging.Formatter( 75 pool_config.logging_format.replace('PROCESS_IDX', str(process_idx)) 76 ) 77 logger_stream_handler.setFormatter(logger_formatter) 78 logger.addHandler(logger_stream_handler) 79 logger.setLevel(pool_config.logging_level) 80 81 if pool_config.logging_to_stderr: 82 multiprocessing.log_to_stderr() 83 84 logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}') 85 86 # Generate seed_sequence. 87 seed_sequences = SeedSequence(pool_config.rng_seed) 88 seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx] 89 90 # Initialize pool worker. 91 logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}') 92 pool_worker = pool_config.pool_worker_class( 93 process_idx=process_idx, 94 seed_sequence=seed_sequence, 95 config=pool_config.pool_worker_config, 96 logger=logger, 97 ) 98 PoolWorkerState.pool_worker = pool_worker 99 PoolWorkerState.logger = logger 100 logger.debug('Initialized.')
116class PoolInventoryState: 117 cond: threading.Condition 118 inventory: int 119 num_scheduled: int 120 inventory_target: int 121 abort: bool 122 123 def predicate(self): 124 return self.abort or (self.inventory + self.num_scheduled < self.inventory_target) 125 126 def __repr__(self): 127 return ( 128 'PoolInventoryState(' 129 f'inventory={self.inventory}, ' 130 f'num_scheduled={self.num_scheduled}, ' 131 f'inventory_target={self.inventory_target}, ' 132 f'should_schedule={self.predicate()}' 133 ')' 134 )
2def __init__(self, cond, inventory, num_scheduled, inventory_target, abort): 3 self.cond = cond 4 self.inventory = inventory 5 self.num_scheduled = num_scheduled 6 self.inventory_target = inventory_target 7 self.abort = abort
Method generated by attrs for class PoolInventoryState.
137def trigger_generator(schedule_size_min: int, state: PoolInventoryState): 138 while True: 139 with state.cond: 140 state.cond.wait_for(state.predicate) 141 if state.abort: 142 return 143 144 schedule_size = max( 145 schedule_size_min, 146 state.inventory_target - state.inventory - state.num_scheduled, 147 ) 148 logger.debug(f'state={state}, Need to schedule {schedule_size}.') 149 state.num_scheduled += schedule_size 150 for _ in range(schedule_size): 151 yield None
154class Pool(Generic[_T_CONFIG, _T_OUTPUT]): 155 156 def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]): 157 self.config = config 158 159 process_counter = multiprocessing.Value('i') 160 with process_counter.get_lock(): 161 process_counter.value = 0 # type: ignore 162 163 self.mp_pool = multiprocessing.Pool( 164 processes=self.config.num_processes, 165 initializer=pool_worker_initializer, 166 initargs=(self.config, process_counter), 167 ) 168 169 self.state = PoolInventoryState( 170 cond=threading.Condition(threading.Lock()), 171 inventory=0, 172 num_scheduled=0, 173 inventory_target=self.config.inventory, 174 abort=False, 175 ) 176 177 self.mp_pool_iter = self.mp_pool.imap_unordered( 178 pool_worker_runner, 179 trigger_generator( 180 schedule_size_min=round( 181 self.config.schedule_size_min_factor * self.config.num_processes 182 ), 183 state=self.state, 184 ), 185 ) 186 187 self.cleanup_flag = False 188 atexit.register(self.cleanup) 189 190 def cleanup(self): 191 if not self.cleanup_flag: 192 self.cleanup_flag = True 193 194 with self.state.cond: 195 self.state.abort = True 196 self.state.cond.notify() 197 198 self.mp_pool.close() 199 self.mp_pool.terminate() 200 time.sleep(1) 201 202 for proc in self.mp_pool._pool: # type: ignore 203 if not psutil.pid_exists(proc.pid): 204 continue 205 logger.warning(f'worker pid={proc.pid} still exists, killing...') 206 proc = psutil.Process(proc.pid) 207 proc.terminate() 208 try: 209 proc.wait(timeout=3) 210 except psutil.TimeoutExpired: 211 proc.kill() 212 213 # For gc. 214 self.__setattr__('mp_pool_iter', None) 215 self.__setattr__('state', None) 216 self.__setattr__('mp_pool', None) 217 self.__setattr__('config', None) 218 219 atexit.unregister(self.cleanup) 220 221 def run(self): 222 output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout) 223 224 # Update inventory. 225 with self.state.cond: 226 with self.mp_pool_iter._cond: # type: ignore 227 new_inventory = len(self.mp_pool_iter._items) # type: ignore 228 logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}') 229 230 # NOTE: We have just get one output, hence need to minus one. 231 num_scheduled_delta = new_inventory - self.state.inventory + 1 232 logger.debug(f'num_scheduled_delta: {num_scheduled_delta}') 233 assert num_scheduled_delta >= 0 234 235 new_num_scheduled = self.state.num_scheduled - num_scheduled_delta 236 logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}') 237 238 self.state.inventory = new_inventory 239 self.state.num_scheduled = new_num_scheduled 240 241 # Wake up trigger_generator. 242 self.state.cond.notify() 243 244 return output
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
156 def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]): 157 self.config = config 158 159 process_counter = multiprocessing.Value('i') 160 with process_counter.get_lock(): 161 process_counter.value = 0 # type: ignore 162 163 self.mp_pool = multiprocessing.Pool( 164 processes=self.config.num_processes, 165 initializer=pool_worker_initializer, 166 initargs=(self.config, process_counter), 167 ) 168 169 self.state = PoolInventoryState( 170 cond=threading.Condition(threading.Lock()), 171 inventory=0, 172 num_scheduled=0, 173 inventory_target=self.config.inventory, 174 abort=False, 175 ) 176 177 self.mp_pool_iter = self.mp_pool.imap_unordered( 178 pool_worker_runner, 179 trigger_generator( 180 schedule_size_min=round( 181 self.config.schedule_size_min_factor * self.config.num_processes 182 ), 183 state=self.state, 184 ), 185 ) 186 187 self.cleanup_flag = False 188 atexit.register(self.cleanup)
190 def cleanup(self): 191 if not self.cleanup_flag: 192 self.cleanup_flag = True 193 194 with self.state.cond: 195 self.state.abort = True 196 self.state.cond.notify() 197 198 self.mp_pool.close() 199 self.mp_pool.terminate() 200 time.sleep(1) 201 202 for proc in self.mp_pool._pool: # type: ignore 203 if not psutil.pid_exists(proc.pid): 204 continue 205 logger.warning(f'worker pid={proc.pid} still exists, killing...') 206 proc = psutil.Process(proc.pid) 207 proc.terminate() 208 try: 209 proc.wait(timeout=3) 210 except psutil.TimeoutExpired: 211 proc.kill() 212 213 # For gc. 214 self.__setattr__('mp_pool_iter', None) 215 self.__setattr__('state', None) 216 self.__setattr__('mp_pool', None) 217 self.__setattr__('config', None) 218 219 atexit.unregister(self.cleanup)
221 def run(self): 222 output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout) 223 224 # Update inventory. 225 with self.state.cond: 226 with self.mp_pool_iter._cond: # type: ignore 227 new_inventory = len(self.mp_pool_iter._items) # type: ignore 228 logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}') 229 230 # NOTE: We have just get one output, hence need to minus one. 231 num_scheduled_delta = new_inventory - self.state.inventory + 1 232 logger.debug(f'num_scheduled_delta: {num_scheduled_delta}') 233 assert num_scheduled_delta >= 0 234 235 new_num_scheduled = self.state.num_scheduled - num_scheduled_delta 236 logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}') 237 238 self.state.inventory = new_inventory 239 self.state.num_scheduled = new_num_scheduled 240 241 # Wake up trigger_generator. 242 self.state.cond.notify() 243 244 return output