From 73457750a8ca0be957fb0ffbca4a15c5a71d3b2b Mon Sep 17 00:00:00 2001 From: Corentin Date: Tue, 23 Mar 2021 02:13:40 +0900 Subject: [PATCH] Add pipeline to SequenceBatchGenerator * Remove flip_data parameter (can be done in the pipeline) --- utils/batch_generator.py | 13 +++++--- utils/sequence_batch_generator.py | 54 +++++++++++++------------------ 2 files changed, 31 insertions(+), 36 deletions(-) diff --git a/utils/batch_generator.py b/utils/batch_generator.py index cfedc4c..b8257ef 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -68,10 +68,15 @@ class BatchGenerator: self.global_step = 0 self.step = 0 - first_data = np.array([data_processor(entry) if data_processor else entry - for entry in self.data[self.index_list[:batch_size]]]) - first_label = np.array([label_processor(entry) if label_processor else entry - for entry in self.label[self.index_list[:batch_size]]]) + first_data = [data_processor(entry) if data_processor else entry + for entry in self.data[self.index_list[:batch_size]]] + first_label = [label_processor(entry) if label_processor else entry + for entry in self.label[self.index_list[:batch_size]]] + if self.pipeline is not None: + for data_index, sample_data in enumerate(first_data): + first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index]) + first_data = np.asarray(first_data) + first_label = np.asarray(first_label) self.batch_data = first_data self.batch_label = first_label diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 7affb81..524e58b 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -16,14 +16,15 @@ class SequenceGenerator(BatchGenerator): def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, - index_list=None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, - sequence_stride=1, flip_data=False, save: Optional[str] = None): + pipeline: Optional[Callable] = None, index_list=None, + prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, + sequence_stride=1, save: Optional[str] = None): self.batch_size = batch_size self.sequence_size = sequence_size self.shuffle = shuffle self.prefetch = prefetch and not preload self.num_workers = num_workers - self.flip_data = flip_data + self.pipeline = pipeline if not preload: self.data_processor = data_processor @@ -107,26 +108,29 @@ class SequenceGenerator(BatchGenerator): first_data.append( [data_processor(input_data) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) - first_data = np.asarray(first_data) else: first_data = [] for sequence_index, start_index in self.index_list[:batch_size]: first_data.append( self.data[sequence_index][start_index: start_index + self.sequence_size]) - first_data = np.asarray(first_data) if label_processor: first_label = [] for sequence_index, start_index in self.index_list[:batch_size]: first_label.append( [label_processor(input_label) for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]]) - first_label = np.asarray(first_label) else: first_label = [] for sequence_index, start_index in self.index_list[:batch_size]: first_label.append( self.label[sequence_index][start_index: start_index + self.sequence_size]) - first_label = np.asarray(first_label) + if self.pipeline is not None: + for sequence_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)): + for data_index, (data_sample, label_sample) in enumerate(zip(data_sequence, label_sequence)): + first_data[sequence_index][data_index], first_label[sequence_index][data_index] = self.pipeline( + data_sample, label_sample) + first_data = np.asarray(first_data) + first_label = np.asarray(first_label) self.batch_data = first_data self.batch_label = first_label @@ -179,35 +183,11 @@ class SequenceGenerator(BatchGenerator): data.append( [self.data_processor(input_data) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) - if self.flip_data: - flip = np.random.uniform() - if flip < 0.25: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1] - elif flip < 0.5: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1] - elif flip < 0.75: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1] - else: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) - else: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) else: data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: data.append(self.data[sequence_index][start_index: start_index + self.sequence_size]) - if self.flip_data: - flip = np.random.uniform() - if flip < 0.25: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1] - elif flip < 0.5: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1] - elif flip < 0.75: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1] - else: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) - else: - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) # Loading label if self.label_processor is not None: @@ -217,12 +197,22 @@ class SequenceGenerator(BatchGenerator): label.append( [self.label_processor(input_data) for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) - self.cache_label[self.current_cache][:len(label)] = np.asarray(label) else: label = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: label.append(self.label[sequence_index][start_index: start_index + self.sequence_size]) + + # Process through pipeline + if self.pipeline is not None: + for sequence_index in range(len(data)): + for data_index in range(len(data[sequence_index])): + piped_data, piped_label = self.pipeline( + data[sequence_index][data_index], label[sequence_index][data_index]) + self.cache_data[self.current_cache][data_index] = piped_data + self.cache_label[self.current_cache][data_index] = piped_label + else: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) self.cache_label[self.current_cache][:len(label)] = np.asarray(label)