torch_utils/trainer.py
Corentin 86787f6517 Pipeline implementation for BatchGenerator
* Add end_epoch_callback in Trainer
* Fix Layer.ACTIVATION in cas of nn.Module
2021-03-18 21:30:59 +09:00

291 lines
14 KiB
Python

import glob
import os
import shutil
import time
from typing import Callable, Optional
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from .train import parameter_summary, resource_usage
from .utils.logger import DummyLogger
from .utils.batch_generator import BatchGenerator
class Trainer:
def __init__(
self,
batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator,
device: str, output_dir: str,
network: nn.Module, pre_process: nn.Module,
optimizer: torch.optim.Optimizer, criterion: nn.Module,
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger(), verbose=True, save_src=True):
super().__init__()
self.device = device
self.output_dir = output_dir
self.data_is_label = data_is_label
self.logger = logger
self.verbose = verbose
self.should_stop = False
self.batch_generator_train = batch_generator_train
self.batch_generator_val = batch_generator_val
self.pre_process = pre_process
self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process
self.data_dtype = data_dtype
self.label_dtype = label_dtype
self.network = network
self.optimizer = optimizer
self.criterion = criterion
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
# Save network graph
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device)
processed_inputs = pre_process(batch_inputs)
self.writer_train.add_graph(network, (processed_inputs,))
self.writer_val.add_graph(network, (processed_inputs,))
# Save parameters info
with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
# Calculate summary periods
self.epoch_skip = epoch_skip
self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch
if self.summary_period == 0:
self.summary_period = 1
self.image_period = batch_generator_train.step_per_epoch // image_per_epoch
if self.image_period == 0:
self.image_period = 1
# Avoid different period between image and summaries (due to odd number division)
if summary_per_epoch % image_per_epoch == 0:
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
# Save source files
if save_src:
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
if not os.path.exists(dirname):
os.makedirs(dirname)
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
# Initialize training loop variables
self.batch_inputs = batch_inputs
self.batch_labels = batch_labels
self.processed_inputs = processed_inputs
self.network_outputs = processed_inputs # Placeholder
self.train_loss = 0.0
self.train_accuracy = 0.0
self.running_loss = 0.0
self.running_accuracy = 0.0
self.running_count = 0
self.benchmark_step = 0
self.benchmark_time = time.time()
def end_epoch_callback(self):
pass
def train_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float, accuracy: float):
pass
def val_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float, accuracy: float):
pass
def pre_summary_callback(self):
pass
def summary_callback(
self,
train_inputs: torch.Tensor, train_processed: torch.Tensor,
train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int,
val_inputs: torch.Tensor, val_processed: torch.Tensor,
val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int):
pass
def image_callback(
self,
train_inputs: torch.Tensor, train_processed: torch.Tensor,
train_labels: torch.Tensor, train_outputs: torch.Tensor,
val_inputs: torch.Tensor, val_processed: torch.Tensor,
val_labels: torch.Tensor, val_outputs: torch.Tensor):
pass
def fit(self, epochs: int) -> torch.Tensor:
train_start_time = time.time()
try:
while not self.should_stop and self.batch_generator_train.epoch < epochs:
epoch = self.batch_generator_train.epoch
if self.verbose:
print()
print(' ' * os.get_terminal_size()[0], end='\r')
print(f'Epoch {self.batch_generator_train.epoch}')
while not self.should_stop and epoch == self.batch_generator_train.epoch:
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
if self.verbose and self.benchmark_step > 1:
speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r')
# Faster zero grad
for param in self.network.parameters():
param.grad = None
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
loss = self.criterion(self.network_outputs, labels)
loss.backward()
self.optimizer.step()
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.train_step_callback(
self.batch_inputs, self.processed_inputs, labels,
self.network_outputs, self.train_loss, self.train_accuracy)
self.benchmark_step += 1
self.save_summaries()
self.batch_generator_train.next_batch()
self.end_epoch_callback()
except KeyboardInterrupt:
if self.verbose:
print()
# Small training loop for last metrics
for _ in range(20):
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
loss = self.criterion(self.network_outputs, labels)
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.benchmark_step += 1
self.batch_generator_train.next_batch()
self.save_summaries(force_summary=True)
train_stop_time = time.time()
self.writer_train.close()
self.writer_val.close()
memory_peak, gpu_memory = resource_usage()
self.logger.info(
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
def save_summaries(self, force_summary=False):
global_step = self.batch_generator_train.global_step
if self.batch_generator_train.epoch < self.epoch_skip:
return
if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and (
self.batch_generator_train.step % self.image_period != self.image_period - 1):
return
if self.batch_generator_val.step != 0:
self.batch_generator_val.skip_epoch()
self.pre_summary_callback()
val_loss = 0.0
val_accuracy = 0.0
val_count = 0
self.network.train(False)
with torch.no_grad():
val_epoch = self.batch_generator_val.epoch
while val_epoch == self.batch_generator_val.epoch:
val_inputs = torch.as_tensor(
self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device)
val_labels = torch.as_tensor(
self.batch_generator_val.batch_label, dtype=self.label_dtype, device=self.device)
val_pre_process = self.pre_process(val_inputs)
val_outputs = self.network(val_pre_process)
val_labels = val_labels if not self.data_is_label else val_pre_process
loss = self.criterion(val_outputs, val_labels).item()
accuracy = self.accuracy_fn(
val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0
val_loss += loss
val_accuracy += accuracy
val_count += len(self.batch_generator_val.batch_data)
self.val_step_callback(
val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy)
self.batch_generator_val.next_batch()
self.network.train(True)
# Add summaries
if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
self.writer_train.add_scalar(
'loss', self.running_loss / self.running_count, global_step=global_step)
self.writer_val.add_scalar(
'loss', val_loss / val_count, global_step=global_step)
if self.accuracy_fn is not None:
self.writer_train.add_scalar(
'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step)
self.writer_val.add_scalar(
'error', 1 - (val_accuracy / val_count), global_step=global_step)
self.summary_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
# Add image
if self.batch_generator_train.step % self.image_period == (self.image_period - 1):
self.image_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
val_inputs, val_pre_process, val_labels, val_outputs)
if self.verbose:
speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
if self.accuracy_fn is not None:
torch.save(self.network.state_dict(), os.path.join(
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
else:
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
self.benchmark_time = time.time()
self.benchmark_step = 0
self.running_loss = 0.0
self.running_accuracy = 0.0
self.running_count = 0