from pathlib import Path from typing import List, Tuple import torch from src.common import DataType, Op from src.pytorch.base import TorchBase class DenseNetwork(torch.nn.Module): def __init__(self, input_dim: int, dtype: torch.dtype): super().__init__() self.dense = torch.nn.Sequential( *[torch.nn.Linear(input_dim, input_dim, dtype=dtype) for _ in range(5)]) def forward(self, input_data: torch.Tensor) -> torch.Tensor: return self.dense(input_data) class TorchNNDenseX5Bench(TorchBase): def __init__(self, output_path: Path, data_type: DataType): super().__init__(output_path, Op.NN_DENSE_X5, data_type) self.tensor: torch.Tensor = None self.tensor_result: torch.Tensor = None self.network: torch.nn.Module = None def pre_experiment(self, experiment_args: Tuple[int, int]): batch_size, dimension = experiment_args self.tensor = torch.ones((batch_size, dimension), dtype=self.dtype, device=self.device, requires_grad=False) self.network = DenseNetwork(dimension, self.dtype).to(self.device) self.tensor_result = self.network(self.tensor) def experiment(self): self.tensor_result = self.network(self.tensor) def run(self, experiment_args: List[Tuple[int, int]], experiment_count: int): super().run(experiment_args, experiment_count)