dl_bench/benchmark.py
2021-09-28 00:41:53 +09:00

41 lines
1.2 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
from src.base import DataType
from src.torch.matmul import TorchMatmulBench
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files')
arguments = parser.parse_args()
output_path: Path = arguments.output
if not output_path.exists():
output_path.mkdir(parents=True)
for data_type in DataType:
TorchMatmulBench(output_path).run(
[
((100, 100), (100, 100)),
((100, 200), (200, 100)),
((128, 128), (128, 128)),
((200, 100), (100, 200)),
((200, 200), (200, 200)),
((256, 256), (256, 256)),
((256, 512), (512, 256)),
((400, 400), (400, 400)),
((512, 256), (256, 512)),
((512, 512), (512, 512)),
((800, 800), (800, 800)),
((1000, 1000), (1000, 1000)),
((1200, 1200), (1200, 1200)),
],
12,
data_type)
print('Benchmark done')
if __name__ == '__main__':
main()