Pipeline implementation for BatchGenerator

* Add end_epoch_callback in Trainer
* Fix Layer.ACTIVATION in cas of nn.Module
This commit is contained in:
Corentin 2021-03-18 21:30:59 +09:00
commit 86787f6517
3 changed files with 27 additions and 10 deletions

View file

@ -99,6 +99,9 @@ class Trainer:
self.benchmark_step = 0
self.benchmark_time = time.time()
def end_epoch_callback(self):
pass
def train_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
@ -177,6 +180,7 @@ class Trainer:
self.benchmark_step += 1
self.save_summaries()
self.batch_generator_train.next_batch()
self.end_epoch_callback()
except KeyboardInterrupt:
if self.verbose:
print()