import math import multiprocessing as mp from multiprocessing import shared_memory import os from typing import Optional, Tuple import h5py import numpy as np class BatchGenerator: def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, precache=True, shuffle=True, preload=False, save=None, initial_shuffle=False, left_right_flip=False): self.batch_size = batch_size self.shuffle = shuffle self.left_right_flip = left_right_flip self.precache = precache and not preload if not preload: self.data_processor = data_processor self.label_processor = label_processor self.data = data self.label = label else: self.data_processor = None self.label_processor = None save_path = save if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): with h5py.File(save_path, 'r') as h5_file: self.data = np.asarray(h5_file['data']) self.label = np.asarray(h5_file['label']) else: if data_processor: self.data = np.asarray([data_processor(entry) for entry in data]) else: self.data = data if label_processor: self.label = np.asarray([label_processor(entry) for entry in label]) else: self.label = label if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data', data=self.data) h5_file.create_dataset('label', data=self.label) self.step_per_epoch = math.ceil(len(self.data) / batch_size) self.epoch = 0 self.global_step = -1 self.step = -1 self.batch_data: Optional[np.ndarray] = None self.batch_label: Optional[np.ndarray] = None self.index_list = np.arange(len(self.data)) if shuffle or initial_shuffle: np.random.shuffle(self.index_list) if self.precache: data_sample = np.array([data_processor(entry) if data_processor else entry for entry in self.data[:batch_size]]) label_sample = np.array([label_processor(entry) if label_processor else entry for entry in self.label[:batch_size]]) self.cache_memory_data = [ shared_memory.SharedMemory(create=True, size=data_sample.nbytes), shared_memory.SharedMemory(create=True, size=data_sample.nbytes)] self.cache_data = [ np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf), np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)] self.cache_memory_label = [ shared_memory.SharedMemory(create=True, size=label_sample.nbytes), shared_memory.SharedMemory(create=True, size=label_sample.nbytes)] self.cache_label = [ np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf), np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)] self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() self.cache_stop = shared_memory.SharedMemory(create=True, size=1) self.cache_stop.buf[0] = 0 self.cache_skip = shared_memory.SharedMemory(create=True, size=1) self.cache_skip.buf[0] = 0 self.cache_process = mp.Process(target=self.cache_worker) self.cache_process.start() def __del__(self): if self.precache: self.cache_stop.buf[0] = 1 self.cache_pipe_parent.send(True) self.cache_process.join() self.cache_stop.close() self.cache_stop.unlink() self.cache_skip.close() self.cache_skip.unlink() self.cache_memory_data[0].close() self.cache_memory_data[0].unlink() self.cache_memory_data[1].close() self.cache_memory_data[1].unlink() self.cache_memory_label[0].close() self.cache_memory_label[0].unlink() self.cache_memory_label[1].close() self.cache_memory_label[1].unlink() def cache_worker(self): self.precache = False self.next_batch() self.cache_data[0][:] = self.batch_data[:] self.cache_label[0][:] = self.batch_label[:] current_cache = 0 while not self.cache_stop.buf[0]: try: self.cache_pipe_child.recv() self.cache_pipe_child.send(current_cache) if self.cache_skip.buf[0]: self.cache_skip.buf[0] = 0 self.step = self.step_per_epoch self.next_batch() current_cache = 1 - current_cache self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:] self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:] except KeyboardInterrupt: break def skip_epoch(self): if self.precache: self.cache_skip.buf[0] = 1 self.step = self.step_per_epoch self.next_batch() def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 if self.shuffle: np.random.shuffle(self.index_list) else: self.step += 1 self.global_step += 1 # Loading data if self.precache: self.cache_pipe_parent.send(True) current_cache = self.cache_pipe_parent.recv() self.batch_data = self.cache_data[current_cache].copy() elif self.data_processor is not None: self.batch_data = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append(self.data_processor(self.data[entry])) self.batch_data = np.asarray(self.batch_data) else: self.batch_data = self.data[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] # Loading label if self.precache: self.batch_label = self.cache_label[current_cache].copy() elif self.label_processor is not None: self.batch_label = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append(self.label_processor(self.label[entry])) self.batch_label = np.asarray(self.batch_label) else: self.batch_label = self.label[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] # print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format( # self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape)) if self.left_right_flip and np.random.uniform() > 0.5: self.batch_data = self.batch_data[:, :, ::-1] self.batch_label = self.batch_label[:, :, ::-1] return self.batch_data, self.batch_label