Index list parameter for SequenceBatchGenerator

This commit is contained in:
Corentin 2021-03-22 20:14:27 +09:00
commit d0fd6b2642

View file

@ -16,7 +16,7 @@ class SequenceGenerator(BatchGenerator):
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
index_list=None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
sequence_stride=1, flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size
self.sequence_size = sequence_size
@ -70,19 +70,26 @@ class SequenceGenerator(BatchGenerator):
h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index])
h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index])
self.index_list = []
for sequence_index in range(len(self.data)):
if sequence_stride > 1:
start_indices = np.expand_dims(
np.arange(0, len(self.data[sequence_index]) - sequence_size + 1, sequence_stride, dtype=np.uint32),
axis=-1)
else:
start_indices = np.expand_dims(
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
axis=-1)
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
self.index_list.append(start_indices)
self.index_list = np.concatenate(self.index_list, axis=0)
if index_list is not None:
self.index_list = index_list
else:
self.index_list = []
for sequence_index in range(len(self.data)):
if sequence_stride > 1:
start_indices = np.expand_dims(
np.arange(0,
len(self.data[sequence_index]) - sequence_size + 1,
sequence_stride,
dtype=np.uint32),
axis=-1)
else:
start_indices = np.expand_dims(
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
axis=-1)
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
self.index_list.append(start_indices)
self.index_list = np.concatenate(self.index_list, axis=0)
if shuffle or initial_shuffle:
np.random.shuffle(self.index_list)
self.step_per_epoch = len(self.index_list) // self.batch_size