diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 524e58b..005caf3 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -125,10 +125,9 @@ class SequenceGenerator(BatchGenerator): first_label.append( self.label[sequence_index][start_index: start_index + self.sequence_size]) if self.pipeline is not None: - for sequence_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)): - for data_index, (data_sample, label_sample) in enumerate(zip(data_sequence, label_sequence)): - first_data[sequence_index][data_index], first_label[sequence_index][data_index] = self.pipeline( - data_sample, label_sample) + for batch_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)): + first_data[batch_index], first_label[batch_index] = self.pipeline( + np.asarray(data_sequence), np.asarray(label_sequence)) first_data = np.asarray(first_data) first_label = np.asarray(first_label) self.batch_data = first_data @@ -181,8 +180,9 @@ class SequenceGenerator(BatchGenerator): for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: data.append( - [self.data_processor(input_data) - for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) + np.asarray( + [self.data_processor(input_data) + for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])) else: data = [] for sequence_index, start_index in self.index_list[ @@ -195,8 +195,9 @@ class SequenceGenerator(BatchGenerator): for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: label.append( - [self.label_processor(input_data) - for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) + np.asarray( + [self.label_processor(input_data) + for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])) else: label = [] for sequence_index, start_index in self.index_list[ @@ -205,12 +206,10 @@ class SequenceGenerator(BatchGenerator): # Process through pipeline if self.pipeline is not None: - for sequence_index in range(len(data)): - for data_index in range(len(data[sequence_index])): - piped_data, piped_label = self.pipeline( - data[sequence_index][data_index], label[sequence_index][data_index]) - self.cache_data[self.current_cache][data_index] = piped_data - self.cache_label[self.current_cache][data_index] = piped_label + for batch_index in range(len(data)): + piped_data, piped_label = self.pipeline(data[batch_index], label[batch_index]) + self.cache_data[self.current_cache][batch_index] = piped_data + self.cache_label[self.current_cache][batch_index] = piped_label else: self.cache_data[self.current_cache][:len(data)] = np.asarray(data) self.cache_label[self.current_cache][:len(label)] = np.asarray(label) @@ -221,17 +220,22 @@ if __name__ == '__main__': data = np.array( [[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8) label = np.array( - [[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8) + [[10, 20, 30, 40, 50, 60, 70, 80, 90], [110, 120, 130, 140, 150, 160, 170, 180, 190]], dtype=np.uint8) - for data_processor in [None, lambda x:x]: - for prefetch in [False, True]: - for num_workers in [1, 2]: - print(f'{data_processor=} {prefetch=} {num_workers=}') - with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, - prefetch=prefetch, num_workers=num_workers) as batch_generator: - for _ in range(9): - print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) - batch_generator.next_batch() - print() + def pipeline(data, label): + return data, label + + for pipeline in [None, pipeline]: + for data_processor in [None, lambda x:x]: + for prefetch in [False, True]: + for num_workers in [1, 2]: + print(f'{pipeline=} {data_processor=} {prefetch=} {num_workers=}') + with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, pipeline=pipeline, + prefetch=prefetch, num_workers=num_workers) as batch_generator: + for _ in range(9): + print(batch_generator.batch_data.tolist(), batch_generator.batch_label.tolist(), + batch_generator.epoch, batch_generator.step) + batch_generator.next_batch() + print() test()