Add TF1 platform, docker and README
This commit is contained in:
parent
16b7239cd7
commit
dbe5490c5b
28 changed files with 655 additions and 34 deletions
44
benchmark.py
44
benchmark.py
|
|
@ -14,12 +14,26 @@ from src.plot import compare
|
|||
def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op,
|
||||
bench_args, bench_count: int):
|
||||
if platform == Platform.JAX:
|
||||
if data_type == DataType.FLOAT64:
|
||||
os.environ['JAX_ENABLE_X64'] = 'true'
|
||||
from src.jax.ops import jax_ops
|
||||
if bench_op not in jax_ops:
|
||||
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
||||
else:
|
||||
jax_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
||||
print()
|
||||
elif platform == Platform.TF1:
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import tensorflow as tf
|
||||
if tf.__version__.split('.')[0] != '1':
|
||||
print(f'Cannot run benchmark for platform TF1 with tensorflow version: {tf.__version__}')
|
||||
return
|
||||
from src.tf_1.ops import tf1_ops
|
||||
if bench_op not in tf1_ops:
|
||||
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
||||
else:
|
||||
tf1_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
||||
print()
|
||||
elif platform == Platform.TF2:
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
from src.tf_2.ops import tf2_ops
|
||||
|
|
@ -30,6 +44,10 @@ def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, be
|
|||
print()
|
||||
elif platform == Platform.TF2_V1:
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import tensorflow as tf
|
||||
if tf.__version__.split('.')[0] != '2':
|
||||
print(f'Cannot run benchmark for platform TF2_V1 with tensorflow version: {tf.__version__}')
|
||||
return
|
||||
from src.tf_2_v1.ops import tf2v1_ops
|
||||
if bench_op not in tf2v1_ops:
|
||||
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
||||
|
|
@ -49,7 +67,8 @@ def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, be
|
|||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files')
|
||||
parser.add_argument('--output', type=Path, default=Path('output'),
|
||||
help='Path to output files (default: output)')
|
||||
parser.add_argument('--no-benchmark', action='store_true', default=False, help='Avoid running benchmarks')
|
||||
parser.add_argument('--no-compare', action='store_true', default=False, help='Avoid running platform comparaison')
|
||||
parser.add_argument('--count', type=int, default=30,
|
||||
|
|
@ -62,6 +81,10 @@ def main():
|
|||
help='List of operation to benchmark (add, mul, div, matmul, etc) (else all are used)')
|
||||
parser.add_argument('--list-op', action='store_true',
|
||||
help='List all possible operation to benchmark (no further action will be done)')
|
||||
parser.add_argument('--list-platform', action='store_true',
|
||||
help='List all possible platform to benchmark (no further action will be done)')
|
||||
parser.add_argument('--list-data', action='store_true',
|
||||
help='List all possible data to benchmark (no further action will be done)')
|
||||
parser.add_argument(
|
||||
'--experiment-time', type=float,
|
||||
help=f'Change time (in s) per experiment (default={Config.EXPERIMENT_TIME:0.3f}s)')
|
||||
|
|
@ -70,6 +93,12 @@ def main():
|
|||
if arguments.list_op:
|
||||
print(', '.join([op.value for op in Op]))
|
||||
sys.exit(0)
|
||||
if arguments.list_platform:
|
||||
print(', '.join([p.value for p in Platform]))
|
||||
sys.exit(0)
|
||||
if arguments.list_data:
|
||||
print(', '.join([d.value for d in DataType]))
|
||||
sys.exit(0)
|
||||
|
||||
output_path: Path = arguments.output
|
||||
no_benchmark: bool = arguments.no_benchmark
|
||||
|
|
@ -91,13 +120,16 @@ def main():
|
|||
for data_type in data:
|
||||
for bench_op in [Op.ADD, Op.MUL, Op.DIV]:
|
||||
if bench_op in bench_ops:
|
||||
benchmarks.append((output_path, platform, data_type, bench_op,
|
||||
Config.ELEMENT_WISE_ARGS, bench_count))
|
||||
benchmarks.append(
|
||||
(output_path, platform, data_type, bench_op, Config.ELEMENT_WISE_ARGS, bench_count))
|
||||
for bench_op in [Op.MATMUL, Op.NN_MATMUL]:
|
||||
if bench_op in bench_ops:
|
||||
benchmarks.append((output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count))
|
||||
if Op.NN_DENSE in bench_ops:
|
||||
benchmarks.append((output_path, platform, data_type, Op.NN_DENSE, Config.NN_1D_ARGS, bench_count))
|
||||
benchmarks.append(
|
||||
(output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count))
|
||||
for bench_op in [Op.NN_DENSE, Op.NN_DENSE_X5]:
|
||||
if bench_op in bench_ops:
|
||||
benchmarks.append(
|
||||
(output_path, platform, data_type, bench_op, Config.NN_1D_ARGS, bench_count))
|
||||
|
||||
if benchmarks:
|
||||
for benchmark in benchmarks:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue