Sequence stride option for SequenceBatchGenerator
This commit is contained in:
parent
92971be5f0
commit
51776d6999
1 changed files with 9 additions and 4 deletions
|
|
@ -17,7 +17,7 @@ class SequenceGenerator(BatchGenerator):
|
|||
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
flip_data=False, save: Optional[str] = None):
|
||||
sequence_stride=1, flip_data=False, save: Optional[str] = None):
|
||||
self.batch_size = batch_size
|
||||
self.sequence_size = sequence_size
|
||||
self.shuffle = shuffle
|
||||
|
|
@ -72,6 +72,11 @@ class SequenceGenerator(BatchGenerator):
|
|||
|
||||
self.index_list = []
|
||||
for sequence_index in range(len(self.data)):
|
||||
if sequence_stride > 1:
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(0, len(self.data[sequence_index]) - sequence_size + 1, sequence_stride, dtype=np.uint32),
|
||||
axis=-1)
|
||||
else:
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
|
||||
axis=-1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue