28 lines
959 B
Python
28 lines
959 B
Python
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
import tensorflow as tf
|
|
|
|
from src.common import DataType, Op
|
|
from src.tf_2.base import TFBase
|
|
|
|
|
|
class TFMulBench(TFBase):
|
|
def __init__(self, output_path: Path, data_type: DataType):
|
|
super().__init__(output_path, Op.MUL, data_type)
|
|
self.tensor_1: tf.Tensor = None
|
|
self.tensor_2: tf.Tensor = None
|
|
self.tensor_result: tf.Tensor = None
|
|
|
|
def pre_experiment(self, experiment_args: Tuple[int, int]):
|
|
shape_1 = experiment_args
|
|
with self.device:
|
|
self.tensor_1 = tf.ones(shape_1, dtype=self.dtype)
|
|
self.tensor_2 = tf.ones(shape_1, dtype=self.dtype)
|
|
self.tensor_result = self.tensor_1 * self.tensor_2
|
|
|
|
def experiment(self):
|
|
self.tensor_result = self.tensor_1 * self.tensor_2
|
|
|
|
def run(self, experiment_args: List[Tuple[int, int]], experiment_count: int):
|
|
super().run(experiment_args, experiment_count)
|