Jax implementation, code factorisation

* Compatibility for older python version (typing)
This commit is contained in:
Corentin 2021-10-01 20:14:00 +09:00
commit 16b7239cd7
37 changed files with 1007 additions and 293 deletions

41
config/benchmark.py Normal file
View file

@ -0,0 +1,41 @@
class Config:
EXPERIMENT_TIME = 1.0
ELEMENT_WISE_ARGS = [
(100, 100),
(100, 200),
(128, 128),
(200, 100),
(200, 200),
(256, 256),
(256, 512),
(512, 256),
(400, 400),
(512, 512),
(800, 800),
(1024, 1024),
(1800, 1800)]
MATMUL_ARGS = [
((100, 100), (100, 100)),
((100, 200), (200, 100)),
((128, 128), (128, 128)),
((200, 100), (100, 200)),
((200, 200), (200, 200)),
((256, 256), (256, 256)),
((256, 512), (512, 256)),
((400, 400), (400, 400)),
((512, 256), (256, 512)),
((512, 512), (512, 512)),
((800, 800), (800, 800)),
((1000, 1000), (1000, 1000)),
((1200, 1200), (1200, 1200))]
NN_1D_ARGS = [
(1, 16), (16, 16), (64, 16),
(1, 64), (16, 64),
(1, 150), (16, 150),
(1, 256), (16, 256),
(1, 400), (16, 400), (64, 400),
(1, 512), (16, 512), (64, 512),
(1, 800), (16, 800), (64, 800),
(1, 1024), (16, 1024),
(1, 2000), (16, 2000), (64, 2000),
(1, 4000), (16, 4000), (64, 4000)]