torch_utils/utils/batch_generator.py

176 lines
7.6 KiB
Python

import math
import multiprocessing as mp
from multiprocessing import shared_memory
import os
from typing import Optional, Tuple
import h5py
import numpy as np
class BatchGenerator:
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, precache=True,
shuffle=True, preload=False, save=None, initial_shuffle=False, left_right_flip=False):
self.batch_size = batch_size
self.shuffle = shuffle
self.left_right_flip = left_right_flip
self.precache = precache and not preload
if not preload:
self.data_processor = data_processor
self.label_processor = label_processor
self.data = data
self.label = label
else:
self.data_processor = None
self.label_processor = None
save_path = save
if save is not None:
if '.' not in os.path.basename(save_path):
save_path += '.hdf5'
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save_path):
with h5py.File(save_path, 'r') as h5_file:
self.data = np.asarray(h5_file['data'])
self.label = np.asarray(h5_file['label'])
else:
if data_processor:
self.data = np.asarray([data_processor(entry) for entry in data])
else:
self.data = data
if label_processor:
self.label = np.asarray([label_processor(entry) for entry in label])
else:
self.label = label
if save:
with h5py.File(save_path, 'w') as h5_file:
h5_file.create_dataset('data', data=self.data)
h5_file.create_dataset('label', data=self.label)
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
self.epoch = 0
self.global_step = -1
self.step = -1
self.batch_data: Optional[np.ndarray] = None
self.batch_label: Optional[np.ndarray] = None
self.index_list = np.arange(len(self.data))
if shuffle or initial_shuffle:
np.random.shuffle(self.index_list)
if self.precache:
data_sample = np.array([data_processor(entry) if data_processor else entry
for entry in self.data[:batch_size]])
label_sample = np.array([label_processor(entry) if label_processor else entry
for entry in self.label[:batch_size]])
self.cache_memory_data = [
shared_memory.SharedMemory(create=True, size=data_sample.nbytes),
shared_memory.SharedMemory(create=True, size=data_sample.nbytes)]
self.cache_data = [
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf),
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)]
self.cache_memory_label = [
shared_memory.SharedMemory(create=True, size=label_sample.nbytes),
shared_memory.SharedMemory(create=True, size=label_sample.nbytes)]
self.cache_label = [
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf),
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)]
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
self.cache_stop = shared_memory.SharedMemory(create=True, size=1)
self.cache_stop.buf[0] = 0
self.cache_skip = shared_memory.SharedMemory(create=True, size=1)
self.cache_skip.buf[0] = 0
self.cache_process = mp.Process(target=self.cache_worker)
self.cache_process.start()
def __del__(self):
if self.precache:
self.cache_stop.buf[0] = 1
self.cache_pipe_parent.send(True)
self.cache_process.join()
self.cache_stop.close()
self.cache_stop.unlink()
self.cache_skip.close()
self.cache_skip.unlink()
self.cache_memory_data[0].close()
self.cache_memory_data[0].unlink()
self.cache_memory_data[1].close()
self.cache_memory_data[1].unlink()
self.cache_memory_label[0].close()
self.cache_memory_label[0].unlink()
self.cache_memory_label[1].close()
self.cache_memory_label[1].unlink()
def cache_worker(self):
self.precache = False
self.next_batch()
self.cache_data[0][:] = self.batch_data[:]
self.cache_label[0][:] = self.batch_label[:]
current_cache = 0
while not self.cache_stop.buf[0]:
try:
self.cache_pipe_child.recv()
self.cache_pipe_child.send(current_cache)
if self.cache_skip.buf[0]:
self.cache_skip.buf[0] = 0
self.step = self.step_per_epoch
self.next_batch()
current_cache = 1 - current_cache
self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:]
self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:]
except KeyboardInterrupt:
break
def skip_epoch(self):
if self.precache:
self.cache_skip.buf[0] = 1
self.step = self.step_per_epoch
self.next_batch()
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
if self.step >= self.step_per_epoch - 1: # step start at 0
self.step = 0
self.epoch += 1
if self.shuffle:
np.random.shuffle(self.index_list)
else:
self.step += 1
self.global_step += 1
# Loading data
if self.precache:
self.cache_pipe_parent.send(True)
current_cache = self.cache_pipe_parent.recv()
self.batch_data = self.cache_data[current_cache].copy()
elif self.data_processor is not None:
self.batch_data = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_data.append(self.data_processor(self.data[entry]))
self.batch_data = np.asarray(self.batch_data)
else:
self.batch_data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
# Loading label
if self.precache:
self.batch_label = self.cache_label[current_cache].copy()
elif self.label_processor is not None:
self.batch_label = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_label.append(self.label_processor(self.label[entry]))
self.batch_label = np.asarray(self.batch_label)
else:
self.batch_label = self.label[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
# print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format(
# self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape))
if self.left_right_flip and np.random.uniform() > 0.5:
self.batch_data = self.batch_data[:, :, ::-1]
self.batch_label = self.batch_label[:, :, ::-1]
return self.batch_data, self.batch_label