291 lines
14 KiB
Python
291 lines
14 KiB
Python
import glob
|
|
import os
|
|
import shutil
|
|
import time
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from .train import parameter_summary, resource_usage
|
|
from .utils.logger import DummyLogger
|
|
from .utils.batch_generator import BatchGenerator
|
|
|
|
|
|
class Trainer:
|
|
def __init__(
|
|
self,
|
|
batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator,
|
|
device: str, output_dir: str,
|
|
network: nn.Module, pre_process: nn.Module,
|
|
optimizer: torch.optim.Optimizer, criterion: nn.Module,
|
|
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
|
data_dtype=None, label_dtype=None,
|
|
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
|
logger=DummyLogger(), verbose=True, save_src=True):
|
|
super().__init__()
|
|
self.device = device
|
|
self.output_dir = output_dir
|
|
self.data_is_label = data_is_label
|
|
self.logger = logger
|
|
self.verbose = verbose
|
|
self.should_stop = False
|
|
|
|
self.batch_generator_train = batch_generator_train
|
|
self.batch_generator_val = batch_generator_val
|
|
|
|
self.pre_process = pre_process
|
|
self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process
|
|
self.data_dtype = data_dtype
|
|
self.label_dtype = label_dtype
|
|
|
|
self.network = network
|
|
self.optimizer = optimizer
|
|
self.criterion = criterion
|
|
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
|
|
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
|
|
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
|
|
|
|
# Save network graph
|
|
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
|
|
batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device)
|
|
processed_inputs = pre_process(batch_inputs)
|
|
self.writer_train.add_graph(network, (processed_inputs,))
|
|
self.writer_val.add_graph(network, (processed_inputs,))
|
|
|
|
# Save parameters info
|
|
with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file:
|
|
param_summary = parameter_summary(network)
|
|
names = [len(name) for name, _, _ in param_summary]
|
|
shapes = [len(str(shape)) for _, shape, _ in param_summary]
|
|
param_file.write(
|
|
'\n'.join(
|
|
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
|
for name, shape, size in param_summary]))
|
|
|
|
# Calculate summary periods
|
|
self.epoch_skip = epoch_skip
|
|
self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch
|
|
if self.summary_period == 0:
|
|
self.summary_period = 1
|
|
self.image_period = batch_generator_train.step_per_epoch // image_per_epoch
|
|
if self.image_period == 0:
|
|
self.image_period = 1
|
|
# Avoid different period between image and summaries (due to odd number division)
|
|
if summary_per_epoch % image_per_epoch == 0:
|
|
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
|
|
|
|
torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
|
|
# Save source files
|
|
if save_src:
|
|
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
|
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
|
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
|
if not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
|
|
|
# Initialize training loop variables
|
|
self.batch_inputs = batch_inputs
|
|
self.batch_labels = batch_labels
|
|
self.processed_inputs = processed_inputs
|
|
self.network_outputs = processed_inputs # Placeholder
|
|
self.train_loss = 0.0
|
|
self.train_accuracy = 0.0
|
|
self.running_loss = 0.0
|
|
self.running_accuracy = 0.0
|
|
self.running_count = 0
|
|
self.benchmark_step = 0
|
|
self.benchmark_time = time.time()
|
|
|
|
def end_epoch_callback(self):
|
|
pass
|
|
|
|
def train_step_callback(
|
|
self,
|
|
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
|
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
|
loss: float, accuracy: float):
|
|
pass
|
|
|
|
def val_step_callback(
|
|
self,
|
|
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
|
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
|
loss: float, accuracy: float):
|
|
pass
|
|
|
|
def pre_summary_callback(self):
|
|
pass
|
|
|
|
def summary_callback(
|
|
self,
|
|
train_inputs: torch.Tensor, train_processed: torch.Tensor,
|
|
train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int,
|
|
val_inputs: torch.Tensor, val_processed: torch.Tensor,
|
|
val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int):
|
|
pass
|
|
|
|
def image_callback(
|
|
self,
|
|
train_inputs: torch.Tensor, train_processed: torch.Tensor,
|
|
train_labels: torch.Tensor, train_outputs: torch.Tensor,
|
|
val_inputs: torch.Tensor, val_processed: torch.Tensor,
|
|
val_labels: torch.Tensor, val_outputs: torch.Tensor):
|
|
pass
|
|
|
|
def fit(self, epochs: int) -> torch.Tensor:
|
|
train_start_time = time.time()
|
|
try:
|
|
while not self.should_stop and self.batch_generator_train.epoch < epochs:
|
|
epoch = self.batch_generator_train.epoch
|
|
if self.verbose:
|
|
print()
|
|
print(' ' * os.get_terminal_size()[0], end='\r')
|
|
print(f'Epoch {self.batch_generator_train.epoch}')
|
|
while not self.should_stop and epoch == self.batch_generator_train.epoch:
|
|
self.batch_inputs = torch.as_tensor(
|
|
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
|
self.batch_labels = torch.as_tensor(
|
|
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
|
|
|
if self.verbose and self.benchmark_step > 1:
|
|
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
|
print(
|
|
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
|
|
f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r')
|
|
|
|
# Faster zero grad
|
|
for param in self.network.parameters():
|
|
param.grad = None
|
|
|
|
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
|
self.network_outputs = self.network(self.processed_inputs)
|
|
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
|
loss = self.criterion(self.network_outputs, labels)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
self.train_loss = loss.item()
|
|
self.train_accuracy = self.accuracy_fn(
|
|
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
|
|
self.running_loss += self.train_loss
|
|
self.running_accuracy += self.train_accuracy
|
|
self.running_count += len(self.batch_generator_train.batch_data)
|
|
self.train_step_callback(
|
|
self.batch_inputs, self.processed_inputs, labels,
|
|
self.network_outputs, self.train_loss, self.train_accuracy)
|
|
|
|
self.benchmark_step += 1
|
|
self.save_summaries()
|
|
self.batch_generator_train.next_batch()
|
|
self.end_epoch_callback()
|
|
except KeyboardInterrupt:
|
|
if self.verbose:
|
|
print()
|
|
|
|
# Small training loop for last metrics
|
|
for _ in range(20):
|
|
self.batch_inputs = torch.as_tensor(
|
|
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
|
self.batch_labels = torch.as_tensor(
|
|
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
|
|
|
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
|
self.network_outputs = self.network(self.processed_inputs)
|
|
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
|
loss = self.criterion(self.network_outputs, labels)
|
|
|
|
self.train_loss = loss.item()
|
|
self.train_accuracy = self.accuracy_fn(
|
|
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
|
|
self.running_loss += self.train_loss
|
|
self.running_accuracy += self.train_accuracy
|
|
self.running_count += len(self.batch_generator_train.batch_data)
|
|
self.benchmark_step += 1
|
|
self.batch_generator_train.next_batch()
|
|
self.save_summaries(force_summary=True)
|
|
train_stop_time = time.time()
|
|
self.writer_train.close()
|
|
self.writer_val.close()
|
|
|
|
memory_peak, gpu_memory = resource_usage()
|
|
self.logger.info(
|
|
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
|
|
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
|
|
|
|
def save_summaries(self, force_summary=False):
|
|
global_step = self.batch_generator_train.global_step
|
|
if self.batch_generator_train.epoch < self.epoch_skip:
|
|
return
|
|
if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and (
|
|
self.batch_generator_train.step % self.image_period != self.image_period - 1):
|
|
return
|
|
|
|
if self.batch_generator_val.step != 0:
|
|
self.batch_generator_val.skip_epoch()
|
|
self.pre_summary_callback()
|
|
val_loss = 0.0
|
|
val_accuracy = 0.0
|
|
val_count = 0
|
|
self.network.train(False)
|
|
with torch.no_grad():
|
|
val_epoch = self.batch_generator_val.epoch
|
|
while val_epoch == self.batch_generator_val.epoch:
|
|
val_inputs = torch.as_tensor(
|
|
self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device)
|
|
val_labels = torch.as_tensor(
|
|
self.batch_generator_val.batch_label, dtype=self.label_dtype, device=self.device)
|
|
|
|
val_pre_process = self.pre_process(val_inputs)
|
|
val_outputs = self.network(val_pre_process)
|
|
val_labels = val_labels if not self.data_is_label else val_pre_process
|
|
loss = self.criterion(val_outputs, val_labels).item()
|
|
accuracy = self.accuracy_fn(
|
|
val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0
|
|
val_loss += loss
|
|
val_accuracy += accuracy
|
|
val_count += len(self.batch_generator_val.batch_data)
|
|
self.val_step_callback(
|
|
val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy)
|
|
|
|
self.batch_generator_val.next_batch()
|
|
self.network.train(True)
|
|
|
|
# Add summaries
|
|
if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
|
self.writer_train.add_scalar(
|
|
'loss', self.running_loss / self.running_count, global_step=global_step)
|
|
self.writer_val.add_scalar(
|
|
'loss', val_loss / val_count, global_step=global_step)
|
|
if self.accuracy_fn is not None:
|
|
self.writer_train.add_scalar(
|
|
'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step)
|
|
self.writer_val.add_scalar(
|
|
'error', 1 - (val_accuracy / val_count), global_step=global_step)
|
|
self.summary_callback(
|
|
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
|
|
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
|
|
|
|
# Add image
|
|
if self.batch_generator_train.step % self.image_period == (self.image_period - 1):
|
|
self.image_callback(
|
|
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
|
|
val_inputs, val_pre_process, val_labels, val_outputs)
|
|
|
|
if self.verbose:
|
|
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
|
print(f'Step {global_step}, '
|
|
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
|
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
|
if self.accuracy_fn is not None:
|
|
torch.save(self.network.state_dict(), os.path.join(
|
|
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
|
else:
|
|
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
|
self.benchmark_time = time.time()
|
|
self.benchmark_step = 0
|
|
self.running_loss = 0.0
|
|
self.running_accuracy = 0.0
|
|
self.running_count = 0
|