dl_bench/benchmark.py
2021-10-06 13:48:59 +09:00

147 lines
6.6 KiB
Python

from argparse import ArgumentParser
import multiprocessing as mp
import os
from pathlib import Path
import sys
from typing import List, Type
from config.benchmark import Config
from src.base import BenchBase
from src.common import DataType, Op, Platform
from src.plot import compare
def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op,
bench_args, bench_count: int):
if platform == Platform.JAX:
if data_type == DataType.FLOAT64:
os.environ['JAX_ENABLE_X64'] = 'true'
from src.jax.ops import jax_ops
if bench_op not in jax_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
jax_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
elif platform == Platform.TF1:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
if tf.__version__.split('.')[0] != '1':
print(f'Cannot run benchmark for platform TF1 with tensorflow version: {tf.__version__}')
return
from src.tf_1.ops import tf1_ops
if bench_op not in tf1_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
tf1_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
elif platform == Platform.TF2:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from src.tf_2.ops import tf2_ops
if bench_op not in tf2_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
tf2_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
elif platform == Platform.TF2_V1:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
if tf.__version__.split('.')[0] != '2':
print(f'Cannot run benchmark for platform TF2_V1 with tensorflow version: {tf.__version__}')
return
from src.tf_2_v1.ops import tf2v1_ops
if bench_op not in tf2v1_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
tf2v1_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
elif platform == Platform.TORCH:
from src.pytorch.ops import torch_ops
if bench_op not in torch_ops:
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
else:
torch_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
print()
else:
print(f'Platform {platform.value} is not implemented yet')
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output'),
help='Path to output files (default: output)')
parser.add_argument('--no-benchmark', action='store_true', default=False, help='Avoid running benchmarks')
parser.add_argument('--no-compare', action='store_true', default=False, help='Avoid running platform comparaison')
parser.add_argument('--count', type=int, default=30,
help='Number of experiments per benchmark (for stastistical analysis)')
parser.add_argument('--platform', nargs='*', type=Platform,
help='List of platform to benchmark [TF1, TF2, Torch] (else all are used)')
parser.add_argument('--data', nargs='*', type=DataType,
help='List of data type to benchmark [float16, float32, float64] (else all are used)')
parser.add_argument('--op', nargs='*', type=Op,
help='List of operation to benchmark (add, mul, div, matmul, etc) (else all are used)')
parser.add_argument('--list-op', action='store_true',
help='List all possible operation to benchmark (no further action will be done)')
parser.add_argument('--list-platform', action='store_true',
help='List all possible platform to benchmark (no further action will be done)')
parser.add_argument('--list-data', action='store_true',
help='List all possible data to benchmark (no further action will be done)')
parser.add_argument(
'--experiment-time', type=float,
help=f'Change time (in s) per experiment (default={Config.EXPERIMENT_TIME:0.3f}s)')
arguments = parser.parse_args()
if arguments.list_op:
print(', '.join([op.value for op in Op]))
sys.exit(0)
if arguments.list_platform:
print(', '.join([p.value for p in Platform]))
sys.exit(0)
if arguments.list_data:
print(', '.join([d.value for d in DataType]))
sys.exit(0)
output_path: Path = arguments.output
no_benchmark: bool = arguments.no_benchmark
no_compare: bool = arguments.no_compare
bench_count: int = arguments.count
platforms: List[Platform] = arguments.platform if arguments.platform is not None else list(Platform)
data: List[DataType] = arguments.data if arguments.data is not None else list(DataType)
bench_ops: List[Op] = arguments.op if arguments.op is not None else list(Op)
if arguments.experiment_time:
Config.EXPERIMENT_TIME = arguments.experiment_time
if not output_path.exists():
output_path.mkdir(parents=True)
if not no_benchmark:
benchmarks: List[dict[Op, Type[BenchBase]]] = []
for platform in platforms:
for data_type in data:
for bench_op in [Op.ADD, Op.MUL, Op.DIV]:
if bench_op in bench_ops:
benchmarks.append(
(output_path, platform, data_type, bench_op, Config.ELEMENT_WISE_ARGS, bench_count))
for bench_op in [Op.MATMUL, Op.NN_MATMUL]:
if bench_op in bench_ops:
benchmarks.append(
(output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count))
for bench_op in [Op.NN_DENSE, Op.NN_DENSE_X5]:
if bench_op in bench_ops:
benchmarks.append(
(output_path, platform, data_type, bench_op, Config.NN_1D_ARGS, bench_count))
if benchmarks:
for benchmark in benchmarks:
process = mp.Process(target=run_benchmark, args=benchmark)
process.start()
process.join()
print('Benchmark done')
if not no_compare:
compare(output_path)
print('Compare done')
if __name__ == '__main__':
main()