from argparse import ArgumentParser import multiprocessing as mp import os from pathlib import Path import sys from typing import List, Type from config.benchmark import Config from src.base import BenchBase from src.common import DataType, Op, Platform from src.plot import compare def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op, bench_args, bench_count: int): if platform == Platform.JAX: if data_type == DataType.FLOAT64: os.environ['JAX_ENABLE_X64'] = 'true' from src.jax.ops import jax_ops if bench_op not in jax_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: jax_ops[bench_op](output_path, data_type).run(bench_args, bench_count) print() elif platform == Platform.TF1: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf if tf.__version__.split('.')[0] != '1': print(f'Cannot run benchmark for platform TF1 with tensorflow version: {tf.__version__}') return from src.tf_1.ops import tf1_ops if bench_op not in tf1_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: tf1_ops[bench_op](output_path, data_type).run(bench_args, bench_count) print() elif platform == Platform.TF2: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from src.tf_2.ops import tf2_ops if bench_op not in tf2_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: tf2_ops[bench_op](output_path, data_type).run(bench_args, bench_count) print() elif platform == Platform.TF2_V1: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf if tf.__version__.split('.')[0] != '2': print(f'Cannot run benchmark for platform TF2_V1 with tensorflow version: {tf.__version__}') return from src.tf_2_v1.ops import tf2v1_ops if bench_op not in tf2v1_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: tf2v1_ops[bench_op](output_path, data_type).run(bench_args, bench_count) print() elif platform == Platform.TORCH: from src.pytorch.ops import torch_ops if bench_op not in torch_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: torch_ops[bench_op](output_path, data_type).run(bench_args, bench_count) print() else: print(f'Platform {platform.value} is not implemented yet') def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files (default: output)') parser.add_argument('--no-benchmark', action='store_true', default=False, help='Avoid running benchmarks') parser.add_argument('--no-compare', action='store_true', default=False, help='Avoid running platform comparaison') parser.add_argument('--count', type=int, default=30, help='Number of experiments per benchmark (for stastistical analysis)') parser.add_argument('--platform', nargs='*', type=Platform, help='List of platform to benchmark [TF1, TF2, Torch] (else all are used)') parser.add_argument('--data', nargs='*', type=DataType, help='List of data type to benchmark [float16, float32, float64] (else all are used)') parser.add_argument('--op', nargs='*', type=Op, help='List of operation to benchmark (add, mul, div, matmul, etc) (else all are used)') parser.add_argument('--list-op', action='store_true', help='List all possible operation to benchmark (no further action will be done)') parser.add_argument('--list-platform', action='store_true', help='List all possible platform to benchmark (no further action will be done)') parser.add_argument('--list-data', action='store_true', help='List all possible data to benchmark (no further action will be done)') parser.add_argument( '--experiment-time', type=float, help=f'Change time (in s) per experiment (default={Config.EXPERIMENT_TIME:0.3f}s)') arguments = parser.parse_args() if arguments.list_op: print(', '.join([op.value for op in Op])) sys.exit(0) if arguments.list_platform: print(', '.join([p.value for p in Platform])) sys.exit(0) if arguments.list_data: print(', '.join([d.value for d in DataType])) sys.exit(0) output_path: Path = arguments.output no_benchmark: bool = arguments.no_benchmark no_compare: bool = arguments.no_compare bench_count: int = arguments.count platforms: List[Platform] = arguments.platform if arguments.platform is not None else list(Platform) data: List[DataType] = arguments.data if arguments.data is not None else list(DataType) bench_ops: List[Op] = arguments.op if arguments.op is not None else list(Op) if arguments.experiment_time: Config.EXPERIMENT_TIME = arguments.experiment_time if not output_path.exists(): output_path.mkdir(parents=True) if not no_benchmark: benchmarks: List[dict[Op, Type[BenchBase]]] = [] for platform in platforms: for data_type in data: for bench_op in [Op.ADD, Op.MUL, Op.DIV]: if bench_op in bench_ops: benchmarks.append( (output_path, platform, data_type, bench_op, Config.ELEMENT_WISE_ARGS, bench_count)) for bench_op in [Op.MATMUL, Op.NN_MATMUL]: if bench_op in bench_ops: benchmarks.append( (output_path, platform, data_type, bench_op, Config.MATMUL_ARGS, bench_count)) for bench_op in [Op.NN_DENSE, Op.NN_DENSE_X5]: if bench_op in bench_ops: benchmarks.append( (output_path, platform, data_type, bench_op, Config.NN_1D_ARGS, bench_count)) if benchmarks: for benchmark in benchmarks: process = mp.Process(target=run_benchmark, args=benchmark) process.start() process.join() print('Benchmark done') if not no_compare: compare(output_path) print('Compare done') if __name__ == '__main__': main()