torch_utils/utils/batch_generator.py
2020-12-04 11:28:12 +09:00

337 lines
14 KiB
Python

import multiprocessing as mp
from multiprocessing import shared_memory
import os
from typing import Callable, Iterable, Optional, Tuple
import h5py
import numpy as np
class BatchGenerator:
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
left_right_flip=False, save: Optional[str] = None):
self.batch_size = batch_size
self.shuffle = shuffle
self.prefetch = prefetch and not preload
self.num_workers = num_workers
self.left_right_flip = left_right_flip
if not preload:
self.data_processor = data_processor
self.label_processor = label_processor
self.data = data
self.label = label
else:
self.data_processor = None
self.label_processor = None
save_path = save
if save is not None:
if '.' not in os.path.basename(save_path):
save_path += '.hdf5'
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save_path):
with h5py.File(save_path, 'r') as h5_file:
self.data = np.asarray(h5_file['data'])
self.label = np.asarray(h5_file['label'])
else:
if data_processor:
self.data = np.asarray([data_processor(entry) for entry in data])
else:
self.data = data
if label_processor:
self.label = np.asarray([label_processor(entry) for entry in label])
else:
self.label = label
if save:
with h5py.File(save_path, 'w') as h5_file:
h5_file.create_dataset('data', data=self.data)
h5_file.create_dataset('label', data=self.label)
self.index_list = np.arange(len(self.data))
if shuffle or initial_shuffle:
np.random.shuffle(self.index_list)
self.step_per_epoch = len(self.index_list) // self.batch_size
self.last_batch_size = len(self.index_list) % self.batch_size
if self.last_batch_size == 0:
self.last_batch_size = self.batch_size
self.epoch = 0
self.global_step = 0
self.step = 0
first_data = np.array([data_processor(entry) if data_processor else entry
for entry in self.data[self.index_list[:batch_size]]])
first_label = np.array([label_processor(entry) if label_processor else entry
for entry in self.label[self.index_list[:batch_size]]])
self.batch_data = first_data
self.batch_label = first_label
self.main_process = False
if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray(
self.index_list.shape, dtype=self.index_list.dtype, buffer=self.cache_memory_indices.buf)
self.cache_indices[:] = self.index_list
self.cache_memory_data = [
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
self.cache_data = [
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
self.cache_memory_label = [
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
self.cache_label = [
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
else:
self.cache_memory_indices = None
self.cache_data = [first_data]
self.cache_label = [first_label]
if self.prefetch:
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_skip.buf[0] = 0
self.cache_process = mp.Process(target=self._cache_worker)
self.cache_process.start()
self.num_workers = 0
self._init_workers()
self.current_cache = 0
self.main_process = True
def _init_workers(self):
if self.num_workers > 1:
self.worker_stop = shared_memory.SharedMemory(create=True, size=1)
self.worker_stop.buf[0] = 0
self.worker_pipes = []
self.worker_processes = []
for _ in range(self.num_workers):
self.worker_pipes.append(mp.Pipe())
for worker_index in range(self.num_workers):
self.worker_processes.append(mp.Process(target=self._worker, args=(worker_index,)))
self.worker_processes[-1].start()
def __del__(self):
self.release()
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_value, _traceback):
self.release()
def release(self):
if self.num_workers > 1:
self.worker_stop.buf[0] = 1
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
pipe.send([-1, 0, 0, 0])
worker.join()
self.worker_stop.close()
self.worker_stop.unlink()
self.num_workers = 0 # Avoids double release
if self.prefetch:
self.prefetch_stop.buf[0] = 1
self.cache_pipe_parent.send(True)
self.cache_process.join()
self.prefetch_stop.close()
self.prefetch_stop.unlink()
self.prefetch_skip.close()
self.prefetch_skip.unlink()
self.prefetch = False # Avoids double release
if self.main_process and self.cache_memory_indices is not None:
self.cache_memory_indices.close()
self.cache_memory_indices.unlink()
self.cache_memory_data[0].close()
self.cache_memory_data[0].unlink()
self.cache_memory_data[1].close()
self.cache_memory_data[1].unlink()
self.cache_memory_label[0].close()
self.cache_memory_label[0].unlink()
self.cache_memory_label[1].close()
self.cache_memory_label[1].unlink()
self.cache_memory_indices = None # Avoids double release
def _cache_worker(self):
self.prefetch = False
self.current_cache = 1
self._init_workers()
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
try:
self.current_cache = 1 - self.current_cache
if self.prefetch_skip.buf[0]:
self.prefetch_skip.buf[0] = 0
self.step = self.step_per_epoch
self.index_list[:] = self.cache_indices
if self.step >= self.step_per_epoch - 1: # step start at 0
self.step = 0
self.epoch += 1
else:
self.step += 1
self.global_step += 1
self._next_batch()
self.cache_pipe_child.recv()
self.cache_pipe_child.send(self.current_cache)
except KeyboardInterrupt:
break
if self.num_workers > 1:
self.release()
def _worker(self, worker_index: int):
self.num_workers = 0
self.current_cache = 0
parent_cache_data = self.cache_data
parent_cache_label = self.cache_label
cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)
cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)
self.cache_data = [cache_data]
self.cache_label = [cache_label]
self.index_list[:] = self.cache_indices
pipe = self.worker_pipes[worker_index][1]
while self.worker_stop.buf is not None and self.worker_stop.buf[0] == 0:
try:
current_cache, batch_index, start_index, self.batch_size = pipe.recv()
if self.batch_size == 0:
continue
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
self.cache_data[0] = cache_data[:self.batch_size]
self.cache_label[0] = cache_label[:self.batch_size]
self._next_batch()
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
self.current_cache][:self.batch_size]
parent_cache_label[current_cache][batch_index:batch_index + self.batch_size] = self.cache_label[
self.current_cache][:self.batch_size]
pipe.send(True)
except KeyboardInterrupt:
break
def skip_epoch(self):
if self.prefetch:
self.prefetch_skip.buf[0] = 1
self.step = self.step_per_epoch
self.next_batch()
def _worker_next_batch(self):
index_list = self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]
batch_len = len(index_list)
indices_per_worker = batch_len // self.num_workers
if indices_per_worker == 0:
indices_per_worker = 1
worker_params = []
batch_index = 0
start_index = self.step * self.batch_size
for _worker_index in range(self.num_workers - 1):
worker_params.append([self.current_cache, batch_index, start_index, indices_per_worker])
batch_index += indices_per_worker
start_index += indices_per_worker
worker_params.append([
self.current_cache, batch_index, start_index,
batch_len - ((self.num_workers - 1) * indices_per_worker)])
if indices_per_worker == 1 and batch_len < self.num_workers:
worker_params = worker_params[:batch_len]
for params, (pipe, _) in zip(worker_params, self.worker_pipes):
pipe.send(params)
for _, (pipe, _) in zip(worker_params, self.worker_pipes):
pipe.recv()
def _next_batch(self):
if self.num_workers > 1:
self._worker_next_batch()
return
# Loading data
if self.data_processor is not None:
data = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(self.data_processor(self.data[entry]))
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
self.cache_data[self.current_cache][:len(data)] = data
# Loading label
if self.label_processor is not None:
label = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append(self.label_processor(self.label[entry]))
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
else:
label = self.label[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
self.cache_label[self.current_cache][:len(label)] = label
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
if self.step >= self.step_per_epoch - 1: # step start at 0
self.step = 0
self.epoch += 1
else:
self.step += 1
self.global_step += 1
if self.prefetch:
# Shuffles once step ahead in pretech mode because next batch is already prepared
if self.step == self.step_per_epoch - 1 and self.shuffle:
np.random.shuffle(self.index_list)
self.cache_indices[:] = self.index_list
self.cache_pipe_parent.send(True)
self.current_cache = self.cache_pipe_parent.recv()
else:
if self.step == 0 and self.shuffle:
np.random.shuffle(self.index_list)
if self.num_workers > 1:
self.cache_indices[:] = self.index_list
self._next_batch()
if self.step == self.step_per_epoch - 1:
self.batch_data = self.cache_data[self.current_cache][:self.last_batch_size]
self.batch_label = self.cache_label[self.current_cache][:self.last_batch_size]
else:
self.batch_data = self.cache_data[self.current_cache]
self.batch_label = self.cache_label[self.current_cache]
if self.left_right_flip and np.random.uniform() > 0.5:
self.batch_data = self.batch_data[:, :, ::-1]
self.batch_label = self.batch_label[:, :, ::-1]
return self.batch_data, self.batch_label
if __name__ == '__main__':
def test():
data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=np.uint8)
label = np.array(
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
for data_processor in [None, lambda x:x]:
for prefetch in [False, True]:
for num_workers in [1, 3]:
print(f'{data_processor=} {prefetch=} {num_workers=}')
with BatchGenerator(data, label, 5, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch()
print()
test()