Add IPCBatchGenerator

This commit is contained in:
Corentin 2021-04-26 17:01:24 +09:00
commit dcd99a55ea

139
utils/ipc_data_generator.py Normal file
View file

@ -0,0 +1,139 @@
import multiprocessing as mp
from multiprocessing import shared_memory
from typing import Callable, Optional, Tuple
import numpy as np
class IPCBatchGenerator:
def __init__(self, ipc_processor: Callable,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
pipeline: Optional[Callable] = None,
prefetch=True, flip_data=False):
self.flip_data = flip_data
self.pipeline = pipeline
self.prefetch = prefetch
self.ipc_processor = ipc_processor
self.data_processor = data_processor
self.label_processor = label_processor
self.global_step = 0
self.data, self.label = ipc_processor()
first_data = [data_processor(entry) for entry in self.data] if data_processor else self.data
first_label = [label_processor(entry) for entry in self.label] if label_processor else self.label
if self.pipeline is not None:
for data_index, sample_data in enumerate(first_data):
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
first_data = np.asarray(first_data)
first_label = np.asarray(first_label)
self.batch_data = first_data
self.batch_label = first_label
self.process_id = 'NA'
if self.prefetch:
self.cache_memory_data = [
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
self.cache_data = [
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
self.cache_memory_label = [
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
self.cache_label = [
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0
self.prefetch_process = mp.Process(target=self._prefetch_worker)
self.prefetch_process.start()
else:
self.cache_data = [first_data]
self.cache_label = [first_label]
self.current_cache = 0
self.process_id = 'main'
def __del__(self):
self.release()
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_value, _traceback):
self.release()
def release(self):
if self.prefetch:
self.prefetch_stop.buf[0] = 1
self.prefetch_pipe_parent.send(True)
self.prefetch_process.join()
for shared_mem in self.cache_memory_data + self.cache_memory_label:
shared_mem.close()
shared_mem.unlink()
self.prefetch_stop.close()
self.prefetch_stop.unlink()
self.prefetch = False # Avoids double release
def _prefetch_worker(self):
self.prefetch = False
self.current_cache = 1
self.process_id = 'prefetch'
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
try:
self.current_cache = 1 - self.current_cache
self.global_step += 1
self._next_batch()
self.prefetch_pipe_child.recv()
self.prefetch_pipe_child.send(self.current_cache)
except KeyboardInterrupt:
break
def _next_batch(self):
# Loading data
self.data, self.label = self.ipc_processor()
data = np.asarray([self.data_processor(entry) for entry in self.data]) if self.data_processor else self.data
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
data = data[:, :, ::-1]
elif flip < 0.5:
data = data[:, :, :, ::-1]
elif flip < 0.75:
data = data[:, :, ::-1, ::-1]
# Loading label
label = np.asarray([
self.label_processor(entry) for entry in self.label]) if self.label_processor else self.label
# Process through pipeline
if self.pipeline is not None:
for data_index, data_entry in enumerate(data):
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
self.cache_data[self.current_cache][data_index] = piped_data
self.cache_label[self.current_cache][data_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = data
self.cache_label[self.current_cache][:len(label)] = label
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
self.global_step += 1
if self.prefetch:
self.prefetch_pipe_parent.send(True)
self.current_cache = self.prefetch_pipe_parent.recv()
else:
self._next_batch()
self.batch_data = self.cache_data[self.current_cache]
self.batch_label = self.cache_label[self.current_cache]
return self.batch_data, self.batch_label