20 lines
593 B
Python
20 lines
593 B
Python
from typing import List, Tuple
|
|
|
|
import torch
|
|
|
|
from .utils.memory import human_size
|
|
|
|
|
|
def parameter_summary(network: torch.nn.Module) -> List[Tuple[str, Tuple[int], str]]:
|
|
""" Returns network parameter
|
|
|
|
Returns a list of tuple: name, shape (tuple os ints), size (string)
|
|
|
|
Args:
|
|
network (torch.nn.Module): network to parse
|
|
"""
|
|
parameter_info = []
|
|
for name, param in network.named_parameters():
|
|
numpy = param.detach().cpu().numpy()
|
|
parameter_info.append((name, numpy.shape, human_size(numpy.size * numpy.dtype.itemsize)))
|
|
return parameter_info
|