import glob import os import shutil import time from typing import Callable, Optional import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from .train import parameter_summary, resource_usage from .utils.logger import DummyLogger from .utils.batch_generator import BatchGenerator class Trainer: def __init__( self, batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator, device: str, output_dir: str, network: nn.Module, pre_process: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, data_dtype=None, label_dtype=None, train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, logger=DummyLogger()): super().__init__() self.device = device self.output_dir = output_dir self.data_is_label = data_is_label self.logger = logger self.should_stop = False self.batch_generator_train = batch_generator_train self.batch_generator_val = batch_generator_val self.pre_process = pre_process self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process self.data_dtype = data_dtype self.label_dtype = label_dtype self.network = network self.optimizer = optimizer self.criterion = criterion self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30) self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30) # Save network graph batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device) batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device) processed_inputs = pre_process(batch_inputs) self.writer_train.add_graph(network, (processed_inputs,)) self.writer_val.add_graph(network, (processed_inputs,)) # Save parameters info with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] param_file.write( '\n'.join( [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) # Calculate summary periods self.epoch_skip = epoch_skip self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch if self.summary_period == 0: self.summary_period = 1 self.image_period = batch_generator_train.step_per_epoch // image_per_epoch if self.image_period == 0: self.image_period = 1 # Avoid different period between image and summaries (due to odd number division) if summary_per_epoch % image_per_epoch == 0: self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch) torch.save(network, os.path.join(output_dir, 'model_init.pt')) # Save source files for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( os.path.join('src', '**', '*.py'), recursive=True): dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) if not os.path.exists(dirname): os.makedirs(dirname) shutil.copy2(entry, os.path.join(output_dir, 'code', entry)) # Initialize training loop variables self.batch_inputs = batch_inputs self.batch_labels = batch_labels self.processed_inputs = processed_inputs self.network_outputs = processed_inputs # Placeholder self.train_loss = 0.0 self.running_loss = 0.0 self.running_count = 0 self.benchmark_step = 0 self.benchmark_time = time.time() def train_step_callback( self, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, batch_labels: torch.Tensor, network_outputs: torch.Tensor, loss: float): pass def val_step_callback( self, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, batch_labels: torch.Tensor, network_outputs: torch.Tensor, loss: float): pass def summary_callback( self, train_inputs: torch.Tensor, train_processed: torch.Tensor, train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int, val_inputs: torch.Tensor, val_processed: torch.Tensor, val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int): pass def image_callback( self, train_inputs: torch.Tensor, train_processed: torch.Tensor, train_labels: torch.Tensor, train_outputs: torch.Tensor, val_inputs: torch.Tensor, val_processed: torch.Tensor, val_labels: torch.Tensor, val_outputs: torch.Tensor): pass def fit(self, epochs: int) -> torch.Tensor: train_start_time = time.time() try: while not self.should_stop and self.batch_generator_train.epoch < epochs: epoch = self.batch_generator_train.epoch print() print(' ' * os.get_terminal_size()[0], end='\r') print(f'Epoch {self.batch_generator_train.epoch}') while not self.should_stop and epoch == self.batch_generator_train.epoch: self.batch_inputs = torch.as_tensor( self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device) self.batch_labels = torch.as_tensor( self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device) if self.benchmark_step > 1: speed = self.benchmark_step / (time.time() - self.benchmark_time) print( f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s' f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r') # Faster zero grad for param in self.network.parameters(): param.grad = None self.processed_inputs = self.train_pre_process(self.batch_inputs) self.network_outputs = self.network(self.processed_inputs) loss = self.criterion( self.network_outputs, self.batch_labels if not self.data_is_label else self.processed_inputs) loss.backward() self.optimizer.step() self.train_loss = loss.item() self.running_loss += self.train_loss self.running_count += len(self.batch_generator_train.batch_data) self.train_step_callback( self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.train_loss) self.benchmark_step += 1 self.save_summaries() self.batch_generator_train.next_batch() except KeyboardInterrupt: print() train_stop_time = time.time() self.writer_train.close() torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt')) memory_peak, gpu_memory = resource_usage() self.logger.info( f'Training time : {train_stop_time - train_start_time:.03f}s\n' f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}') def save_summaries(self): global_step = self.batch_generator_train.global_step if self.batch_generator_train.epoch < self.epoch_skip: return if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and ( self.batch_generator_train.step % self.image_period != self.image_period - 1): return if self.batch_generator_val.step != 0: self.batch_generator_val.skip_epoch() val_loss = 0.0 val_count = 0 self.network.train(False) with torch.no_grad(): val_epoch = self.batch_generator_val.epoch while val_epoch == self.batch_generator_val.epoch: val_inputs = torch.as_tensor( self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device) val_labels = torch.as_tensor( self.batch_generator_val.batch_label, dtype=self.label_dtype, device=self.device) val_pre_process = self.pre_process(val_inputs) val_outputs = self.network(val_pre_process) loss = self.criterion( val_outputs, val_labels if not self.data_is_label else val_pre_process).item() val_loss += loss val_count += len(self.batch_generator_val.batch_data) self.val_step_callback( val_inputs, val_pre_process, val_labels, val_outputs, loss) self.batch_generator_val.next_batch() self.network.train(True) # Add summaries if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1): self.writer_train.add_scalar( 'loss', self.running_loss / self.running_count, global_step=global_step) self.writer_val.add_scalar( 'loss', val_loss / val_count, global_step=global_step) self.summary_callback( self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count, val_inputs, val_pre_process, val_labels, val_outputs, val_count) # Add image if self.batch_generator_train.step % self.image_period == (self.image_period - 1): self.image_callback( self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, val_inputs, val_pre_process, val_labels, val_outputs) speed = self.benchmark_step / (time.time() - self.benchmark_time) print(f'Step {global_step}, ' f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt')) self.benchmark_time = time.time() self.benchmark_step = 0 self.running_loss = 0.0 self.running_count = 0