Trainer implementation

* Add Deconv2d
* Fix BatchGenerator save option when using current directory
This commit is contained in:
Corentin 2020-12-25 15:50:38 +09:00
commit f2282e3216
3 changed files with 263 additions and 7 deletions

View file

@ -45,6 +45,20 @@ class Layer(nn.Module):
return output return output
class Linear(Layer):
def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs):
super().__init__(activation, batch_norm)
self.fc = nn.Linear(in_channels, out_channels, **kwargs)
self.batch_norm = nn.BatchNorm1d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.fc(input_data))
class Conv1d(Layer): class Conv1d(Layer):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
@ -93,15 +107,18 @@ class Conv3d(Layer):
return super().forward(self.conv(input_data)) return super().forward(self.conv(input_data))
class Linear(Layer): class Deconv2d(Layer):
def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
super().__init__(activation, batch_norm) super().__init__(activation, batch_norm)
self.fc = nn.Linear(in_channels, out_channels, **kwargs) self.deconv = nn.ConvTranspose2d(
self.batch_norm = nn.BatchNorm1d( in_channels, out_channels, kernel_size, stride=stride,
bias=not self.batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm2d(
out_channels, out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM, momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.fc(input_data)) return super().forward(self.deconv(input_data))

239
trainer.py Normal file
View file

@ -0,0 +1,239 @@
import glob
import os
import shutil
import time
from typing import Callable, Optional
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from .train import parameter_summary, resource_usage
from .utils.logger import DummyLogger
from .utils.batch_generator import BatchGenerator
class Trainer:
def __init__(
self,
batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator,
device: str, output_dir: str,
network: nn.Module, pre_process: nn.Module,
optimizer: torch.optim.Optimizer, criterion: nn.Module,
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger):
super().__init__()
self.device = device
self.output_dir = output_dir
self.data_is_label = data_is_label
self.logger = logger
self.batch_generator_train = batch_generator_train
self.batch_generator_val = batch_generator_val
self.pre_process = pre_process
self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process
self.data_dtype = data_dtype
self.label_dtype = label_dtype
self.network = network
self.optimizer = optimizer
self.criterion = criterion
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
# Save network graph
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device)
processed_inputs = pre_process(batch_inputs)
self.writer_train.add_graph(network, (processed_inputs,))
self.writer_val.add_graph(network, (processed_inputs,))
# Save parameters info
with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
# Calculate summary periods
self.epoch_skip = epoch_skip
self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch
if self.summary_period == 0:
self.summary_period = 1
self.image_period = batch_generator_train.step_per_epoch // image_per_epoch
if self.image_period == 0:
self.image_period = 1
# Avoid different period between image and summaries (due to odd number division)
if summary_per_epoch % image_per_epoch == 0:
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
# Save source files
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
if not os.path.exists(dirname):
os.makedirs(dirname)
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
# Initialize training loop variables
self.batch_inputs = batch_inputs
self.batch_labels = batch_labels
self.processed_inputs = processed_inputs
self.network_outputs = processed_inputs # Placeholder
self.train_loss = 0.0
self.running_loss = 0.0
self.running_count = 0
self.benchmark_step = 0
self.benchmark_time = time.time()
def train_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float):
pass
def val_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float):
pass
def summary_callback(
self,
train_inputs: torch.Tensor, train_processed: torch.Tensor,
train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int,
val_inputs: torch.Tensor, val_processed: torch.Tensor,
val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int):
pass
def image_callback(
self,
train_inputs: torch.Tensor, train_processed: torch.Tensor,
train_labels: torch.Tensor, train_outputs: torch.Tensor,
val_inputs: torch.Tensor, val_processed: torch.Tensor,
val_labels: torch.Tensor, val_outputs: torch.Tensor):
pass
def fit(self, epochs: int) -> torch.Tensor:
train_start_time = time.time()
try:
while self.batch_generator_train.epoch < epochs:
epoch = self.batch_generator_train.epoch
print()
print(' ' * os.get_terminal_size()[0], end='\r')
print(f'Epoch {self.batch_generator_train.epoch}')
while epoch == self.batch_generator_train.epoch:
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
if self.benchmark_step > 1:
speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r')
# Faster zero grad
for param in self.network.parameters():
param.grad = None
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
loss = self.criterion(
self.network_outputs,
self.batch_labels if not self.data_is_label else self.processed_inputs)
loss.backward()
self.optimizer.step()
self.train_loss = loss.item()
self.running_loss += self.train_loss
self.running_count += len(self.batch_generator_train.batch_data)
self.train_step_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels,
self.network_outputs, self.train_loss)
self.benchmark_step += 1
self.save_summaries()
self.batch_generator_train.next_batch()
except KeyboardInterrupt:
print()
train_stop_time = time.time()
self.writer_train.close()
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
memory_peak, gpu_memory = resource_usage()
self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format(
train_stop_time - train_start_time,
memory_peak // 1024,
gpu_memory))
def save_summaries(self):
global_step = self.batch_generator_train.global_step
if self.batch_generator_train.epoch < self.epoch_skip:
return
if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and (
self.batch_generator_train.step % self.image_period != self.image_period - 1):
return
if self.batch_generator_val.step != 0:
self.batch_generator_val.skip_epoch()
val_loss = 0.0
val_count = 0
self.network.train(False)
with torch.no_grad():
val_epoch = self.batch_generator_val.epoch
while val_epoch == self.batch_generator_val.epoch:
val_inputs = torch.as_tensor(
self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device)
val_labels = torch.as_tensor(
self.batch_generator_val.batch_data, dtype=self.label_dtype, device=self.device)
val_pre_process = self.pre_process(val_inputs)
val_outputs = self.network(val_pre_process)
loss = self.criterion(
val_outputs,
val_labels if not self.data_is_label else val_pre_process).item()
val_loss += loss
val_count += len(self.batch_generator_val.batch_data)
self.val_step_callback(
val_inputs, val_pre_process, val_labels, val_outputs, loss)
self.batch_generator_val.next_batch()
self.network.train(True)
# Add summaries
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
self.writer_train.add_scalar(
'loss', self.running_loss / self.running_count, global_step=global_step)
self.writer_val.add_scalar(
'loss', val_loss / val_count, global_step=global_step)
self.summary_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
# Add image
if self.batch_generator_train.step % self.image_period == (self.image_period - 1):
self.image_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
val_inputs, val_pre_process, val_labels, val_outputs)
speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
self.benchmark_time = time.time()
self.benchmark_step = 0
self.running_loss = 0.0
self.running_count = 0

View file

@ -31,7 +31,7 @@ class BatchGenerator:
if save is not None: if save is not None:
if '.' not in os.path.basename(save_path): if '.' not in os.path.basename(save_path):
save_path += '.hdf5' save_path += '.hdf5'
if not os.path.exists(os.path.dirname(save_path)): if os.path.dirname(save_path) and not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path)) os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save_path): if save and os.path.exists(save_path):