From 58310ba0773bc10ad1163f89114e29259ce2c619 Mon Sep 17 00:00:00 2001 From: Corentin Risselin Date: Sat, 4 Jul 2020 13:05:07 +0900 Subject: [PATCH] HDF5 save and Conv2d stride typing change --- layers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/layers.py b/layers.py index 4648b70..d698dc5 100644 --- a/layers.py +++ b/layers.py @@ -1,3 +1,5 @@ +from typing import Union, Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -44,8 +46,8 @@ class Layer(nn.Module): class Conv2d(Layer): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, - activation=0, batch_norm=None, **kwargs): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, + stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs)