Index list parameter for SequenceBatchGenerator
This commit is contained in:
parent
86787f6517
commit
d0fd6b2642
1 changed files with 21 additions and 14 deletions
|
|
@ -16,7 +16,7 @@ class SequenceGenerator(BatchGenerator):
|
|||
|
||||
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
index_list=None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
sequence_stride=1, flip_data=False, save: Optional[str] = None):
|
||||
self.batch_size = batch_size
|
||||
self.sequence_size = sequence_size
|
||||
|
|
@ -70,11 +70,17 @@ class SequenceGenerator(BatchGenerator):
|
|||
h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index])
|
||||
h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index])
|
||||
|
||||
if index_list is not None:
|
||||
self.index_list = index_list
|
||||
else:
|
||||
self.index_list = []
|
||||
for sequence_index in range(len(self.data)):
|
||||
if sequence_stride > 1:
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(0, len(self.data[sequence_index]) - sequence_size + 1, sequence_stride, dtype=np.uint32),
|
||||
np.arange(0,
|
||||
len(self.data[sequence_index]) - sequence_size + 1,
|
||||
sequence_stride,
|
||||
dtype=np.uint32),
|
||||
axis=-1)
|
||||
else:
|
||||
start_indices = np.expand_dims(
|
||||
|
|
@ -83,6 +89,7 @@ class SequenceGenerator(BatchGenerator):
|
|||
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
||||
self.index_list.append(start_indices)
|
||||
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||
|
||||
if shuffle or initial_shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
self.step_per_epoch = len(self.index_list) // self.batch_size
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue