diff --git a/layers.py b/layers.py index 4648b70..d698dc5 100644 --- a/layers.py +++ b/layers.py @@ -1,3 +1,5 @@ +from typing import Union, Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -44,8 +46,8 @@ class Layer(nn.Module): class Conv2d(Layer): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, - activation=0, batch_norm=None, **kwargs): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, + stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs)