Add IPCBatchGenerator
This commit is contained in:
parent
89240e417f
commit
dcd99a55ea
1 changed files with 139 additions and 0 deletions
139
utils/ipc_data_generator.py
Normal file
139
utils/ipc_data_generator.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import multiprocessing as mp
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class IPCBatchGenerator:
|
||||
|
||||
def __init__(self, ipc_processor: Callable,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
pipeline: Optional[Callable] = None,
|
||||
prefetch=True, flip_data=False):
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
self.prefetch = prefetch
|
||||
|
||||
self.ipc_processor = ipc_processor
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
|
||||
self.global_step = 0
|
||||
|
||||
self.data, self.label = ipc_processor()
|
||||
first_data = [data_processor(entry) for entry in self.data] if data_processor else self.data
|
||||
first_label = [label_processor(entry) for entry in self.label] if label_processor else self.label
|
||||
if self.pipeline is not None:
|
||||
for data_index, sample_data in enumerate(first_data):
|
||||
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
|
||||
first_data = np.asarray(first_data)
|
||||
first_label = np.asarray(first_label)
|
||||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch:
|
||||
self.cache_memory_data = [
|
||||
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
|
||||
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
|
||||
self.cache_data = [
|
||||
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
|
||||
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
|
||||
self.cache_memory_label = [
|
||||
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
|
||||
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
|
||||
self.cache_label = [
|
||||
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
|
||||
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
|
||||
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
|
||||
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||
self.prefetch_stop.buf[0] = 0
|
||||
self.prefetch_process = mp.Process(target=self._prefetch_worker)
|
||||
self.prefetch_process.start()
|
||||
else:
|
||||
self.cache_data = [first_data]
|
||||
self.cache_label = [first_label]
|
||||
|
||||
self.current_cache = 0
|
||||
self.process_id = 'main'
|
||||
|
||||
def __del__(self):
|
||||
self.release()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
self.release()
|
||||
|
||||
def release(self):
|
||||
if self.prefetch:
|
||||
self.prefetch_stop.buf[0] = 1
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.prefetch_process.join()
|
||||
|
||||
for shared_mem in self.cache_memory_data + self.cache_memory_label:
|
||||
shared_mem.close()
|
||||
shared_mem.unlink()
|
||||
self.prefetch_stop.close()
|
||||
self.prefetch_stop.unlink()
|
||||
self.prefetch = False # Avoids double release
|
||||
|
||||
def _prefetch_worker(self):
|
||||
self.prefetch = False
|
||||
self.current_cache = 1
|
||||
self.process_id = 'prefetch'
|
||||
|
||||
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
|
||||
try:
|
||||
self.current_cache = 1 - self.current_cache
|
||||
self.global_step += 1
|
||||
|
||||
self._next_batch()
|
||||
self.prefetch_pipe_child.recv()
|
||||
self.prefetch_pipe_child.send(self.current_cache)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
def _next_batch(self):
|
||||
# Loading data
|
||||
self.data, self.label = self.ipc_processor()
|
||||
|
||||
data = np.asarray([self.data_processor(entry) for entry in self.data]) if self.data_processor else self.data
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
data = data[:, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
data = data[:, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
data = data[:, :, ::-1, ::-1]
|
||||
|
||||
# Loading label
|
||||
label = np.asarray([
|
||||
self.label_processor(entry) for entry in self.label]) if self.label_processor else self.label
|
||||
|
||||
# Process through pipeline
|
||||
if self.pipeline is not None:
|
||||
for data_index, data_entry in enumerate(data):
|
||||
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
|
||||
self.cache_data[self.current_cache][data_index] = piped_data
|
||||
self.cache_label[self.current_cache][data_index] = piped_label
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
self.cache_label[self.current_cache][:len(label)] = label
|
||||
|
||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
self.global_step += 1
|
||||
|
||||
if self.prefetch:
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.current_cache = self.prefetch_pipe_parent.recv()
|
||||
else:
|
||||
self._next_batch()
|
||||
|
||||
self.batch_data = self.cache_data[self.current_cache]
|
||||
self.batch_label = self.cache_label[self.current_cache]
|
||||
|
||||
return self.batch_data, self.batch_label
|
||||
Loading…
Add table
Add a link
Reference in a new issue