41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
|
|
from src.base import DataType
|
|
from src.torch.matmul import TorchMatmulBench
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files')
|
|
arguments = parser.parse_args()
|
|
|
|
output_path: Path = arguments.output
|
|
|
|
if not output_path.exists():
|
|
output_path.mkdir(parents=True)
|
|
|
|
for data_type in DataType:
|
|
TorchMatmulBench(output_path).run(
|
|
[
|
|
((100, 100), (100, 100)),
|
|
((100, 200), (200, 100)),
|
|
((128, 128), (128, 128)),
|
|
((200, 100), (100, 200)),
|
|
((200, 200), (200, 200)),
|
|
((256, 256), (256, 256)),
|
|
((256, 512), (512, 256)),
|
|
((400, 400), (400, 400)),
|
|
((512, 256), (256, 512)),
|
|
((512, 512), (512, 512)),
|
|
((800, 800), (800, 800)),
|
|
((1000, 1000), (1000, 1000)),
|
|
((1200, 1200), (1200, 1200)),
|
|
],
|
|
12,
|
|
data_type)
|
|
print('Benchmark done')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|