Fix SequenceBatchGenerator's pipeline
This commit is contained in:
parent
73457750a8
commit
89240e417f
1 changed files with 29 additions and 25 deletions
|
|
@ -125,10 +125,9 @@ class SequenceGenerator(BatchGenerator):
|
|||
first_label.append(
|
||||
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||
if self.pipeline is not None:
|
||||
for sequence_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
|
||||
for data_index, (data_sample, label_sample) in enumerate(zip(data_sequence, label_sequence)):
|
||||
first_data[sequence_index][data_index], first_label[sequence_index][data_index] = self.pipeline(
|
||||
data_sample, label_sample)
|
||||
for batch_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
|
||||
first_data[batch_index], first_label[batch_index] = self.pipeline(
|
||||
np.asarray(data_sequence), np.asarray(label_sequence))
|
||||
first_data = np.asarray(first_data)
|
||||
first_label = np.asarray(first_label)
|
||||
self.batch_data = first_data
|
||||
|
|
@ -181,8 +180,9 @@ class SequenceGenerator(BatchGenerator):
|
|||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
data.append(
|
||||
[self.data_processor(input_data)
|
||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
np.asarray(
|
||||
[self.data_processor(input_data)
|
||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]))
|
||||
else:
|
||||
data = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
|
|
@ -195,8 +195,9 @@ class SequenceGenerator(BatchGenerator):
|
|||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
label.append(
|
||||
[self.label_processor(input_data)
|
||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
np.asarray(
|
||||
[self.label_processor(input_data)
|
||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]))
|
||||
else:
|
||||
label = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
|
|
@ -205,12 +206,10 @@ class SequenceGenerator(BatchGenerator):
|
|||
|
||||
# Process through pipeline
|
||||
if self.pipeline is not None:
|
||||
for sequence_index in range(len(data)):
|
||||
for data_index in range(len(data[sequence_index])):
|
||||
piped_data, piped_label = self.pipeline(
|
||||
data[sequence_index][data_index], label[sequence_index][data_index])
|
||||
self.cache_data[self.current_cache][data_index] = piped_data
|
||||
self.cache_label[self.current_cache][data_index] = piped_label
|
||||
for batch_index in range(len(data)):
|
||||
piped_data, piped_label = self.pipeline(data[batch_index], label[batch_index])
|
||||
self.cache_data[self.current_cache][batch_index] = piped_data
|
||||
self.cache_label[self.current_cache][batch_index] = piped_label
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||
|
|
@ -221,17 +220,22 @@ if __name__ == '__main__':
|
|||
data = np.array(
|
||||
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8)
|
||||
label = np.array(
|
||||
[[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8)
|
||||
[[10, 20, 30, 40, 50, 60, 70, 80, 90], [110, 120, 130, 140, 150, 160, 170, 180, 190]], dtype=np.uint8)
|
||||
|
||||
for data_processor in [None, lambda x:x]:
|
||||
for prefetch in [False, True]:
|
||||
for num_workers in [1, 2]:
|
||||
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(9):
|
||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
print()
|
||||
def pipeline(data, label):
|
||||
return data, label
|
||||
|
||||
for pipeline in [None, pipeline]:
|
||||
for data_processor in [None, lambda x:x]:
|
||||
for prefetch in [False, True]:
|
||||
for num_workers in [1, 2]:
|
||||
print(f'{pipeline=} {data_processor=} {prefetch=} {num_workers=}')
|
||||
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, pipeline=pipeline,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(9):
|
||||
print(batch_generator.batch_data.tolist(), batch_generator.batch_label.tolist(),
|
||||
batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
print()
|
||||
|
||||
test()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue