Pipeline implementation for BatchGenerator

* Add end_epoch_callback in Trainer
* Fix Layer.ACTIVATION in cas of nn.Module
This commit is contained in:
Corentin 2021-03-18 21:30:59 +09:00
commit 86787f6517
3 changed files with 27 additions and 10 deletions

View file

@ -33,7 +33,12 @@ class Layer(nn.Module):
self.info = LayerInfo()
# Preload default
self.activation = Layer.ACTIVATION if activation == 0 else activation
if activation == 0:
activation = Layer.ACTIVATION
if isinstance(activation, type):
self.activation = activation()
else:
self.activation = activation
self.batch_norm = Layer.BATCH_NORM if batch_norm is None else batch_norm
def forward(self, input_data: torch.Tensor) -> torch.Tensor: