HDF5 save and Conv2d stride typing change
This commit is contained in:
parent
95bd1850b5
commit
58310ba077
1 changed files with 4 additions and 2 deletions
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -44,8 +46,8 @@ class Layer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(Layer):
|
class Conv2d(Layer):
|
||||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||||
activation=0, batch_norm=None, **kwargs):
|
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
||||||
super().__init__(activation, batch_norm)
|
super().__init__(activation, batch_norm)
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs)
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue