Pre summary callback added to Trainer

This commit is contained in:
Corentin 2021-03-10 08:43:42 +09:00
commit 92971be5f0

View file

@ -113,6 +113,9 @@ class Trainer:
loss: float, accuracy: float): loss: float, accuracy: float):
pass pass
def pre_summary_callback(self):
pass
def summary_callback( def summary_callback(
self, self,
train_inputs: torch.Tensor, train_processed: torch.Tensor, train_inputs: torch.Tensor, train_processed: torch.Tensor,
@ -218,6 +221,7 @@ class Trainer:
if self.batch_generator_val.step != 0: if self.batch_generator_val.step != 0:
self.batch_generator_val.skip_epoch() self.batch_generator_val.skip_epoch()
self.pre_summary_callback()
val_loss = 0.0 val_loss = 0.0
val_accuracy = 0.0 val_accuracy = 0.0
val_count = 0 val_count = 0