Sequence stride option for SequenceBatchGenerator

This commit is contained in:
Corentin 2021-03-18 10:36:17 +09:00
commit 51776d6999

View file

@ -17,7 +17,7 @@ class SequenceGenerator(BatchGenerator):
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
flip_data=False, save: Optional[str] = None): sequence_stride=1, flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size self.batch_size = batch_size
self.sequence_size = sequence_size self.sequence_size = sequence_size
self.shuffle = shuffle self.shuffle = shuffle
@ -72,6 +72,11 @@ class SequenceGenerator(BatchGenerator):
self.index_list = [] self.index_list = []
for sequence_index in range(len(self.data)): for sequence_index in range(len(self.data)):
if sequence_stride > 1:
start_indices = np.expand_dims(
np.arange(0, len(self.data[sequence_index]) - sequence_size + 1, sequence_stride, dtype=np.uint32),
axis=-1)
else:
start_indices = np.expand_dims( start_indices = np.expand_dims(
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32), np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
axis=-1) axis=-1)