from pathlib import Path from typing import List, Tuple import tensorflow as tf from src.common import DataType, Op from src.tf_2.base import TFBase class DenseModel(tf.keras.Model): def __init__(self, input_dim: int, dtype=tf.DType): super().__init__() self.dense = tf.keras.layers.Dense(input_dim, dtype=dtype) def call(self, input_tensor: tf.Tensor) -> tf.Tensor: return self.dense(input_tensor) class TFNNDenseBench(TFBase): def __init__(self, output_path: Path, data_type: DataType): super().__init__(output_path, Op.NN_DENSE, data_type) self.tensor: tf.Tensor = None self.network: tf.keras.Model = None def pre_experiment(self, experiment_args: Tuple[int, int]): batch_size, dimension = experiment_args with self.device: self.tensor = tf.ones((batch_size, dimension), dtype=self.dtype) self.network = DenseModel(dimension, self.dtype) def experiment(self): self.network(self.tensor) def run(self, experiment_args: List[Tuple[int, int]], experiment_count: int): super().run(experiment_args, experiment_count)