From d0fd6b2642ddf2e6010c60ff8cead8495e70ffc3 Mon Sep 17 00:00:00 2001 From: Corentin Date: Mon, 22 Mar 2021 20:14:27 +0900 Subject: [PATCH] Index list parameter for SequenceBatchGenerator --- utils/sequence_batch_generator.py | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 041bd60..7affb81 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -16,7 +16,7 @@ class SequenceGenerator(BatchGenerator): def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, - prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, + index_list=None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, sequence_stride=1, flip_data=False, save: Optional[str] = None): self.batch_size = batch_size self.sequence_size = sequence_size @@ -70,19 +70,26 @@ class SequenceGenerator(BatchGenerator): h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index]) h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index]) - self.index_list = [] - for sequence_index in range(len(self.data)): - if sequence_stride > 1: - start_indices = np.expand_dims( - np.arange(0, len(self.data[sequence_index]) - sequence_size + 1, sequence_stride, dtype=np.uint32), - axis=-1) - else: - start_indices = np.expand_dims( - np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32), - axis=-1) - start_indices = np.insert(start_indices, 0, sequence_index, axis=1) - self.index_list.append(start_indices) - self.index_list = np.concatenate(self.index_list, axis=0) + if index_list is not None: + self.index_list = index_list + else: + self.index_list = [] + for sequence_index in range(len(self.data)): + if sequence_stride > 1: + start_indices = np.expand_dims( + np.arange(0, + len(self.data[sequence_index]) - sequence_size + 1, + sequence_stride, + dtype=np.uint32), + axis=-1) + else: + start_indices = np.expand_dims( + np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32), + axis=-1) + start_indices = np.insert(start_indices, 0, sequence_index, axis=1) + self.index_list.append(start_indices) + self.index_list = np.concatenate(self.index_list, axis=0) + if shuffle or initial_shuffle: np.random.shuffle(self.index_list) self.step_per_epoch = len(self.index_list) // self.batch_size