from argparse import ArgumentParser from pathlib import Path from src.base import DataType from src.torch.matmul import TorchMatmulBench def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output'), help='Path to output files') arguments = parser.parse_args() output_path: Path = arguments.output if not output_path.exists(): output_path.mkdir(parents=True) for data_type in DataType: TorchMatmulBench(output_path).run( [ ((100, 100), (100, 100)), ((100, 200), (200, 100)), ((128, 128), (128, 128)), ((200, 100), (100, 200)), ((200, 200), (200, 200)), ((256, 256), (256, 256)), ((256, 512), (512, 256)), ((400, 400), (400, 400)), ((512, 256), (256, 512)), ((512, 512), (512, 512)), ((800, 800), (800, 800)), ((1000, 1000), (1000, 1000)), ((1200, 1200), (1200, 1200)), ], 12, data_type) print('Benchmark done') if __name__ == '__main__': main()