diff --git a/utils/ipc_data_generator.py b/utils/ipc_data_generator.py new file mode 100644 index 0000000..a821c9f --- /dev/null +++ b/utils/ipc_data_generator.py @@ -0,0 +1,139 @@ +import multiprocessing as mp +from multiprocessing import shared_memory +from typing import Callable, Optional, Tuple + +import numpy as np + + +class IPCBatchGenerator: + + def __init__(self, ipc_processor: Callable, + data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, + pipeline: Optional[Callable] = None, + prefetch=True, flip_data=False): + self.flip_data = flip_data + self.pipeline = pipeline + self.prefetch = prefetch + + self.ipc_processor = ipc_processor + self.data_processor = data_processor + self.label_processor = label_processor + + self.global_step = 0 + + self.data, self.label = ipc_processor() + first_data = [data_processor(entry) for entry in self.data] if data_processor else self.data + first_label = [label_processor(entry) for entry in self.label] if label_processor else self.label + if self.pipeline is not None: + for data_index, sample_data in enumerate(first_data): + first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index]) + first_data = np.asarray(first_data) + first_label = np.asarray(first_label) + self.batch_data = first_data + self.batch_label = first_label + + self.process_id = 'NA' + if self.prefetch: + self.cache_memory_data = [ + shared_memory.SharedMemory(create=True, size=first_data.nbytes), + shared_memory.SharedMemory(create=True, size=first_data.nbytes)] + self.cache_data = [ + np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf), + np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)] + self.cache_memory_label = [ + shared_memory.SharedMemory(create=True, size=first_label.nbytes), + shared_memory.SharedMemory(create=True, size=first_label.nbytes)] + self.cache_label = [ + np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf), + np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)] + self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe() + self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) + self.prefetch_stop.buf[0] = 0 + self.prefetch_process = mp.Process(target=self._prefetch_worker) + self.prefetch_process.start() + else: + self.cache_data = [first_data] + self.cache_label = [first_label] + + self.current_cache = 0 + self.process_id = 'main' + + def __del__(self): + self.release() + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + self.release() + + def release(self): + if self.prefetch: + self.prefetch_stop.buf[0] = 1 + self.prefetch_pipe_parent.send(True) + self.prefetch_process.join() + + for shared_mem in self.cache_memory_data + self.cache_memory_label: + shared_mem.close() + shared_mem.unlink() + self.prefetch_stop.close() + self.prefetch_stop.unlink() + self.prefetch = False # Avoids double release + + def _prefetch_worker(self): + self.prefetch = False + self.current_cache = 1 + self.process_id = 'prefetch' + + while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0: + try: + self.current_cache = 1 - self.current_cache + self.global_step += 1 + + self._next_batch() + self.prefetch_pipe_child.recv() + self.prefetch_pipe_child.send(self.current_cache) + except KeyboardInterrupt: + break + + def _next_batch(self): + # Loading data + self.data, self.label = self.ipc_processor() + + data = np.asarray([self.data_processor(entry) for entry in self.data]) if self.data_processor else self.data + if self.flip_data: + flip = np.random.uniform() + if flip < 0.25: + data = data[:, :, ::-1] + elif flip < 0.5: + data = data[:, :, :, ::-1] + elif flip < 0.75: + data = data[:, :, ::-1, ::-1] + + # Loading label + label = np.asarray([ + self.label_processor(entry) for entry in self.label]) if self.label_processor else self.label + + # Process through pipeline + if self.pipeline is not None: + for data_index, data_entry in enumerate(data): + piped_data, piped_label = self.pipeline(data_entry, label[data_index]) + self.cache_data[self.current_cache][data_index] = piped_data + self.cache_label[self.current_cache][data_index] = piped_label + else: + self.cache_data[self.current_cache][:len(data)] = data + self.cache_label[self.current_cache][:len(label)] = label + + def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: + self.global_step += 1 + + if self.prefetch: + self.prefetch_pipe_parent.send(True) + self.current_cache = self.prefetch_pipe_parent.recv() + else: + self._next_batch() + + self.batch_data = self.cache_data[self.current_cache] + self.batch_label = self.cache_label[self.current_cache] + + return self.batch_data, self.batch_label