Fix MNIST test data loading, add named buffer summary

This commit is contained in:
Corentin Risselin 2020-04-17 12:08:16 +09:00
commit 7db99ffa51
3 changed files with 10 additions and 3 deletions

View file

@ -20,6 +20,9 @@ def parameter_summary(network: torch.nn.Module) -> List[Tuple[str, Tuple[int], s
for name, param in network.named_parameters():
numpy = param.detach().cpu().numpy()
parameter_info.append((name, numpy.shape, human_size(numpy.size * numpy.dtype.itemsize)))
for name, param in network.named_buffers():
numpy = param.detach().cpu().numpy()
parameter_info.append((name, numpy.shape, human_size(numpy.size * numpy.dtype.itemsize)))
return parameter_info