Jax implementation, code factorisation
* Compatibility for older python version (typing)
This commit is contained in:
parent
4b2bcfe7e8
commit
16b7239cd7
37 changed files with 1007 additions and 293 deletions
35
src/tf_2/nn_dense.py
Normal file
35
src/tf_2/nn_dense.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from src.common import DataType, Op
|
||||
from src.tf_2.base import TFBase
|
||||
|
||||
|
||||
class DenseModel(tf.keras.Model):
|
||||
def __init__(self, input_dim: int, dtype=tf.DType):
|
||||
super().__init__()
|
||||
self.dense = tf.keras.layers.Dense(input_dim, dtype=dtype)
|
||||
|
||||
def call(self, input_tensor: tf.Tensor) -> tf.Tensor:
|
||||
return self.dense(input_tensor)
|
||||
|
||||
|
||||
class TFNNDenseBench(TFBase):
|
||||
def __init__(self, output_path: Path, data_type: DataType):
|
||||
super().__init__(output_path, Op.NN_DENSE, data_type)
|
||||
self.tensor: tf.Tensor = None
|
||||
self.network: tf.keras.Model = None
|
||||
|
||||
def pre_experiment(self, experiment_args: Tuple[int, int]):
|
||||
batch_size, dimension = experiment_args
|
||||
with self.device:
|
||||
self.tensor = tf.ones((batch_size, dimension), dtype=self.dtype)
|
||||
self.network = DenseModel(dimension, self.dtype)
|
||||
|
||||
def experiment(self):
|
||||
self.network(self.tensor)
|
||||
|
||||
def run(self, experiment_args: List[Tuple[int, int]], experiment_count: int):
|
||||
super().run(experiment_args, experiment_count)
|
||||
Loading…
Add table
Add a link
Reference in a new issue