Jax implementation, code factorisation
* Compatibility for older python version (typing)
This commit is contained in:
parent
4b2bcfe7e8
commit
16b7239cd7
37 changed files with 1007 additions and 293 deletions
28
src/jax/matmul.py
Normal file
28
src/jax/matmul.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from jax import device_put
|
||||
import jax.numpy as jnp
|
||||
|
||||
from src.common import DataType, Op
|
||||
from src.jax.base import JaxBase
|
||||
|
||||
|
||||
class JaxMatmulBench(JaxBase):
|
||||
def __init__(self, output_path: Path, data_type: DataType):
|
||||
super().__init__(output_path, Op.MATMUL, data_type)
|
||||
self.tensor_1: jnp.DeviceArray = None
|
||||
self.tensor_2: jnp.DeviceArray = None
|
||||
self.tensor_result: jnp.DeviceArray = None
|
||||
|
||||
def pre_experiment(self, experiment_args: Tuple[int, int]):
|
||||
shape_1, shape_2 = experiment_args
|
||||
self.tensor_1 = device_put(jnp.ones(shape_1, dtype=self.dtype))
|
||||
self.tensor_2 = device_put(jnp.ones(shape_2, dtype=self.dtype))
|
||||
self.tensor_result = jnp.matmul(self.tensor_1, self.tensor_2).block_until_ready()
|
||||
|
||||
def experiment(self):
|
||||
self.tensor_result = jnp.matmul(self.tensor_1, self.tensor_2).block_until_ready()
|
||||
|
||||
def run(self, experiment_args: List[Tuple[Tuple[int, int], Tuple[int, int]]], experiment_count: int):
|
||||
super().run(experiment_args, experiment_count)
|
||||
Loading…
Add table
Add a link
Reference in a new issue