Sequence stride option for SequenceBatchGenerator
This commit is contained in:
parent
92971be5f0
commit
51776d6999
1 changed files with 9 additions and 4 deletions
|
|
@ -17,7 +17,7 @@ class SequenceGenerator(BatchGenerator):
|
||||||
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||||
flip_data=False, save: Optional[str] = None):
|
sequence_stride=1, flip_data=False, save: Optional[str] = None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.sequence_size = sequence_size
|
self.sequence_size = sequence_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
@ -72,6 +72,11 @@ class SequenceGenerator(BatchGenerator):
|
||||||
|
|
||||||
self.index_list = []
|
self.index_list = []
|
||||||
for sequence_index in range(len(self.data)):
|
for sequence_index in range(len(self.data)):
|
||||||
|
if sequence_stride > 1:
|
||||||
|
start_indices = np.expand_dims(
|
||||||
|
np.arange(0, len(self.data[sequence_index]) - sequence_size + 1, sequence_stride, dtype=np.uint32),
|
||||||
|
axis=-1)
|
||||||
|
else:
|
||||||
start_indices = np.expand_dims(
|
start_indices = np.expand_dims(
|
||||||
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
|
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue