Fix SequenceBatchGenerator's pipeline

This commit is contained in:
Corentin 2021-03-23 17:36:39 +09:00
commit 89240e417f

View file

@ -125,10 +125,9 @@ class SequenceGenerator(BatchGenerator):
first_label.append(
self.label[sequence_index][start_index: start_index + self.sequence_size])
if self.pipeline is not None:
for sequence_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
for data_index, (data_sample, label_sample) in enumerate(zip(data_sequence, label_sequence)):
first_data[sequence_index][data_index], first_label[sequence_index][data_index] = self.pipeline(
data_sample, label_sample)
for batch_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
first_data[batch_index], first_label[batch_index] = self.pipeline(
np.asarray(data_sequence), np.asarray(label_sequence))
first_data = np.asarray(first_data)
first_label = np.asarray(first_label)
self.batch_data = first_data
@ -181,8 +180,9 @@ class SequenceGenerator(BatchGenerator):
for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(
[self.data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
np.asarray(
[self.data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]))
else:
data = []
for sequence_index, start_index in self.index_list[
@ -195,8 +195,9 @@ class SequenceGenerator(BatchGenerator):
for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append(
[self.label_processor(input_data)
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
np.asarray(
[self.label_processor(input_data)
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]))
else:
label = []
for sequence_index, start_index in self.index_list[
@ -205,12 +206,10 @@ class SequenceGenerator(BatchGenerator):
# Process through pipeline
if self.pipeline is not None:
for sequence_index in range(len(data)):
for data_index in range(len(data[sequence_index])):
piped_data, piped_label = self.pipeline(
data[sequence_index][data_index], label[sequence_index][data_index])
self.cache_data[self.current_cache][data_index] = piped_data
self.cache_label[self.current_cache][data_index] = piped_label
for batch_index in range(len(data)):
piped_data, piped_label = self.pipeline(data[batch_index], label[batch_index])
self.cache_data[self.current_cache][batch_index] = piped_data
self.cache_label[self.current_cache][batch_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
@ -221,17 +220,22 @@ if __name__ == '__main__':
data = np.array(
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8)
label = np.array(
[[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8)
[[10, 20, 30, 40, 50, 60, 70, 80, 90], [110, 120, 130, 140, 150, 160, 170, 180, 190]], dtype=np.uint8)
for data_processor in [None, lambda x:x]:
for prefetch in [False, True]:
for num_workers in [1, 2]:
print(f'{data_processor=} {prefetch=} {num_workers=}')
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(9):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch()
print()
def pipeline(data, label):
return data, label
for pipeline in [None, pipeline]:
for data_processor in [None, lambda x:x]:
for prefetch in [False, True]:
for num_workers in [1, 2]:
print(f'{pipeline=} {data_processor=} {prefetch=} {num_workers=}')
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, pipeline=pipeline,
prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(9):
print(batch_generator.batch_data.tolist(), batch_generator.batch_label.tolist(),
batch_generator.epoch, batch_generator.step)
batch_generator.next_batch()
print()
test()