From 144ff4a00423051e0d05fdbece329010a1105046 Mon Sep 17 00:00:00 2001 From: corentin Date: Fri, 18 Dec 2020 20:30:36 +0900 Subject: [PATCH] Data flip implementation --- utils/batch_generator.py | 27 ++++++++++++++------------- utils/sequence_batch_generator.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/utils/batch_generator.py b/utils/batch_generator.py index aa278e3..4049045 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -12,12 +12,12 @@ class BatchGenerator: def __init__(self, data: Iterable, label: Iterable, batch_size: int, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, - left_right_flip=False, save: Optional[str] = None): + flip_data=False, save: Optional[str] = None): self.batch_size = batch_size self.shuffle = shuffle self.prefetch = prefetch and not preload self.num_workers = num_workers - self.left_right_flip = left_right_flip + self.flip_data = flip_data if not preload: self.data_processor = data_processor @@ -129,7 +129,6 @@ class BatchGenerator: self.release() def release(self): - print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}') if self.num_workers > 1: self.worker_stop.buf[0] = 1 for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes): @@ -141,7 +140,6 @@ class BatchGenerator: self.num_workers = 0 # Avoids double release if self.prefetch: - print(f'{self.process_id} cleaning prefetch') self.prefetch_stop.buf[0] = 1 self.prefetch_pipe_parent.send(True) self.prefetch_process.join() @@ -153,14 +151,12 @@ class BatchGenerator: self.prefetch = False # Avoids double release if self.process_id == 'main' and self.cache_memory_indices is not None: - print(f'{self.process_id} cleaning memory') self.cache_memory_indices.close() self.cache_memory_indices.unlink() for shared_mem in self.cache_memory_data + self.cache_memory_label: shared_mem.close() shared_mem.unlink() self.cache_memory_indices = None # Avoids double release - print(f'{self.process_id} released') def _prefetch_worker(self): self.prefetch = False @@ -191,7 +187,6 @@ class BatchGenerator: if self.num_workers > 1: self.release() - print(f'{self.process_id} exiting') def _worker(self, worker_index: int): self.process_id = f'worker_{worker_index}' @@ -225,7 +220,6 @@ class BatchGenerator: break except ValueError: break - print(f'{self.process_id} exiting') def skip_epoch(self): if self.prefetch: @@ -273,7 +267,18 @@ class BatchGenerator: else: data = self.data[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] - self.cache_data[self.current_cache][:len(data)] = data + if self.flip_data: + flip = np.random.uniform() + if flip < 0.25: + self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1] + elif flip < 0.5: + self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1] + elif flip < 0.75: + self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1] + else: + self.cache_data[self.current_cache][:len(data)] = data + else: + self.cache_data[self.current_cache][:len(data)] = data # Loading label if self.label_processor is not None: label = [] @@ -314,10 +319,6 @@ class BatchGenerator: self.batch_data = self.cache_data[self.current_cache] self.batch_label = self.cache_label[self.current_cache] - if self.left_right_flip and np.random.uniform() > 0.5: - self.batch_data = self.batch_data[:, :, ::-1] - self.batch_label = self.batch_label[:, :, ::-1] - return self.batch_data, self.batch_label diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 3b0cbd2..04cd55e 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -17,13 +17,13 @@ class SequenceGenerator(BatchGenerator): def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, - save: Optional[str] = None): + flip_data=False, save: Optional[str] = None): self.batch_size = batch_size self.sequence_size = sequence_size self.shuffle = shuffle self.prefetch = prefetch and not preload self.num_workers = num_workers - self.left_right_flip = False + self.flip_data = flip_data if not preload: self.data_processor = data_processor @@ -167,13 +167,35 @@ class SequenceGenerator(BatchGenerator): data.append( [self.data_processor(input_data) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + if self.flip_data: + flip = np.random.uniform() + if flip < 0.25: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1] + elif flip < 0.5: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1] + elif flip < 0.75: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1] + else: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + else: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) else: data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: data.append(self.data[sequence_index][start_index: start_index + self.sequence_size]) - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + if self.flip_data: + flip = np.random.uniform() + if flip < 0.25: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1] + elif flip < 0.5: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1] + elif flip < 0.75: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1] + else: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + else: + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) # Loading label if self.label_processor is not None: