215 lines
9.2 KiB
Python
215 lines
9.2 KiB
Python
import math
|
|
import os
|
|
from typing import Optional, Tuple
|
|
|
|
import h5py
|
|
import numpy as np
|
|
|
|
|
|
class BatchGenerator:
|
|
|
|
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None,
|
|
shuffle=True, preload=False, save=None, left_right_flip=False):
|
|
self.batch_size = batch_size
|
|
self.shuffle = shuffle
|
|
self.left_right_flip = left_right_flip
|
|
|
|
if not preload:
|
|
self.data_processor = data_processor
|
|
self.label_processor = label_processor
|
|
self.data = data
|
|
self.label = label
|
|
else:
|
|
self.data_processor = None
|
|
self.label_processor = None
|
|
save_path = save
|
|
if save is not None:
|
|
if '.' not in os.path.basename(save_path):
|
|
save_path += '.hdf5'
|
|
if not os.path.exists(os.path.dirname(save_path)):
|
|
os.makedirs(os.path.dirname(save_path))
|
|
|
|
if save and os.path.exists(save_path):
|
|
with h5py.File(save_path, 'r') as h5_file:
|
|
self.data = np.asarray(h5_file['data'])
|
|
self.label = np.asarray(h5_file['label'])
|
|
else:
|
|
if data_processor:
|
|
self.data = np.asarray([data_processor(entry) for entry in data])
|
|
else:
|
|
self.data = data
|
|
if label_processor:
|
|
self.label = np.asarray([label_processor(entry) for entry in label])
|
|
else:
|
|
self.label = label
|
|
if save:
|
|
with h5py.File(save_path, 'w') as h5_file:
|
|
h5_file.create_dataset('data', data=self.data)
|
|
h5_file.create_dataset('label', data=self.label)
|
|
|
|
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
|
|
|
|
self.epoch = 0
|
|
self.global_step = -1
|
|
self.step = -1
|
|
|
|
self.batch_data: Optional[np.ndarray] = None
|
|
self.batch_label: Optional[np.ndarray] = None
|
|
|
|
self.index_list = np.arange(len(self.data))
|
|
if shuffle:
|
|
np.random.shuffle(self.index_list)
|
|
|
|
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
if self.step >= self.step_per_epoch - 1: # step start at 0
|
|
self.step = 0
|
|
self.epoch += 1
|
|
if self.shuffle:
|
|
np.random.shuffle(self.index_list)
|
|
else:
|
|
self.step += 1
|
|
|
|
self.global_step += 1
|
|
# Loading data
|
|
if self.data_processor is not None:
|
|
self.batch_data = []
|
|
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
self.batch_data.append(self.data_processor(self.data[entry]))
|
|
self.batch_data = np.asarray(self.batch_data)
|
|
else:
|
|
self.batch_data = self.data[
|
|
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
|
# Loading label
|
|
if self.label_processor is not None:
|
|
self.batch_label = []
|
|
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
self.batch_label.append(self.label_processor(self.label[entry]))
|
|
self.batch_label = np.asarray(self.batch_label)
|
|
else:
|
|
self.batch_label = self.label[
|
|
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
|
# print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format(
|
|
# self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape))
|
|
if self.left_right_flip and np.random.uniform() > 0.5:
|
|
self.batch_data = self.batch_data[:, :, ::-1]
|
|
self.batch_label = self.batch_label[:, :, ::-1]
|
|
return self.batch_data, self.batch_label
|
|
|
|
|
|
class SequenceGenerator:
|
|
|
|
def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None,
|
|
preload=False, shuffle=True, save=None):
|
|
self.sequence_size = sequence_size
|
|
self.batch_size = batch_size
|
|
self.shuffle = shuffle
|
|
|
|
if not preload:
|
|
self.data_processor = data_processor
|
|
self.label_processor = label_processor
|
|
self.data = data
|
|
self.label = label
|
|
else:
|
|
self.data_processor = None
|
|
self.label_processor = None
|
|
save_path = save
|
|
if save is not None:
|
|
if '.' not in os.path.basename(save_path):
|
|
save_path += '.hdf5'
|
|
if not os.path.exists(os.path.dirname(save_path)):
|
|
os.makedirs(os.path.dirname(save_path))
|
|
|
|
if save and os.path.exists(save_path):
|
|
with h5py.File(save_path, 'r') as h5_file:
|
|
data_len = np.asarray(h5_file['data_len'])
|
|
self.data = []
|
|
self.label = []
|
|
for sequence_index in range(data_len):
|
|
self.data.append(np.asarray(h5_file[f'data_{sequence_index}']))
|
|
self.label.append(np.asarray(h5_file[f'label_{sequence_index}']))
|
|
self.data = np.asarray(self.data)
|
|
self.label = np.asarray(self.label)
|
|
else:
|
|
if data_processor:
|
|
self.data = np.asarray(
|
|
[np.asarray([data_processor(entry) for entry in serie]) for serie in data],
|
|
dtype=np.object if len(data) > 1 else None)
|
|
else:
|
|
self.data = data
|
|
if label_processor:
|
|
self.label = np.asarray(
|
|
[np.asarray([label_processor(entry) for entry in serie]) for serie in label],
|
|
dtype=np.object if len(label) > 1 else None)
|
|
else:
|
|
self.label = label
|
|
if save:
|
|
with h5py.File(save_path, 'w') as h5_file:
|
|
h5_file.create_dataset(f'data_len', data=len(self.data))
|
|
for sequence_index in range(len(self.data)):
|
|
h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index])
|
|
h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index])
|
|
|
|
self.index_list = []
|
|
for sequence_index in range(len(self.data)):
|
|
start_indices = np.expand_dims(
|
|
np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32),
|
|
axis=-1)
|
|
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
|
self.index_list.append(start_indices)
|
|
self.index_list = np.concatenate(self.index_list, axis=0)
|
|
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
|
|
|
self.epoch = 0
|
|
self.global_step = -1
|
|
self.step = -1
|
|
|
|
self.batch_data: Optional[np.ndarray] = None
|
|
self.batch_label: Optional[np.ndarray] = None
|
|
|
|
if shuffle:
|
|
np.random.shuffle(self.index_list)
|
|
|
|
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
if self.step >= self.step_per_epoch - 1: # step start at 0
|
|
self.step = 0
|
|
self.epoch += 1
|
|
if self.shuffle:
|
|
np.random.shuffle(self.index_list)
|
|
else:
|
|
self.step += 1
|
|
|
|
self.global_step += 1
|
|
# Loading data
|
|
if self.data_processor is not None:
|
|
self.batch_data = []
|
|
for sequence_index, start_index in self.index_list[
|
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
self.batch_data.append(
|
|
[self.data_processor(input_data)
|
|
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
|
self.batch_data = np.asarray(self.batch_data)
|
|
else:
|
|
self.batch_data = []
|
|
for sequence_index, start_index in self.index_list[
|
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
self.batch_data.append(
|
|
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
|
self.batch_data = np.asarray(self.batch_data)
|
|
# Loading label
|
|
if self.label_processor is not None:
|
|
self.batch_label = []
|
|
for sequence_index, start_index in self.index_list[
|
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
self.batch_label.append(
|
|
[self.label_processor(input_data)
|
|
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
|
self.batch_label = np.asarray(self.batch_label)
|
|
else:
|
|
self.batch_label = []
|
|
for sequence_index, start_index in self.index_list[
|
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
self.batch_label.append(
|
|
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
|
self.batch_label = np.asarray(self.batch_label)
|
|
|
|
return self.batch_data, self.batch_label
|