Implement TF2 and add, mul and div benchmark
This commit is contained in:
parent
fbf6898dd9
commit
4b2bcfe7e8
18 changed files with 649 additions and 171 deletions
102
benchmark.py
102
benchmark.py
|
|
@ -1,39 +1,99 @@
|
|||
from argparse import ArgumentParser
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
from src.base import DataType
|
||||
from src.torch.matmul import TorchMatmulBench
|
||||
from src.base import BenchBase
|
||||
from src.common import DataType, Op, Platform
|
||||
|
||||
|
||||
def run_benchmark(output_path: Path, platform: Platform, data_type: DataType, bench_op: Op,
|
||||
bench_args, bench_count: int):
|
||||
if platform == Platform.TF2:
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
from src.tf_2.ops import tf2_ops
|
||||
if bench_op not in tf2_ops:
|
||||
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
||||
else:
|
||||
tf2_ops[bench_op](output_path).run(bench_args, bench_count, data_type)
|
||||
print()
|
||||
elif platform == Platform.TORCH:
|
||||
from src.pytorch.ops import torch_ops
|
||||
if bench_op not in torch_ops:
|
||||
print(f'Operation {bench_op.value} is not implemented for {platform.value} yet')
|
||||
else:
|
||||
torch_ops[bench_op](output_path).run(bench_args, bench_count, data_type)
|
||||
print()
|
||||
else:
|
||||
print(f'Platform {platform.value} is not implemented yet')
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files')
|
||||
parser.add_argument('--count', type=int, default=30,
|
||||
help='Number of experiments per benchmark (for stastistical analysis)')
|
||||
parser.add_argument('--platform', nargs='*', type=Platform,
|
||||
help='List of platform to benchmark [TF1, TF2, Torch] (else all are used)')
|
||||
parser.add_argument('--data', nargs='*', type=DataType,
|
||||
help='List of data type to benchmark [float16, float32, float64] (else all are used)')
|
||||
parser.add_argument('--op', nargs='*', type=Op,
|
||||
help='List of operation to benchmark [add, mul, div, matmul] (else all are used)')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
output_path: Path = arguments.output
|
||||
bench_count: int = arguments.count
|
||||
platforms: list[Platform] = arguments.platform if arguments.platform is not None else list(Platform)
|
||||
data: list[DataType] = arguments.data if arguments.data is not None else list(DataType)
|
||||
bench_ops: list[Op] = arguments.op if arguments.op is not None else list(Op)
|
||||
|
||||
if not output_path.exists():
|
||||
output_path.mkdir(parents=True)
|
||||
|
||||
for data_type in DataType:
|
||||
TorchMatmulBench(output_path).run(
|
||||
[
|
||||
((100, 100), (100, 100)),
|
||||
((100, 200), (200, 100)),
|
||||
((128, 128), (128, 128)),
|
||||
((200, 100), (100, 200)),
|
||||
((200, 200), (200, 200)),
|
||||
((256, 256), (256, 256)),
|
||||
((256, 512), (512, 256)),
|
||||
((400, 400), (400, 400)),
|
||||
((512, 256), (256, 512)),
|
||||
((512, 512), (512, 512)),
|
||||
((800, 800), (800, 800)),
|
||||
((1000, 1000), (1000, 1000)),
|
||||
((1200, 1200), (1200, 1200)),
|
||||
],
|
||||
12,
|
||||
data_type)
|
||||
benchmarks: list[dict[Op, Type[BenchBase]]] = []
|
||||
element_wise_args = [
|
||||
(100, 100),
|
||||
(100, 200),
|
||||
(128, 128),
|
||||
(200, 100),
|
||||
(200, 200),
|
||||
(256, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(400, 400),
|
||||
(512, 512),
|
||||
(800, 800),
|
||||
(1024, 1024),
|
||||
(1800, 1800)]
|
||||
matmul_args = [
|
||||
((100, 100), (100, 100)),
|
||||
((100, 200), (200, 100)),
|
||||
((128, 128), (128, 128)),
|
||||
((200, 100), (100, 200)),
|
||||
((200, 200), (200, 200)),
|
||||
((256, 256), (256, 256)),
|
||||
((256, 512), (512, 256)),
|
||||
((400, 400), (400, 400)),
|
||||
((512, 256), (256, 512)),
|
||||
((512, 512), (512, 512)),
|
||||
((800, 800), (800, 800)),
|
||||
((1000, 1000), (1000, 1000)),
|
||||
((1200, 1200), (1200, 1200))]
|
||||
|
||||
for platform in platforms:
|
||||
for data_type in data:
|
||||
for bench_op in [Op.ADD, Op.MUL, Op.DIV]:
|
||||
if bench_op in bench_ops:
|
||||
benchmarks.append((output_path, platform, data_type, bench_op, element_wise_args, bench_count))
|
||||
if Op.MATMUL in bench_ops:
|
||||
benchmarks.append((output_path, platform, data_type, Op.MATMUL, matmul_args, bench_count))
|
||||
|
||||
for benchmark in benchmarks:
|
||||
process = mp.Process(target=run_benchmark, args=benchmark)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
print('Benchmark done')
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue