Jax implementation, code factorisation
* Compatibility for older python version (typing)
This commit is contained in:
parent
4b2bcfe7e8
commit
16b7239cd7
37 changed files with 1007 additions and 293 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
|
@ -7,28 +8,21 @@ from src.tf_2.base import TFBase
|
|||
|
||||
|
||||
class TFMulBench(TFBase):
|
||||
def __init__(self, output_path: Path):
|
||||
super().__init__(output_path, Op.MUL)
|
||||
def __init__(self, output_path: Path, data_type: DataType):
|
||||
super().__init__(output_path, Op.MUL, data_type)
|
||||
self.tensor_1: tf.Tensor = None
|
||||
self.tensor_2: tf.Tensor = None
|
||||
self.tensor_result: tf.Tensor = None
|
||||
|
||||
def experiment(self, experiment_args: tuple[int, int], length: int, dtype: tf.DType, device: tf.device):
|
||||
def pre_experiment(self, experiment_args: Tuple[int, int]):
|
||||
shape_1 = experiment_args
|
||||
with device:
|
||||
tensor_1 = tf.ones(shape_1, dtype=dtype)
|
||||
tensor_2 = tf.ones(shape_1, dtype=dtype)
|
||||
with self.device:
|
||||
self.tensor_1 = tf.ones(shape_1, dtype=self.dtype)
|
||||
self.tensor_2 = tf.ones(shape_1, dtype=self.dtype)
|
||||
self.tensor_result = self.tensor_1 * self.tensor_2
|
||||
|
||||
for _ in range(length):
|
||||
_ = tensor_1 * tensor_2
|
||||
def experiment(self):
|
||||
self.tensor_result = self.tensor_1 * self.tensor_2
|
||||
|
||||
def name(self, experiment_args: tuple[int, int]) -> str:
|
||||
shape_1 = experiment_args
|
||||
return f'{shape_1[0]}x{shape_1[1]} * {shape_1[0]}x{shape_1[1]}'
|
||||
|
||||
def mop(self, experiment_args: tuple[int, int]) -> float:
|
||||
shape_1 = experiment_args
|
||||
return shape_1[0] * shape_1[1] / 1000_000
|
||||
|
||||
def run(self,
|
||||
experiment_args: list[tuple[int, int]],
|
||||
experiment_count: int,
|
||||
data_type: DataType):
|
||||
super().run(experiment_args, experiment_count, data_type)
|
||||
def run(self, experiment_args: List[Tuple[int, int]], experiment_count: int):
|
||||
super().run(experiment_args, experiment_count)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue