vkit.utility.pool
1# Copyright 2022 vkit-x Administrator. All Rights Reserved. 2# 3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses. 4# 5# The commercial license gives you the full rights to create and distribute software 6# on your own terms without any SSPL license obligations. For more information, 7# please see the "LICENSE_COMMERCIAL.txt" file. 8# 9# This project is also available under Server Side Public License (SSPL). 10# The SSPL licensing is ideal for use cases such as open source projects with 11# SSPL distribution, student/academic purposes, hobby projects, internal research 12# projects without external distribution, or other projects where all SSPL 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file. 14from typing import Optional, Protocol, TypeVar, Generic, Type, Any 15import multiprocessing 16import multiprocessing.util 17import multiprocessing.process 18import threading 19import logging 20import time 21import atexit 22 23import attrs 24from numpy.random import SeedSequence 25import psutil 26 27logger = logging.getLogger(__name__) 28 29_T_CONFIG = TypeVar('_T_CONFIG', contravariant=True) 30_T_OUTPUT = TypeVar('_T_OUTPUT', covariant=True) 31 32 33class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]): 34 35 def __init__( 36 self, 37 process_idx: int, 38 seed_sequence: SeedSequence, 39 logger: logging.Logger, 40 config: _T_CONFIG, 41 ) -> None: 42 ... 43 44 def run(self) -> _T_OUTPUT: 45 ... 46 47 48@attrs.define 49class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]): 50 inventory: int 51 num_processes: int 52 pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]] 53 pool_worker_config: _T_CONFIG 54 schedule_size_min_factor: float = 1.0 55 rng_seed: int = 133700 56 logging_level: int = logging.INFO 57 logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s' 58 logging_to_stderr: bool = False 59 timeout: Optional[int] = None 60 61 62class PoolWorkerState: 63 pool_worker: Optional[PoolWorkerProtocol] = None 64 logger: logging.Logger 65 66 67def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any): 68 with process_counter.get_lock(): 69 process_idx: int = process_counter.value # type: ignore 70 process_counter.value += 1 # type: ignore 71 72 # Overriding logger. 73 logger = multiprocessing.get_logger() 74 logger_stream_handler = logging.StreamHandler() 75 logger_formatter = logging.Formatter( 76 pool_config.logging_format.replace('PROCESS_IDX', str(process_idx)) 77 ) 78 logger_stream_handler.setFormatter(logger_formatter) 79 logger.addHandler(logger_stream_handler) 80 logger.setLevel(pool_config.logging_level) 81 82 if pool_config.logging_to_stderr: 83 multiprocessing.util._log_to_stderr = True # type: ignore 84 85 logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}') 86 87 # Generate seed_sequence. 88 seed_sequences = SeedSequence(pool_config.rng_seed) 89 seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx] 90 91 # Initialize pool worker. 92 logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}') 93 pool_worker = pool_config.pool_worker_class( 94 process_idx=process_idx, 95 seed_sequence=seed_sequence, 96 config=pool_config.pool_worker_config, 97 logger=logger, 98 ) 99 PoolWorkerState.pool_worker = pool_worker 100 PoolWorkerState.logger = logger 101 logger.debug('Initialized.') 102 103 104def pool_worker_runner(_): 105 logger = PoolWorkerState.logger 106 logger.debug('Triggered.') 107 108 pool_worker = PoolWorkerState.pool_worker 109 assert pool_worker is not None 110 result = pool_worker.run() 111 logger.debug('Result generated.') 112 113 return result 114 115 116@attrs.define 117class PoolInventoryState: 118 cond: threading.Condition 119 inventory: int 120 num_scheduled: int 121 inventory_target: int 122 abort: bool 123 124 def predicate(self): 125 return self.abort or (self.inventory + self.num_scheduled < self.inventory_target) 126 127 def __repr__(self): 128 return ( 129 'PoolInventoryState(' 130 f'inventory={self.inventory}, ' 131 f'num_scheduled={self.num_scheduled}, ' 132 f'inventory_target={self.inventory_target}, ' 133 f'should_schedule={self.predicate()}' 134 ')' 135 ) 136 137 138def trigger_generator(schedule_size_min: int, state: PoolInventoryState): 139 while True: 140 with state.cond: 141 state.cond.wait_for(state.predicate) 142 if state.abort: 143 return 144 145 schedule_size = max( 146 schedule_size_min, 147 state.inventory_target - state.inventory - state.num_scheduled, 148 ) 149 logger.debug(f'state={state}, Need to schedule {schedule_size}.') 150 state.num_scheduled += schedule_size 151 for _ in range(schedule_size): 152 yield None 153 154 155class Pool(Generic[_T_CONFIG, _T_OUTPUT]): 156 157 def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]): 158 self.config = config 159 160 process_counter = multiprocessing.Value('i') 161 with process_counter.get_lock(): 162 process_counter.value = 0 # type: ignore 163 164 self.mp_pool = multiprocessing.Pool( 165 processes=self.config.num_processes, 166 initializer=pool_worker_initializer, 167 initargs=(self.config, process_counter), 168 ) 169 170 self.state = PoolInventoryState( 171 cond=threading.Condition(threading.Lock()), 172 inventory=0, 173 num_scheduled=0, 174 inventory_target=self.config.inventory, 175 abort=False, 176 ) 177 178 self.mp_pool_iter = self.mp_pool.imap_unordered( 179 pool_worker_runner, 180 trigger_generator( 181 schedule_size_min=round( 182 self.config.schedule_size_min_factor * self.config.num_processes 183 ), 184 state=self.state, 185 ), 186 ) 187 188 self.cleanup_flag = False 189 atexit.register(self.cleanup) 190 191 def cleanup(self): 192 if not self.cleanup_flag: 193 self.cleanup_flag = True 194 195 with self.state.cond: 196 self.state.abort = True 197 self.state.cond.notify() 198 199 self.mp_pool.close() 200 self.mp_pool.terminate() 201 time.sleep(1) 202 203 for proc in self.mp_pool._pool: # type: ignore 204 if not psutil.pid_exists(proc.pid): 205 continue 206 logger.warning(f'worker pid={proc.pid} still exists, killing...') 207 proc = psutil.Process(proc.pid) 208 proc.terminate() 209 try: 210 proc.wait(timeout=3) 211 except psutil.TimeoutExpired: 212 proc.kill() 213 214 # For gc. 215 self.__setattr__('mp_pool_iter', None) 216 self.__setattr__('state', None) 217 self.__setattr__('mp_pool', None) 218 self.__setattr__('config', None) 219 220 atexit.unregister(self.cleanup) 221 222 def run(self): 223 output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout) 224 225 # Update inventory. 226 with self.state.cond: 227 with self.mp_pool_iter._cond: # type: ignore 228 new_inventory = len(self.mp_pool_iter._items) # type: ignore 229 logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}') 230 231 # NOTE: We have just get one output, hence need to minus one. 232 num_scheduled_delta = new_inventory - self.state.inventory + 1 233 logger.debug(f'num_scheduled_delta: {num_scheduled_delta}') 234 assert num_scheduled_delta >= 0 235 236 new_num_scheduled = self.state.num_scheduled - num_scheduled_delta 237 logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}') 238 239 self.state.inventory = new_inventory 240 self.state.num_scheduled = new_num_scheduled 241 242 # Wake up trigger_generator. 243 self.state.cond.notify() 244 245 return output
34class PoolWorkerProtocol(Protocol[_T_CONFIG, _T_OUTPUT]): 35 36 def __init__( 37 self, 38 process_idx: int, 39 seed_sequence: SeedSequence, 40 logger: logging.Logger, 41 config: _T_CONFIG, 42 ) -> None: 43 ... 44 45 def run(self) -> _T_OUTPUT: 46 ...
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
50class PoolConfig(Generic[_T_CONFIG, _T_OUTPUT]): 51 inventory: int 52 num_processes: int 53 pool_worker_class: Type[PoolWorkerProtocol[_T_CONFIG, _T_OUTPUT]] 54 pool_worker_config: _T_CONFIG 55 schedule_size_min_factor: float = 1.0 56 rng_seed: int = 133700 57 logging_level: int = logging.INFO 58 logging_format: str = '[%(levelname)s/PROCESS_IDX] %(message)s' 59 logging_to_stderr: bool = False 60 timeout: Optional[int] = None
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
2def __init__(self, inventory, num_processes, pool_worker_class, pool_worker_config, schedule_size_min_factor=attr_dict['schedule_size_min_factor'].default, rng_seed=attr_dict['rng_seed'].default, logging_level=attr_dict['logging_level'].default, logging_format=attr_dict['logging_format'].default, logging_to_stderr=attr_dict['logging_to_stderr'].default, timeout=attr_dict['timeout'].default): 3 self.inventory = inventory 4 self.num_processes = num_processes 5 self.pool_worker_class = pool_worker_class 6 self.pool_worker_config = pool_worker_config 7 self.schedule_size_min_factor = schedule_size_min_factor 8 self.rng_seed = rng_seed 9 self.logging_level = logging_level 10 self.logging_format = logging_format 11 self.logging_to_stderr = logging_to_stderr 12 self.timeout = timeout
Method generated by attrs for class PoolConfig.
68def pool_worker_initializer(pool_config: PoolConfig, process_counter: Any): 69 with process_counter.get_lock(): 70 process_idx: int = process_counter.value # type: ignore 71 process_counter.value += 1 # type: ignore 72 73 # Overriding logger. 74 logger = multiprocessing.get_logger() 75 logger_stream_handler = logging.StreamHandler() 76 logger_formatter = logging.Formatter( 77 pool_config.logging_format.replace('PROCESS_IDX', str(process_idx)) 78 ) 79 logger_stream_handler.setFormatter(logger_formatter) 80 logger.addHandler(logger_stream_handler) 81 logger.setLevel(pool_config.logging_level) 82 83 if pool_config.logging_to_stderr: 84 multiprocessing.util._log_to_stderr = True # type: ignore 85 86 logger.debug(f'process_idx={process_idx}, num_processes={pool_config.num_processes}') 87 88 # Generate seed_sequence. 89 seed_sequences = SeedSequence(pool_config.rng_seed) 90 seed_sequence = seed_sequences.spawn(pool_config.num_processes)[process_idx] 91 92 # Initialize pool worker. 93 logger.debug(f'Initializing process_idx={process_idx} with seed_sequence={seed_sequence}') 94 pool_worker = pool_config.pool_worker_class( 95 process_idx=process_idx, 96 seed_sequence=seed_sequence, 97 config=pool_config.pool_worker_config, 98 logger=logger, 99 ) 100 PoolWorkerState.pool_worker = pool_worker 101 PoolWorkerState.logger = logger 102 logger.debug('Initialized.')
118class PoolInventoryState: 119 cond: threading.Condition 120 inventory: int 121 num_scheduled: int 122 inventory_target: int 123 abort: bool 124 125 def predicate(self): 126 return self.abort or (self.inventory + self.num_scheduled < self.inventory_target) 127 128 def __repr__(self): 129 return ( 130 'PoolInventoryState(' 131 f'inventory={self.inventory}, ' 132 f'num_scheduled={self.num_scheduled}, ' 133 f'inventory_target={self.inventory_target}, ' 134 f'should_schedule={self.predicate()}' 135 ')' 136 )
2def __init__(self, cond, inventory, num_scheduled, inventory_target, abort): 3 self.cond = cond 4 self.inventory = inventory 5 self.num_scheduled = num_scheduled 6 self.inventory_target = inventory_target 7 self.abort = abort
Method generated by attrs for class PoolInventoryState.
139def trigger_generator(schedule_size_min: int, state: PoolInventoryState): 140 while True: 141 with state.cond: 142 state.cond.wait_for(state.predicate) 143 if state.abort: 144 return 145 146 schedule_size = max( 147 schedule_size_min, 148 state.inventory_target - state.inventory - state.num_scheduled, 149 ) 150 logger.debug(f'state={state}, Need to schedule {schedule_size}.') 151 state.num_scheduled += schedule_size 152 for _ in range(schedule_size): 153 yield None
156class Pool(Generic[_T_CONFIG, _T_OUTPUT]): 157 158 def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]): 159 self.config = config 160 161 process_counter = multiprocessing.Value('i') 162 with process_counter.get_lock(): 163 process_counter.value = 0 # type: ignore 164 165 self.mp_pool = multiprocessing.Pool( 166 processes=self.config.num_processes, 167 initializer=pool_worker_initializer, 168 initargs=(self.config, process_counter), 169 ) 170 171 self.state = PoolInventoryState( 172 cond=threading.Condition(threading.Lock()), 173 inventory=0, 174 num_scheduled=0, 175 inventory_target=self.config.inventory, 176 abort=False, 177 ) 178 179 self.mp_pool_iter = self.mp_pool.imap_unordered( 180 pool_worker_runner, 181 trigger_generator( 182 schedule_size_min=round( 183 self.config.schedule_size_min_factor * self.config.num_processes 184 ), 185 state=self.state, 186 ), 187 ) 188 189 self.cleanup_flag = False 190 atexit.register(self.cleanup) 191 192 def cleanup(self): 193 if not self.cleanup_flag: 194 self.cleanup_flag = True 195 196 with self.state.cond: 197 self.state.abort = True 198 self.state.cond.notify() 199 200 self.mp_pool.close() 201 self.mp_pool.terminate() 202 time.sleep(1) 203 204 for proc in self.mp_pool._pool: # type: ignore 205 if not psutil.pid_exists(proc.pid): 206 continue 207 logger.warning(f'worker pid={proc.pid} still exists, killing...') 208 proc = psutil.Process(proc.pid) 209 proc.terminate() 210 try: 211 proc.wait(timeout=3) 212 except psutil.TimeoutExpired: 213 proc.kill() 214 215 # For gc. 216 self.__setattr__('mp_pool_iter', None) 217 self.__setattr__('state', None) 218 self.__setattr__('mp_pool', None) 219 self.__setattr__('config', None) 220 221 atexit.unregister(self.cleanup) 222 223 def run(self): 224 output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout) 225 226 # Update inventory. 227 with self.state.cond: 228 with self.mp_pool_iter._cond: # type: ignore 229 new_inventory = len(self.mp_pool_iter._items) # type: ignore 230 logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}') 231 232 # NOTE: We have just get one output, hence need to minus one. 233 num_scheduled_delta = new_inventory - self.state.inventory + 1 234 logger.debug(f'num_scheduled_delta: {num_scheduled_delta}') 235 assert num_scheduled_delta >= 0 236 237 new_num_scheduled = self.state.num_scheduled - num_scheduled_delta 238 logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}') 239 240 self.state.inventory = new_inventory 241 self.state.num_scheduled = new_num_scheduled 242 243 # Wake up trigger_generator. 244 self.state.cond.notify() 245 246 return output
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
158 def __init__(self, config: PoolConfig[_T_CONFIG, _T_OUTPUT]): 159 self.config = config 160 161 process_counter = multiprocessing.Value('i') 162 with process_counter.get_lock(): 163 process_counter.value = 0 # type: ignore 164 165 self.mp_pool = multiprocessing.Pool( 166 processes=self.config.num_processes, 167 initializer=pool_worker_initializer, 168 initargs=(self.config, process_counter), 169 ) 170 171 self.state = PoolInventoryState( 172 cond=threading.Condition(threading.Lock()), 173 inventory=0, 174 num_scheduled=0, 175 inventory_target=self.config.inventory, 176 abort=False, 177 ) 178 179 self.mp_pool_iter = self.mp_pool.imap_unordered( 180 pool_worker_runner, 181 trigger_generator( 182 schedule_size_min=round( 183 self.config.schedule_size_min_factor * self.config.num_processes 184 ), 185 state=self.state, 186 ), 187 ) 188 189 self.cleanup_flag = False 190 atexit.register(self.cleanup)
192 def cleanup(self): 193 if not self.cleanup_flag: 194 self.cleanup_flag = True 195 196 with self.state.cond: 197 self.state.abort = True 198 self.state.cond.notify() 199 200 self.mp_pool.close() 201 self.mp_pool.terminate() 202 time.sleep(1) 203 204 for proc in self.mp_pool._pool: # type: ignore 205 if not psutil.pid_exists(proc.pid): 206 continue 207 logger.warning(f'worker pid={proc.pid} still exists, killing...') 208 proc = psutil.Process(proc.pid) 209 proc.terminate() 210 try: 211 proc.wait(timeout=3) 212 except psutil.TimeoutExpired: 213 proc.kill() 214 215 # For gc. 216 self.__setattr__('mp_pool_iter', None) 217 self.__setattr__('state', None) 218 self.__setattr__('mp_pool', None) 219 self.__setattr__('config', None) 220 221 atexit.unregister(self.cleanup)
223 def run(self): 224 output: _T_OUTPUT = self.mp_pool_iter.next(timeout=self.config.timeout) 225 226 # Update inventory. 227 with self.state.cond: 228 with self.mp_pool_iter._cond: # type: ignore 229 new_inventory = len(self.mp_pool_iter._items) # type: ignore 230 logger.debug(f'inventory: {self.state.inventory} -> {new_inventory}') 231 232 # NOTE: We have just get one output, hence need to minus one. 233 num_scheduled_delta = new_inventory - self.state.inventory + 1 234 logger.debug(f'num_scheduled_delta: {num_scheduled_delta}') 235 assert num_scheduled_delta >= 0 236 237 new_num_scheduled = self.state.num_scheduled - num_scheduled_delta 238 logger.debug(f'num_scheduled: {self.state.num_scheduled} -> {new_num_scheduled}') 239 240 self.state.inventory = new_inventory 241 self.state.num_scheduled = new_num_scheduled 242 243 # Wake up trigger_generator. 244 self.state.cond.notify() 245 246 return output