import math import multiprocessing as mp from multiprocessing import shared_memory import os from typing import Optional, Tuple import h5py import numpy as np class BatchGenerator: def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, precache=True, shuffle=True, preload=False, save=None, left_right_flip=False): self.batch_size = batch_size self.shuffle = shuffle self.left_right_flip = left_right_flip self.precache = precache and not preload if not preload: self.data_processor = data_processor self.label_processor = label_processor self.data = data self.label = label else: self.data_processor = None self.label_processor = None save_path = save if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): with h5py.File(save_path, 'r') as h5_file: self.data = np.asarray(h5_file['data']) self.label = np.asarray(h5_file['label']) else: if data_processor: self.data = np.asarray([data_processor(entry) for entry in data]) else: self.data = data if label_processor: self.label = np.asarray([label_processor(entry) for entry in label]) else: self.label = label if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data', data=self.data) h5_file.create_dataset('label', data=self.label) self.step_per_epoch = math.ceil(len(self.data) / batch_size) self.epoch = 0 self.global_step = -1 self.step = -1 self.batch_data: Optional[np.ndarray] = None self.batch_label: Optional[np.ndarray] = None self.index_list = np.arange(len(self.data)) if shuffle: np.random.shuffle(self.index_list) if self.precache: data_sample = np.array([data_processor(entry) if data_processor else entry for entry in self.data[:batch_size]]) label_sample = np.array([label_processor(entry) if label_processor else entry for entry in self.label[:batch_size]]) self.cache_memory_data = [ shared_memory.SharedMemory(create=True, size=data_sample.nbytes), shared_memory.SharedMemory(create=True, size=data_sample.nbytes)] self.cache_data = [ np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf), np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)] self.cache_memory_label = [ shared_memory.SharedMemory(create=True, size=label_sample.nbytes), shared_memory.SharedMemory(create=True, size=label_sample.nbytes)] self.cache_label = [ np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf), np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)] self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() self.cache_stop = shared_memory.SharedMemory(create=True, size=1) self.cache_stop.buf[0] = 0 self.cache_process = mp.Process(target=self.cache_worker) self.cache_process.start() def __del__(self): if self.precache: self.cache_stop.buf[0] = 1 self.cache_pipe_parent.send(True) self.cache_process.join() self.cache_stop.close() self.cache_stop.unlink() self.cache_memory_data[0].close() self.cache_memory_data[0].unlink() self.cache_memory_data[1].close() self.cache_memory_data[1].unlink() self.cache_memory_label[0].close() self.cache_memory_label[0].unlink() self.cache_memory_label[1].close() self.cache_memory_label[1].unlink() def cache_worker(self): self.precache = False self.next_batch() self.cache_data[0][:] = self.batch_data[:] self.cache_label[0][:] = self.batch_label[:] current_cache = 0 while not self.cache_stop.buf[0]: try: self.cache_pipe_child.recv() self.cache_pipe_child.send(current_cache) self.next_batch() current_cache = 1 - current_cache self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:] self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:] except KeyboardInterrupt: break def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 if self.shuffle: np.random.shuffle(self.index_list) else: self.step += 1 self.global_step += 1 # Loading data if self.precache: self.cache_pipe_parent.send(True) current_cache = self.cache_pipe_parent.recv() self.batch_data = self.cache_data[current_cache].copy() elif self.data_processor is not None: self.batch_data = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append(self.data_processor(self.data[entry])) self.batch_data = np.asarray(self.batch_data) else: self.batch_data = self.data[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] # Loading label if self.precache: self.batch_label = self.cache_label[current_cache].copy() elif self.label_processor is not None: self.batch_label = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append(self.label_processor(self.label[entry])) self.batch_label = np.asarray(self.batch_label) else: self.batch_label = self.label[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] # print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format( # self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape)) if self.left_right_flip and np.random.uniform() > 0.5: self.batch_data = self.batch_data[:, :, ::-1] self.batch_label = self.batch_label[:, :, ::-1] return self.batch_data, self.batch_label class SequenceGenerator: def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None, preload=False, shuffle=True, save=None): self.sequence_size = sequence_size self.batch_size = batch_size self.shuffle = shuffle if not preload: self.data_processor = data_processor self.label_processor = label_processor self.data = data self.label = label else: self.data_processor = None self.label_processor = None save_path = save if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): with h5py.File(save_path, 'r') as h5_file: data_len = np.asarray(h5_file['data_len']) self.data = [] self.label = [] for sequence_index in range(data_len): self.data.append(np.asarray(h5_file[f'data_{sequence_index}'])) self.label.append(np.asarray(h5_file[f'label_{sequence_index}'])) self.data = np.asarray(self.data) self.label = np.asarray(self.label) else: if data_processor: self.data = np.asarray( [np.asarray([data_processor(entry) for entry in serie]) for serie in data], dtype=np.object if len(data) > 1 else None) else: self.data = data if label_processor: self.label = np.asarray( [np.asarray([label_processor(entry) for entry in serie]) for serie in label], dtype=np.object if len(label) > 1 else None) else: self.label = label if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset(f'data_len', data=len(self.data)) for sequence_index in range(len(self.data)): h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index]) h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index]) self.index_list = [] for sequence_index in range(len(self.data)): start_indices = np.expand_dims( np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32), axis=-1) start_indices = np.insert(start_indices, 0, sequence_index, axis=1) self.index_list.append(start_indices) self.index_list = np.concatenate(self.index_list, axis=0) self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) self.epoch = 0 self.global_step = -1 self.step = -1 self.batch_data: Optional[np.ndarray] = None self.batch_label: Optional[np.ndarray] = None if shuffle: np.random.shuffle(self.index_list) def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 if self.shuffle: np.random.shuffle(self.index_list) else: self.step += 1 self.global_step += 1 # Loading data if self.data_processor is not None: self.batch_data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append( [self.data_processor(input_data) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) self.batch_data = np.asarray(self.batch_data) else: self.batch_data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append( self.data[sequence_index][start_index: start_index + self.sequence_size]) self.batch_data = np.asarray(self.batch_data) # Loading label if self.label_processor is not None: self.batch_label = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append( [self.label_processor(input_data) for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) self.batch_label = np.asarray(self.batch_label) else: self.batch_label = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append( self.label[sequence_index][start_index: start_index + self.sequence_size]) self.batch_label = np.asarray(self.batch_label) return self.batch_data, self.batch_label