Trainer implementation
* Add Deconv2d * Fix BatchGenerator save option when using current directory
This commit is contained in:
parent
144ff4a004
commit
f2282e3216
3 changed files with 263 additions and 7 deletions
29
layers.py
29
layers.py
|
|
@ -45,6 +45,20 @@ class Layer(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(Layer):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs):
|
||||||
|
super().__init__(activation, batch_norm)
|
||||||
|
|
||||||
|
self.fc = nn.Linear(in_channels, out_channels, **kwargs)
|
||||||
|
self.batch_norm = nn.BatchNorm1d(
|
||||||
|
out_channels,
|
||||||
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
|
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||||
|
|
||||||
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(self.fc(input_data))
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(Layer):
|
class Conv1d(Layer):
|
||||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||||
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
||||||
|
|
@ -93,15 +107,18 @@ class Conv3d(Layer):
|
||||||
return super().forward(self.conv(input_data))
|
return super().forward(self.conv(input_data))
|
||||||
|
|
||||||
|
|
||||||
class Linear(Layer):
|
class Deconv2d(Layer):
|
||||||
def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs):
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||||
|
stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs):
|
||||||
super().__init__(activation, batch_norm)
|
super().__init__(activation, batch_norm)
|
||||||
|
|
||||||
self.fc = nn.Linear(in_channels, out_channels, **kwargs)
|
self.deconv = nn.ConvTranspose2d(
|
||||||
self.batch_norm = nn.BatchNorm1d(
|
in_channels, out_channels, kernel_size, stride=stride,
|
||||||
|
bias=not self.batch_norm, **kwargs)
|
||||||
|
self.batch_norm = nn.BatchNorm2d(
|
||||||
out_channels,
|
out_channels,
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(self.fc(input_data))
|
return super().forward(self.deconv(input_data))
|
||||||
|
|
|
||||||
239
trainer.py
Normal file
239
trainer.py
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .train import parameter_summary, resource_usage
|
||||||
|
from .utils.logger import DummyLogger
|
||||||
|
from .utils.batch_generator import BatchGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator,
|
||||||
|
device: str, output_dir: str,
|
||||||
|
network: nn.Module, pre_process: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer, criterion: nn.Module,
|
||||||
|
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
||||||
|
data_dtype=None, label_dtype=None,
|
||||||
|
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
||||||
|
logger=DummyLogger):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.data_is_label = data_is_label
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
self.batch_generator_train = batch_generator_train
|
||||||
|
self.batch_generator_val = batch_generator_val
|
||||||
|
|
||||||
|
self.pre_process = pre_process
|
||||||
|
self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process
|
||||||
|
self.data_dtype = data_dtype
|
||||||
|
self.label_dtype = label_dtype
|
||||||
|
|
||||||
|
self.network = network
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.criterion = criterion
|
||||||
|
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
|
||||||
|
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
|
||||||
|
|
||||||
|
# Save network graph
|
||||||
|
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
|
||||||
|
batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device)
|
||||||
|
processed_inputs = pre_process(batch_inputs)
|
||||||
|
self.writer_train.add_graph(network, (processed_inputs,))
|
||||||
|
self.writer_val.add_graph(network, (processed_inputs,))
|
||||||
|
|
||||||
|
# Save parameters info
|
||||||
|
with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file:
|
||||||
|
param_summary = parameter_summary(network)
|
||||||
|
names = [len(name) for name, _, _ in param_summary]
|
||||||
|
shapes = [len(str(shape)) for _, shape, _ in param_summary]
|
||||||
|
param_file.write(
|
||||||
|
'\n'.join(
|
||||||
|
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
||||||
|
for name, shape, size in param_summary]))
|
||||||
|
|
||||||
|
# Calculate summary periods
|
||||||
|
self.epoch_skip = epoch_skip
|
||||||
|
self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch
|
||||||
|
if self.summary_period == 0:
|
||||||
|
self.summary_period = 1
|
||||||
|
self.image_period = batch_generator_train.step_per_epoch // image_per_epoch
|
||||||
|
if self.image_period == 0:
|
||||||
|
self.image_period = 1
|
||||||
|
# Avoid different period between image and summaries (due to odd number division)
|
||||||
|
if summary_per_epoch % image_per_epoch == 0:
|
||||||
|
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
|
||||||
|
|
||||||
|
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
||||||
|
# Save source files
|
||||||
|
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||||
|
os.path.join('src', '**', '*.py'), recursive=True):
|
||||||
|
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
||||||
|
|
||||||
|
# Initialize training loop variables
|
||||||
|
self.batch_inputs = batch_inputs
|
||||||
|
self.batch_labels = batch_labels
|
||||||
|
self.processed_inputs = processed_inputs
|
||||||
|
self.network_outputs = processed_inputs # Placeholder
|
||||||
|
self.train_loss = 0.0
|
||||||
|
self.running_loss = 0.0
|
||||||
|
self.running_count = 0
|
||||||
|
self.benchmark_step = 0
|
||||||
|
self.benchmark_time = time.time()
|
||||||
|
|
||||||
|
def train_step_callback(
|
||||||
|
self,
|
||||||
|
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||||
|
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
||||||
|
loss: float):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def val_step_callback(
|
||||||
|
self,
|
||||||
|
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||||
|
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
||||||
|
loss: float):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def summary_callback(
|
||||||
|
self,
|
||||||
|
train_inputs: torch.Tensor, train_processed: torch.Tensor,
|
||||||
|
train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int,
|
||||||
|
val_inputs: torch.Tensor, val_processed: torch.Tensor,
|
||||||
|
val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def image_callback(
|
||||||
|
self,
|
||||||
|
train_inputs: torch.Tensor, train_processed: torch.Tensor,
|
||||||
|
train_labels: torch.Tensor, train_outputs: torch.Tensor,
|
||||||
|
val_inputs: torch.Tensor, val_processed: torch.Tensor,
|
||||||
|
val_labels: torch.Tensor, val_outputs: torch.Tensor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fit(self, epochs: int) -> torch.Tensor:
|
||||||
|
train_start_time = time.time()
|
||||||
|
try:
|
||||||
|
while self.batch_generator_train.epoch < epochs:
|
||||||
|
epoch = self.batch_generator_train.epoch
|
||||||
|
print()
|
||||||
|
print(' ' * os.get_terminal_size()[0], end='\r')
|
||||||
|
print(f'Epoch {self.batch_generator_train.epoch}')
|
||||||
|
while epoch == self.batch_generator_train.epoch:
|
||||||
|
self.batch_inputs = torch.as_tensor(
|
||||||
|
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
||||||
|
self.batch_labels = torch.as_tensor(
|
||||||
|
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
||||||
|
|
||||||
|
if self.benchmark_step > 1:
|
||||||
|
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||||
|
print(
|
||||||
|
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
|
||||||
|
f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r')
|
||||||
|
|
||||||
|
# Faster zero grad
|
||||||
|
for param in self.network.parameters():
|
||||||
|
param.grad = None
|
||||||
|
|
||||||
|
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
||||||
|
self.network_outputs = self.network(self.processed_inputs)
|
||||||
|
loss = self.criterion(
|
||||||
|
self.network_outputs,
|
||||||
|
self.batch_labels if not self.data_is_label else self.processed_inputs)
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
self.train_loss = loss.item()
|
||||||
|
self.running_loss += self.train_loss
|
||||||
|
self.running_count += len(self.batch_generator_train.batch_data)
|
||||||
|
self.train_step_callback(
|
||||||
|
self.batch_inputs, self.processed_inputs, self.batch_labels,
|
||||||
|
self.network_outputs, self.train_loss)
|
||||||
|
|
||||||
|
self.benchmark_step += 1
|
||||||
|
self.save_summaries()
|
||||||
|
self.batch_generator_train.next_batch()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print()
|
||||||
|
|
||||||
|
train_stop_time = time.time()
|
||||||
|
self.writer_train.close()
|
||||||
|
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
||||||
|
|
||||||
|
memory_peak, gpu_memory = resource_usage()
|
||||||
|
self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format(
|
||||||
|
train_stop_time - train_start_time,
|
||||||
|
memory_peak // 1024,
|
||||||
|
gpu_memory))
|
||||||
|
|
||||||
|
def save_summaries(self):
|
||||||
|
global_step = self.batch_generator_train.global_step
|
||||||
|
if self.batch_generator_train.epoch < self.epoch_skip:
|
||||||
|
return
|
||||||
|
if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and (
|
||||||
|
self.batch_generator_train.step % self.image_period != self.image_period - 1):
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.batch_generator_val.step != 0:
|
||||||
|
self.batch_generator_val.skip_epoch()
|
||||||
|
val_loss = 0.0
|
||||||
|
val_count = 0
|
||||||
|
self.network.train(False)
|
||||||
|
with torch.no_grad():
|
||||||
|
val_epoch = self.batch_generator_val.epoch
|
||||||
|
while val_epoch == self.batch_generator_val.epoch:
|
||||||
|
val_inputs = torch.as_tensor(
|
||||||
|
self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device)
|
||||||
|
val_labels = torch.as_tensor(
|
||||||
|
self.batch_generator_val.batch_data, dtype=self.label_dtype, device=self.device)
|
||||||
|
|
||||||
|
val_pre_process = self.pre_process(val_inputs)
|
||||||
|
val_outputs = self.network(val_pre_process)
|
||||||
|
loss = self.criterion(
|
||||||
|
val_outputs,
|
||||||
|
val_labels if not self.data_is_label else val_pre_process).item()
|
||||||
|
val_loss += loss
|
||||||
|
val_count += len(self.batch_generator_val.batch_data)
|
||||||
|
self.val_step_callback(
|
||||||
|
val_inputs, val_pre_process, val_labels, val_outputs, loss)
|
||||||
|
|
||||||
|
self.batch_generator_val.next_batch()
|
||||||
|
self.network.train(True)
|
||||||
|
|
||||||
|
# Add summaries
|
||||||
|
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
||||||
|
self.writer_train.add_scalar(
|
||||||
|
'loss', self.running_loss / self.running_count, global_step=global_step)
|
||||||
|
self.writer_val.add_scalar(
|
||||||
|
'loss', val_loss / val_count, global_step=global_step)
|
||||||
|
self.summary_callback(
|
||||||
|
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
|
||||||
|
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
|
||||||
|
|
||||||
|
# Add image
|
||||||
|
if self.batch_generator_train.step % self.image_period == (self.image_period - 1):
|
||||||
|
self.image_callback(
|
||||||
|
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
|
||||||
|
val_inputs, val_pre_process, val_labels, val_outputs)
|
||||||
|
|
||||||
|
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||||
|
print(f'Step {global_step}, '
|
||||||
|
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||||
|
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||||
|
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
|
||||||
|
self.benchmark_time = time.time()
|
||||||
|
self.benchmark_step = 0
|
||||||
|
self.running_loss = 0.0
|
||||||
|
self.running_count = 0
|
||||||
|
|
@ -31,7 +31,7 @@ class BatchGenerator:
|
||||||
if save is not None:
|
if save is not None:
|
||||||
if '.' not in os.path.basename(save_path):
|
if '.' not in os.path.basename(save_path):
|
||||||
save_path += '.hdf5'
|
save_path += '.hdf5'
|
||||||
if not os.path.exists(os.path.dirname(save_path)):
|
if os.path.dirname(save_path) and not os.path.exists(os.path.dirname(save_path)):
|
||||||
os.makedirs(os.path.dirname(save_path))
|
os.makedirs(os.path.dirname(save_path))
|
||||||
|
|
||||||
if save and os.path.exists(save_path):
|
if save and os.path.exists(save_path):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue