147 lines
6.6 KiB
Python
147 lines
6.6 KiB
Python
from argparse import ArgumentParser
|
|
import multiprocessing as mp
|
|
import os
|
|
from pathlib import Path
|
|
import sys
|
|
from typing import List, Type
|
|
|
|
from config.benchmark import Config
|
|
from src.base import BenchBase
|
|
from src.common import DataType, Op, Platform
|
|
from src.plot import compare
|
|
|
|
|
|
def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op,
|
|
bench_args, bench_count: int):
|
|
if platform == Platform.JAX:
|
|
if data_type == DataType.FLOAT64:
|
|
os.environ['JAX_ENABLE_X64'] = 'true'
|
|
from src.jax.ops import jax_ops
|
|
if bench_op not in jax_ops:
|
|
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
|
else:
|
|
jax_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
|
print()
|
|
elif platform == Platform.TF1:
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
import tensorflow as tf
|
|
if tf.__version__.split('.')[0] != '1':
|
|
print(f'Cannot run benchmark for platform TF1 with tensorflow version: {tf.__version__}')
|
|
return
|
|
from src.tf_1.ops import tf1_ops
|
|
if bench_op not in tf1_ops:
|
|
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
|
else:
|
|
tf1_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
|
print()
|
|
elif platform == Platform.TF2:
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
from src.tf_2.ops import tf2_ops
|
|
if bench_op not in tf2_ops:
|
|
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
|
else:
|
|
tf2_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
|
print()
|
|
elif platform == Platform.TF2_V1:
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
import tensorflow as tf
|
|
if tf.__version__.split('.')[0] != '2':
|
|
print(f'Cannot run benchmark for platform TF2_V1 with tensorflow version: {tf.__version__}')
|
|
return
|
|
from src.tf_2_v1.ops import tf2v1_ops
|
|
if bench_op not in tf2v1_ops:
|
|
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
|
else:
|
|
tf2v1_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
|
print()
|
|
elif platform == Platform.TORCH:
|
|
from src.pytorch.ops import torch_ops
|
|
if bench_op not in torch_ops:
|
|
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
|
else:
|
|
torch_ops[bench_op](output_path, data_type).run(bench_args, bench_count)
|
|
print()
|
|
else:
|
|
print(f'Platform {platform.value} is not implemented yet')
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('--output', type=Path, default=Path('output'),
|
|
help='Path to output files (default: output)')
|
|
parser.add_argument('--no-benchmark', action='store_true', default=False, help='Avoid running benchmarks')
|
|
parser.add_argument('--no-compare', action='store_true', default=False, help='Avoid running platform comparaison')
|
|
parser.add_argument('--count', type=int, default=30,
|
|
help='Number of experiments per benchmark (for stastistical analysis)')
|
|
parser.add_argument('--platform', nargs='*', type=Platform,
|
|
help='List of platform to benchmark [TF1, TF2, Torch] (else all are used)')
|
|
parser.add_argument('--data', nargs='*', type=DataType,
|
|
help='List of data type to benchmark [float16, float32, float64] (else all are used)')
|
|
parser.add_argument('--op', nargs='*', type=Op,
|
|
help='List of operation to benchmark (add, mul, div, matmul, etc) (else all are used)')
|
|
parser.add_argument('--list-op', action='store_true',
|
|
help='List all possible operation to benchmark (no further action will be done)')
|
|
parser.add_argument('--list-platform', action='store_true',
|
|
help='List all possible platform to benchmark (no further action will be done)')
|
|
parser.add_argument('--list-data', action='store_true',
|
|
help='List all possible data to benchmark (no further action will be done)')
|
|
parser.add_argument(
|
|
'--experiment-time', type=float,
|
|
help=f'Change time (in s) per experiment (default={Config.EXPERIMENT_TIME:0.3f}s)')
|
|
arguments = parser.parse_args()
|
|
|
|
if arguments.list_op:
|
|
print(', '.join([op.value for op in Op]))
|
|
sys.exit(0)
|
|
if arguments.list_platform:
|
|
print(', '.join([p.value for p in Platform]))
|
|
sys.exit(0)
|
|
if arguments.list_data:
|
|
print(', '.join([d.value for d in DataType]))
|
|
sys.exit(0)
|
|
|
|
output_path: Path = arguments.output
|
|
no_benchmark: bool = arguments.no_benchmark
|
|
no_compare: bool = arguments.no_compare
|
|
bench_count: int = arguments.count
|
|
platforms: List[Platform] = arguments.platform if arguments.platform is not None else list(Platform)
|
|
data: List[DataType] = arguments.data if arguments.data is not None else list(DataType)
|
|
bench_ops: List[Op] = arguments.op if arguments.op is not None else list(Op)
|
|
|
|
if arguments.experiment_time:
|
|
Config.EXPERIMENT_TIME = arguments.experiment_time
|
|
|
|
if not output_path.exists():
|
|
output_path.mkdir(parents=True)
|
|
|
|
if not no_benchmark:
|
|
benchmarks: List[dict[Op, Type[BenchBase]]] = []
|
|
for platform in platforms:
|
|
for data_type in data:
|
|
for bench_op in [Op.ADD, Op.MUL, Op.DIV]:
|
|
if bench_op in bench_ops:
|
|
benchmarks.append(
|
|
(output_path, platform, data_type, bench_op, Config.ELEMENT_WISE_ARGS, bench_count))
|
|
for bench_op in [Op.MATMUL, Op.NN_MATMUL]:
|
|
if bench_op in bench_ops:
|
|
benchmarks.append(
|
|
(output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count))
|
|
for bench_op in [Op.NN_DENSE, Op.NN_DENSE_X5]:
|
|
if bench_op in bench_ops:
|
|
benchmarks.append(
|
|
(output_path, platform, data_type, bench_op, Config.NN_1D_ARGS, bench_count))
|
|
|
|
if benchmarks:
|
|
for benchmark in benchmarks:
|
|
process = mp.Process(target=run_benchmark, args=benchmark)
|
|
process.start()
|
|
process.join()
|
|
print('Benchmark done')
|
|
|
|
if not no_compare:
|
|
compare(output_path)
|
|
print('Compare done')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|