Residual blocks, precache for BatchGenerator
This commit is contained in:
parent
7f4a162033
commit
5081cf63fe
3 changed files with 139 additions and 4 deletions
|
|
@ -50,7 +50,8 @@ class Conv2d(Layer):
|
||||||
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
||||||
super().__init__(activation, batch_norm)
|
super().__init__(activation, batch_norm)
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs)
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||||
|
bias=not self.batch_norm, **kwargs)
|
||||||
self.batch_norm = nn.BatchNorm2d(
|
self.batch_norm = nn.BatchNorm2d(
|
||||||
out_channels,
|
out_channels,
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
|
|
||||||
67
residual.py
Normal file
67
residual.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
from typing import Union, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .layers import LayerInfo, Layer
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(Layer):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||||
|
activation=None, **kwargs):
|
||||||
|
super().__init__(activation if activation is not None else 0, False)
|
||||||
|
|
||||||
|
self.seq = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||||
|
torch.nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||||
|
self.batch_norm = nn.BatchNorm2d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||||
|
|
||||||
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input_data + self.seq(input_data))
|
||||||
|
|
||||||
|
|
||||||
|
class ResBottleneck(Layer):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3,
|
||||||
|
stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs):
|
||||||
|
super().__init__(activation if activation is not None else 0, False)
|
||||||
|
self.batch_norm = None
|
||||||
|
|
||||||
|
self.seq = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||||
|
torch.nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||||
|
torch.nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
planes * out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||||
|
|
||||||
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(self.downsample(input_data) + self.seq(input_data))
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import math
|
import math
|
||||||
|
import multiprocessing as mp
|
||||||
|
from multiprocessing import shared_memory
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -8,11 +10,12 @@ import numpy as np
|
||||||
|
|
||||||
class BatchGenerator:
|
class BatchGenerator:
|
||||||
|
|
||||||
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None,
|
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, precache=True,
|
||||||
shuffle=True, preload=False, save=None, left_right_flip=False):
|
shuffle=True, preload=False, save=None, left_right_flip=False):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.left_right_flip = left_right_flip
|
self.left_right_flip = left_right_flip
|
||||||
|
self.precache = precache and not preload
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
@ -60,6 +63,64 @@ class BatchGenerator:
|
||||||
if shuffle:
|
if shuffle:
|
||||||
np.random.shuffle(self.index_list)
|
np.random.shuffle(self.index_list)
|
||||||
|
|
||||||
|
if self.precache:
|
||||||
|
data_sample = np.array([data_processor(entry) if data_processor else entry
|
||||||
|
for entry in self.data[:batch_size]])
|
||||||
|
label_sample = np.array([label_processor(entry) if label_processor else entry
|
||||||
|
for entry in self.label[:batch_size]])
|
||||||
|
self.cache_memory_data = [
|
||||||
|
shared_memory.SharedMemory(create=True, size=data_sample.nbytes),
|
||||||
|
shared_memory.SharedMemory(create=True, size=data_sample.nbytes)]
|
||||||
|
self.cache_data = [
|
||||||
|
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf),
|
||||||
|
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)]
|
||||||
|
self.cache_memory_label = [
|
||||||
|
shared_memory.SharedMemory(create=True, size=label_sample.nbytes),
|
||||||
|
shared_memory.SharedMemory(create=True, size=label_sample.nbytes)]
|
||||||
|
self.cache_label = [
|
||||||
|
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf),
|
||||||
|
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)]
|
||||||
|
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
||||||
|
self.cache_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||||
|
self.cache_stop.buf[0] = 0
|
||||||
|
self.cache_process = mp.Process(target=self.cache_worker)
|
||||||
|
self.cache_process.start()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.precache:
|
||||||
|
self.cache_stop.buf[0] = 1
|
||||||
|
self.cache_pipe_parent.send(True)
|
||||||
|
self.cache_process.join()
|
||||||
|
|
||||||
|
self.cache_stop.close()
|
||||||
|
self.cache_stop.unlink()
|
||||||
|
self.cache_memory_data[0].close()
|
||||||
|
self.cache_memory_data[0].unlink()
|
||||||
|
self.cache_memory_data[1].close()
|
||||||
|
self.cache_memory_data[1].unlink()
|
||||||
|
self.cache_memory_label[0].close()
|
||||||
|
self.cache_memory_label[0].unlink()
|
||||||
|
self.cache_memory_label[1].close()
|
||||||
|
self.cache_memory_label[1].unlink()
|
||||||
|
|
||||||
|
def cache_worker(self):
|
||||||
|
self.precache = False
|
||||||
|
self.next_batch()
|
||||||
|
self.cache_data[0][:] = self.batch_data[:]
|
||||||
|
self.cache_label[0][:] = self.batch_label[:]
|
||||||
|
current_cache = 0
|
||||||
|
|
||||||
|
while not self.cache_stop.buf[0]:
|
||||||
|
try:
|
||||||
|
self.cache_pipe_child.recv()
|
||||||
|
self.cache_pipe_child.send(current_cache)
|
||||||
|
self.next_batch()
|
||||||
|
current_cache = 1 - current_cache
|
||||||
|
self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:]
|
||||||
|
self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:]
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
break
|
||||||
|
|
||||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
if self.step >= self.step_per_epoch - 1: # step start at 0
|
if self.step >= self.step_per_epoch - 1: # step start at 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
@ -71,7 +132,11 @@ class BatchGenerator:
|
||||||
|
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
# Loading data
|
# Loading data
|
||||||
if self.data_processor is not None:
|
if self.precache:
|
||||||
|
self.cache_pipe_parent.send(True)
|
||||||
|
current_cache = self.cache_pipe_parent.recv()
|
||||||
|
self.batch_data = self.cache_data[current_cache].copy()
|
||||||
|
elif self.data_processor is not None:
|
||||||
self.batch_data = []
|
self.batch_data = []
|
||||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
self.batch_data.append(self.data_processor(self.data[entry]))
|
self.batch_data.append(self.data_processor(self.data[entry]))
|
||||||
|
|
@ -80,7 +145,9 @@ class BatchGenerator:
|
||||||
self.batch_data = self.data[
|
self.batch_data = self.data[
|
||||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||||
# Loading label
|
# Loading label
|
||||||
if self.label_processor is not None:
|
if self.precache:
|
||||||
|
self.batch_label = self.cache_label[current_cache].copy()
|
||||||
|
elif self.label_processor is not None:
|
||||||
self.batch_label = []
|
self.batch_label = []
|
||||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
self.batch_label.append(self.label_processor(self.label[entry]))
|
self.batch_label.append(self.label_processor(self.label[entry]))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue