torch_utils/train.py
2020-03-31 13:46:01 +09:00

20 lines
593 B
Python

from typing import List, Tuple
import torch
from .utils.memory import human_size
def parameter_summary(network: torch.nn.Module) -> List[Tuple[str, Tuple[int], str]]:
""" Returns network parameter
Returns a list of tuple: name, shape (tuple os ints), size (string)
Args:
network (torch.nn.Module): network to parse
"""
parameter_info = []
for name, param in network.named_parameters():
numpy = param.detach().cpu().numpy()
parameter_info.append((name, numpy.shape, human_size(numpy.size * numpy.dtype.itemsize)))
return parameter_info