from pathlib import Path from typing import List, Tuple import tensorflow as tf from src.common import DataType, Op from src.tf_2.base import TFBase class TFMulBench(TFBase): def __init__(self, output_path: Path, data_type: DataType): super().__init__(output_path, Op.MUL, data_type) self.tensor_1: tf.Tensor = None self.tensor_2: tf.Tensor = None self.tensor_result: tf.Tensor = None def pre_experiment(self, experiment_args: Tuple[int, int]): shape_1 = experiment_args with self.device: self.tensor_1 = tf.ones(shape_1, dtype=self.dtype) self.tensor_2 = tf.ones(shape_1, dtype=self.dtype) self.tensor_result = self.tensor_1 * self.tensor_2 def experiment(self): self.tensor_result = self.tensor_1 * self.tensor_2 def run(self, experiment_args: List[Tuple[int, int]], experiment_count: int): super().run(experiment_args, experiment_count)