import math import os from typing import Optional, Tuple import h5py import numpy as np class BatchGenerator: def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, shuffle=True, preload=False, save=None, left_right_flip=False): self.batch_size = batch_size self.shuffle = shuffle self.left_right_flip = left_right_flip if not preload: self.data_processor = data_processor self.label_processor = label_processor self.data = data self.label = label else: self.data_processor = None self.label_processor = None save_path = save if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): with h5py.File(save_path, 'r') as h5_file: self.data = np.asarray(h5_file['data']) self.label = np.asarray(h5_file['label']) else: if data_processor: self.data = np.asarray([data_processor(entry) for entry in data]) else: self.data = data if label_processor: self.label = np.asarray([label_processor(entry) for entry in label]) else: self.label = label if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data', data=self.data) h5_file.create_dataset('label', data=self.label) self.step_per_epoch = math.ceil(len(self.data) / batch_size) self.epoch = 0 self.global_step = -1 self.step = -1 self.batch_data: Optional[np.ndarray] = None self.batch_label: Optional[np.ndarray] = None self.index_list = np.arange(len(self.data)) if shuffle: np.random.shuffle(self.index_list) def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 if self.shuffle: np.random.shuffle(self.index_list) else: self.step += 1 self.global_step += 1 # Loading data if self.data_processor is not None: self.batch_data = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append(self.data_processor(self.data[entry])) self.batch_data = np.asarray(self.batch_data) else: self.batch_data = self.data[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] # Loading label if self.label_processor is not None: self.batch_label = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append(self.label_processor(self.label[entry])) self.batch_label = np.asarray(self.batch_label) else: self.batch_label = self.label[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] # print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format( # self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape)) if self.left_right_flip and np.random.uniform() > 0.5: self.batch_data = self.batch_data[:, :, ::-1] self.batch_label = self.batch_label[:, :, ::-1] return self.batch_data, self.batch_label class SequenceGenerator: def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None, preload=False, shuffle=True, save=None): self.sequence_size = sequence_size self.batch_size = batch_size self.shuffle = shuffle if not preload: self.data_processor = data_processor self.label_processor = label_processor self.data = data self.label = label else: self.data_processor = None self.label_processor = None save_path = save if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): with h5py.File(save_path, 'r') as h5_file: data_len = np.asarray(h5_file['data_len']) self.data = [] self.label = [] for sequence_index in range(data_len): self.data.append(np.asarray(h5_file[f'data_{sequence_index}'])) self.label.append(np.asarray(h5_file[f'label_{sequence_index}'])) self.data = np.asarray(self.data) self.label = np.asarray(self.label) else: if data_processor: self.data = np.asarray( [np.asarray([data_processor(entry) for entry in serie]) for serie in data], dtype=np.object if len(data) > 1 else None) else: self.data = data if label_processor: self.label = np.asarray( [np.asarray([label_processor(entry) for entry in serie]) for serie in label], dtype=np.object if len(label) > 1 else None) else: self.label = label if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset(f'data_len', data=len(self.data)) for sequence_index in range(len(self.data)): h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index]) h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index]) self.index_list = [] for sequence_index in range(len(self.data)): start_indices = np.expand_dims( np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32), axis=-1) start_indices = np.insert(start_indices, 0, sequence_index, axis=1) self.index_list.append(start_indices) self.index_list = np.concatenate(self.index_list, axis=0) self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) self.epoch = 0 self.global_step = -1 self.step = -1 self.batch_data: Optional[np.ndarray] = None self.batch_label: Optional[np.ndarray] = None if shuffle: np.random.shuffle(self.index_list) def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 if self.shuffle: np.random.shuffle(self.index_list) else: self.step += 1 self.global_step += 1 # Loading data if self.data_processor is not None: self.batch_data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append( [self.data_processor(input_data) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) self.batch_data = np.asarray(self.batch_data) else: self.batch_data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_data.append( self.data[sequence_index][start_index: start_index + self.sequence_size]) self.batch_data = np.asarray(self.batch_data) # Loading label if self.label_processor is not None: self.batch_label = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append( [self.label_processor(input_data) for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) self.batch_label = np.asarray(self.batch_label) else: self.batch_label = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.batch_label.append( self.label[sequence_index][start_index: start_index + self.sequence_size]) self.batch_label = np.asarray(self.batch_label) return self.batch_data, self.batch_label