From 86787f6517635acd49318363e24a4caa59771a2f Mon Sep 17 00:00:00 2001 From: Corentin Date: Thu, 18 Mar 2021 21:30:59 +0900 Subject: [PATCH] Pipeline implementation for BatchGenerator * Add end_epoch_callback in Trainer * Fix Layer.ACTIVATION in cas of nn.Module --- layers.py | 7 ++++++- trainer.py | 4 ++++ utils/batch_generator.py | 26 +++++++++++++++++--------- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/layers.py b/layers.py index 142aac0..6d511f6 100644 --- a/layers.py +++ b/layers.py @@ -33,7 +33,12 @@ class Layer(nn.Module): self.info = LayerInfo() # Preload default - self.activation = Layer.ACTIVATION if activation == 0 else activation + if activation == 0: + activation = Layer.ACTIVATION + if isinstance(activation, type): + self.activation = activation() + else: + self.activation = activation self.batch_norm = Layer.BATCH_NORM if batch_norm is None else batch_norm def forward(self, input_data: torch.Tensor) -> torch.Tensor: diff --git a/trainer.py b/trainer.py index 1bb7352..4059b9f 100644 --- a/trainer.py +++ b/trainer.py @@ -99,6 +99,9 @@ class Trainer: self.benchmark_step = 0 self.benchmark_time = time.time() + def end_epoch_callback(self): + pass + def train_step_callback( self, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, @@ -177,6 +180,7 @@ class Trainer: self.benchmark_step += 1 self.save_summaries() self.batch_generator_train.next_batch() + self.end_epoch_callback() except KeyboardInterrupt: if self.verbose: print() diff --git a/utils/batch_generator.py b/utils/batch_generator.py index a39e396..cfedc4c 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -11,6 +11,7 @@ class BatchGenerator: def __init__(self, data: Iterable, label: Iterable, batch_size: int, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, + pipeline: Optional[Callable] = None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, flip_data=False, save: Optional[str] = None): self.batch_size = batch_size @@ -18,6 +19,7 @@ class BatchGenerator: self.prefetch = prefetch and not preload self.num_workers = num_workers self.flip_data = flip_data + self.pipeline = pipeline if not preload: self.data_processor = data_processor @@ -261,31 +263,37 @@ class BatchGenerator: data = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: data.append(self.data_processor(self.data[entry])) - self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + data = np.asarray(data) else: data = self.data[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] if self.flip_data: flip = np.random.uniform() if flip < 0.25: - self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1] + data = data[:, :, ::-1] elif flip < 0.5: - self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1] + data = data[:, :, :, ::-1] elif flip < 0.75: - self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1] - else: - self.cache_data[self.current_cache][:len(data)] = data - else: - self.cache_data[self.current_cache][:len(data)] = data + data = data[:, :, ::-1, ::-1] + # Loading label if self.label_processor is not None: label = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: label.append(self.label_processor(self.label[entry])) - self.cache_label[self.current_cache][:len(label)] = np.asarray(label) + label = np.asarray(label) else: label = self.label[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] + + # Process through pipeline + if self.pipeline is not None: + for data_index, data_entry in enumerate(data): + piped_data, piped_label = self.pipeline(data_entry, label[data_index]) + self.cache_data[self.current_cache][data_index] = piped_data + self.cache_label[self.current_cache][data_index] = piped_label + else: + self.cache_data[self.current_cache][:len(data)] = data self.cache_label[self.current_cache][:len(label)] = label def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: