HDF5 save and Conv2d stride typing change
This commit is contained in:
parent
95bd1850b5
commit
58310ba077
1 changed files with 4 additions and 2 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -44,8 +46,8 @@ class Layer(nn.Module):
|
|||
|
||||
|
||||
class Conv2d(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
|
||||
activation=0, batch_norm=None, **kwargs):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
||||
super().__init__(activation, batch_norm)
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue