diff --git a/utils/batch_generator.py b/utils/batch_generator.py index 008e5c6..9e8d419 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -104,16 +104,6 @@ class SequenceGenerator: self.batch_size = batch_size self.shuffle = shuffle - self.index_list = [] - for sequence_index in range(len(data)): - start_indices = np.expand_dims( - np.arange(len(data[sequence_index]) - sequence_size, dtype=np.uint8), - axis=-1) - start_indices = np.insert(start_indices, 0, sequence_index, axis=1) - self.index_list.append(start_indices) - self.index_list = np.concatenate(self.index_list, axis=0) - self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) - if not preload: self.data_processor = data_processor self.label_processor = label_processor @@ -122,25 +112,52 @@ class SequenceGenerator: else: self.data_processor = None self.label_processor = None + save_path = save + if save is not None: + if '.' not in os.path.basename(save_path): + save_path += '.hdf5' + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) - if save and os.path.exists(save + '_data.npy'): - self.data = np.load(save + '_data.npy', allow_pickle=True) - self.label = np.load(save + '_label.npy', allow_pickle=True) + if save and os.path.exists(save_path): + with h5py.File(save_path, 'r') as h5_file: + data_len = np.asarray(h5_file['data_len']) + self.data = [] + self.label = [] + for sequence_index in range(data_len): + self.data.append(np.asarray(h5_file[f'data_{sequence_index}'])) + self.label.append(np.asarray(h5_file[f'label_{sequence_index}'])) + self.data = np.asarray(self.data) + self.label = np.asarray(self.label) else: if data_processor: self.data = np.asarray( - [np.asarray([data_processor(entry) for entry in serie]) for serie in data]) + [np.asarray([data_processor(entry) for entry in serie]) for serie in data], + dtype=np.object if len(data) > 1 else None) else: self.data = data if label_processor: self.label = np.asarray( - [np.asarray([label_processor(entry) for entry in serie]) for serie in label]) + [np.asarray([label_processor(entry) for entry in serie]) for serie in label], + dtype=np.object if len(label) > 1 else None) else: self.label = label + if save: + with h5py.File(save_path, 'w') as h5_file: + h5_file.create_dataset(f'data_len', data=len(self.data)) + for sequence_index in range(len(self.data)): + h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index]) + h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index]) - if save: - np.save(save + '_data.npy', self.data, allow_pickle=True) - np.save(save + '_label.npy', self.label, allow_pickle=True) + self.index_list = [] + for sequence_index in range(len(self.data)): + start_indices = np.expand_dims( + np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32), + axis=-1) + start_indices = np.insert(start_indices, 0, sequence_index, axis=1) + self.index_list.append(start_indices) + self.index_list = np.concatenate(self.index_list, axis=0) + self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) self.epoch = 0 self.global_step = -1