from argparse import ArgumentParser import multiprocessing as mp import os from pathlib import Path from typing import Type from src.base import BenchBase from src.common import DataType, Op, Platform def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op, bench_args, bench_count: int): if platform == Platform.TF2: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from src.tf_2.ops import tf2_ops if bench_op not in tf2_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: tf2_ops[bench_op](output_path).run(bench_args, bench_count, data_type) print() elif platform == Platform.TORCH: from src.pytorch.ops import torch_ops if bench_op not in torch_ops: print(f'Operation {bench_op.value} is not implemented for {platform.value} yet') else: torch_ops[bench_op](output_path).run(bench_args, bench_count, data_type) print() else: print(f'Platform {platform.value} is not implemented yet') def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files') parser.add_argument('--count', type=int, default=30, help='Number of experiments per benchmark (for stastistical analysis)') parser.add_argument('--platform', nargs='*', type=Platform, help='List of platform to benchmark [TF1, TF2, Torch] (else all are used)') parser.add_argument('--data', nargs='*', type=DataType, help='List of data type to benchmark [float16, float32, float64] (else all are used)') parser.add_argument('--op', nargs='*', type=Op, help='List of operation to benchmark [add, mul, div, matmul] (else all are used)') arguments = parser.parse_args() output_path: Path = arguments.output bench_count: int = arguments.count platforms: list[Platform] = arguments.platform if arguments.platform is not None else list(Platform) data: list[DataType] = arguments.data if arguments.data is not None else list(DataType) bench_ops: list[Op] = arguments.op if arguments.op is not None else list(Op) if not output_path.exists(): output_path.mkdir(parents=True) benchmarks: list[dict[Op, Type[BenchBase]]] = [] element_wise_args = [ (100, 100), (100, 200), (128, 128), (200, 100), (200, 200), (256, 256), (256, 512), (512, 256), (400, 400), (512, 512), (800, 800), (1024, 1024), (1800, 1800)] matmul_args = [ ((100, 100), (100, 100)), ((100, 200), (200, 100)), ((128, 128), (128, 128)), ((200, 100), (100, 200)), ((200, 200), (200, 200)), ((256, 256), (256, 256)), ((256, 512), (512, 256)), ((400, 400), (400, 400)), ((512, 256), (256, 512)), ((512, 512), (512, 512)), ((800, 800), (800, 800)), ((1000, 1000), (1000, 1000)), ((1200, 1200), (1200, 1200))] for platform in platforms: for data_type in data: for bench_op in [Op.ADD, Op.MUL, Op.DIV]: if bench_op in bench_ops: benchmarks.append((output_path, platform, data_type, bench_op, element_wise_args, bench_count)) if Op.MATMUL in bench_ops: benchmarks.append((output_path, platform, data_type, Op.MATMUL, matmul_args, bench_count)) for benchmark in benchmarks: process = mp.Process(target=run_benchmark, args=benchmark) process.start() process.join() print('Benchmark done') if __name__ == '__main__': main()