from pathlib import Path from typing import List, Tuple from jax import device_put import jax.numpy as jnp from src.common import DataType, Op from src.jax.base import JaxBase class JaxMatmulBench(JaxBase): def __init__(self, output_path: Path, data_type: DataType): super().__init__(output_path, Op.MATMUL, data_type) self.tensor_1: jnp.DeviceArray = None self.tensor_2: jnp.DeviceArray = None self.tensor_result: jnp.DeviceArray = None def pre_experiment(self, experiment_args: Tuple[int, int]): shape_1, shape_2 = experiment_args self.tensor_1 = device_put(jnp.ones(shape_1, dtype=self.dtype)) self.tensor_2 = device_put(jnp.ones(shape_2, dtype=self.dtype)) self.tensor_result = jnp.matmul(self.tensor_1, self.tensor_2).block_until_ready() def experiment(self): self.tensor_result = jnp.matmul(self.tensor_1, self.tensor_2).block_until_ready() def run(self, experiment_args: List[Tuple[Tuple[int, int], Tuple[int, int]]], experiment_count: int): super().run(experiment_args, experiment_count)