Trainer implementation

* Add Deconv2d
* Fix BatchGenerator save option when using current directory
This commit is contained in:
Corentin 2020-12-25 15:50:38 +09:00
commit f2282e3216
3 changed files with 263 additions and 7 deletions

239
trainer.py Normal file
View file

@ -0,0 +1,239 @@
import glob
import os
import shutil
import time
from typing import Callable, Optional
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from .train import parameter_summary, resource_usage
from .utils.logger import DummyLogger
from .utils.batch_generator import BatchGenerator
class Trainer:
def __init__(
self,
batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator,
device: str, output_dir: str,
network: nn.Module, pre_process: nn.Module,
optimizer: torch.optim.Optimizer, criterion: nn.Module,
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger):
super().__init__()
self.device = device
self.output_dir = output_dir
self.data_is_label = data_is_label
self.logger = logger
self.batch_generator_train = batch_generator_train
self.batch_generator_val = batch_generator_val
self.pre_process = pre_process
self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process
self.data_dtype = data_dtype
self.label_dtype = label_dtype
self.network = network
self.optimizer = optimizer
self.criterion = criterion
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
# Save network graph
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device)
processed_inputs = pre_process(batch_inputs)
self.writer_train.add_graph(network, (processed_inputs,))
self.writer_val.add_graph(network, (processed_inputs,))
# Save parameters info
with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
# Calculate summary periods
self.epoch_skip = epoch_skip
self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch
if self.summary_period == 0:
self.summary_period = 1
self.image_period = batch_generator_train.step_per_epoch // image_per_epoch
if self.image_period == 0:
self.image_period = 1
# Avoid different period between image and summaries (due to odd number division)
if summary_per_epoch % image_per_epoch == 0:
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
# Save source files
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
if not os.path.exists(dirname):
os.makedirs(dirname)
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
# Initialize training loop variables
self.batch_inputs = batch_inputs
self.batch_labels = batch_labels
self.processed_inputs = processed_inputs
self.network_outputs = processed_inputs # Placeholder
self.train_loss = 0.0
self.running_loss = 0.0
self.running_count = 0
self.benchmark_step = 0
self.benchmark_time = time.time()
def train_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float):
pass
def val_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float):
pass
def summary_callback(
self,
train_inputs: torch.Tensor, train_processed: torch.Tensor,
train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int,
val_inputs: torch.Tensor, val_processed: torch.Tensor,
val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int):
pass
def image_callback(
self,
train_inputs: torch.Tensor, train_processed: torch.Tensor,
train_labels: torch.Tensor, train_outputs: torch.Tensor,
val_inputs: torch.Tensor, val_processed: torch.Tensor,
val_labels: torch.Tensor, val_outputs: torch.Tensor):
pass
def fit(self, epochs: int) -> torch.Tensor:
train_start_time = time.time()
try:
while self.batch_generator_train.epoch < epochs:
epoch = self.batch_generator_train.epoch
print()
print(' ' * os.get_terminal_size()[0], end='\r')
print(f'Epoch {self.batch_generator_train.epoch}')
while epoch == self.batch_generator_train.epoch:
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
if self.benchmark_step > 1:
speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r')
# Faster zero grad
for param in self.network.parameters():
param.grad = None
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
loss = self.criterion(
self.network_outputs,
self.batch_labels if not self.data_is_label else self.processed_inputs)
loss.backward()
self.optimizer.step()
self.train_loss = loss.item()
self.running_loss += self.train_loss
self.running_count += len(self.batch_generator_train.batch_data)
self.train_step_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels,
self.network_outputs, self.train_loss)
self.benchmark_step += 1
self.save_summaries()
self.batch_generator_train.next_batch()
except KeyboardInterrupt:
print()
train_stop_time = time.time()
self.writer_train.close()
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
memory_peak, gpu_memory = resource_usage()
self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format(
train_stop_time - train_start_time,
memory_peak // 1024,
gpu_memory))
def save_summaries(self):
global_step = self.batch_generator_train.global_step
if self.batch_generator_train.epoch < self.epoch_skip:
return
if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and (
self.batch_generator_train.step % self.image_period != self.image_period - 1):
return
if self.batch_generator_val.step != 0:
self.batch_generator_val.skip_epoch()
val_loss = 0.0
val_count = 0
self.network.train(False)
with torch.no_grad():
val_epoch = self.batch_generator_val.epoch
while val_epoch == self.batch_generator_val.epoch:
val_inputs = torch.as_tensor(
self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device)
val_labels = torch.as_tensor(
self.batch_generator_val.batch_data, dtype=self.label_dtype, device=self.device)
val_pre_process = self.pre_process(val_inputs)
val_outputs = self.network(val_pre_process)
loss = self.criterion(
val_outputs,
val_labels if not self.data_is_label else val_pre_process).item()
val_loss += loss
val_count += len(self.batch_generator_val.batch_data)
self.val_step_callback(
val_inputs, val_pre_process, val_labels, val_outputs, loss)
self.batch_generator_val.next_batch()
self.network.train(True)
# Add summaries
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
self.writer_train.add_scalar(
'loss', self.running_loss / self.running_count, global_step=global_step)
self.writer_val.add_scalar(
'loss', val_loss / val_count, global_step=global_step)
self.summary_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
# Add image
if self.batch_generator_train.step % self.image_period == (self.image_period - 1):
self.image_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
val_inputs, val_pre_process, val_labels, val_outputs)
speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
self.benchmark_time = time.time()
self.benchmark_step = 0
self.running_loss = 0.0
self.running_count = 0