diff --git a/trainer.py b/trainer.py index 4f94396..1bb7352 100644 --- a/trainer.py +++ b/trainer.py @@ -113,6 +113,9 @@ class Trainer: loss: float, accuracy: float): pass + def pre_summary_callback(self): + pass + def summary_callback( self, train_inputs: torch.Tensor, train_processed: torch.Tensor, @@ -218,6 +221,7 @@ class Trainer: if self.batch_generator_val.step != 0: self.batch_generator_val.skip_epoch() + self.pre_summary_callback() val_loss = 0.0 val_accuracy = 0.0 val_count = 0