Add TF1 platform, docker and README

This commit is contained in:
Corentin 2021-10-05 10:51:18 +09:00
commit dbe5490c5b
28 changed files with 655 additions and 34 deletions

View file

@ -14,12 +14,26 @@ from src.plot import compare
def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op,
bench_args, bench_count: int):
if platform == Platform.JAX:
if data_type == DataType.FLOAT64:
os.environ['JAX_ENABLE_X64'] = 'true'
from src.jax.ops import jax_ops
if bench_op not in jax_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
jax_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
elif platform == Platform.TF1:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
if tf.__version__.split('.')[0] != '1':
print(f'Cannot run benchmark for platform TF1 with tensorflow version: {tf.__version__}')
return
from src.tf_1.ops import tf1_ops
if bench_op not in tf1_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
tf1_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
elif platform == Platform.TF2:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from src.tf_2.ops import tf2_ops
@ -30,6 +44,10 @@ def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, be
print()
elif platform == Platform.TF2_V1:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
if tf.__version__.split('.')[0] != '2':
print(f'Cannot run benchmark for platform TF2_V1 with tensorflow version: {tf.__version__}')
return
from src.tf_2_v1.ops import tf2v1_ops
if bench_op not in tf2v1_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
@ -49,7 +67,8 @@ def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, be
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files')
parser.add_argument('--output', type=Path, default=Path('output'),
help='Path to output files (default: output)')
parser.add_argument('--no-benchmark', action='store_true', default=False, help='Avoid running benchmarks')
parser.add_argument('--no-compare', action='store_true', default=False, help='Avoid running platform comparaison')
parser.add_argument('--count', type=int, default=30,
@ -62,6 +81,10 @@ def main():
help='List of operation to benchmark (add, mul, div, matmul, etc) (else all are used)')
parser.add_argument('--list-op', action='store_true',
help='List all possible operation to benchmark (no further action will be done)')
parser.add_argument('--list-platform', action='store_true',
help='List all possible platform to benchmark (no further action will be done)')
parser.add_argument('--list-data', action='store_true',
help='List all possible data to benchmark (no further action will be done)')
parser.add_argument(
'--experiment-time', type=float,
help=f'Change time (in s) per experiment (default={Config.EXPERIMENT_TIME:0.3f}s)')
@ -70,6 +93,12 @@ def main():
if arguments.list_op:
print(', '.join([op.value for op in Op]))
sys.exit(0)
if arguments.list_platform:
print(', '.join([p.value for p in Platform]))
sys.exit(0)
if arguments.list_data:
print(', '.join([d.value for d in DataType]))
sys.exit(0)
output_path: Path = arguments.output
no_benchmark: bool = arguments.no_benchmark
@ -91,13 +120,16 @@ def main():
for data_type in data:
for bench_op in [Op.ADD, Op.MUL, Op.DIV]:
if bench_op in bench_ops:
benchmarks.append((output_path, platform, data_type, bench_op,
Config.ELEMENT_WISE_ARGS, bench_count))
benchmarks.append(
(output_path, platform, data_type, bench_op, Config.ELEMENT_WISE_ARGS, bench_count))
for bench_op in [Op.MATMUL, Op.NN_MATMUL]:
if bench_op in bench_ops:
benchmarks.append((output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count))
if Op.NN_DENSE in bench_ops:
benchmarks.append((output_path, platform, data_type, Op.NN_DENSE, Config.NN_1D_ARGS, bench_count))
benchmarks.append(
(output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count))
for bench_op in [Op.NN_DENSE, Op.NN_DENSE_X5]:
if bench_op in bench_ops:
benchmarks.append(
(output_path, platform, data_type, bench_op, Config.NN_1D_ARGS, bench_count))
if benchmarks:
for benchmark in benchmarks: