From f2282e3216e3990b8538cd4a619f6a42b4a0c54b Mon Sep 17 00:00:00 2001 From: Corentin Date: Fri, 25 Dec 2020 15:50:38 +0900 Subject: [PATCH] Trainer implementation * Add Deconv2d * Fix BatchGenerator save option when using current directory --- layers.py | 29 ++++- trainer.py | 239 +++++++++++++++++++++++++++++++++++++++ utils/batch_generator.py | 2 +- 3 files changed, 263 insertions(+), 7 deletions(-) create mode 100644 trainer.py diff --git a/layers.py b/layers.py index 4a5296e..10df5f2 100644 --- a/layers.py +++ b/layers.py @@ -45,6 +45,20 @@ class Layer(nn.Module): return output +class Linear(Layer): + def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): + super().__init__(activation, batch_norm) + + self.fc = nn.Linear(in_channels, out_channels, **kwargs) + self.batch_norm = nn.BatchNorm1d( + out_channels, + momentum=Layer.BATCH_NORM_MOMENTUM, + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return super().forward(self.fc(input_data)) + + class Conv1d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): @@ -93,15 +107,18 @@ class Conv3d(Layer): return super().forward(self.conv(input_data)) -class Linear(Layer): - def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): +class Deconv2d(Layer): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, + stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm) - self.fc = nn.Linear(in_channels, out_channels, **kwargs) - self.batch_norm = nn.BatchNorm1d( + self.deconv = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size, stride=stride, + bias=not self.batch_norm, **kwargs) + self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: - return super().forward(self.fc(input_data)) + return super().forward(self.deconv(input_data)) diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..12abd73 --- /dev/null +++ b/trainer.py @@ -0,0 +1,239 @@ +import glob +import os +import shutil +import time +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter + +from .train import parameter_summary, resource_usage +from .utils.logger import DummyLogger +from .utils.batch_generator import BatchGenerator + + +class Trainer: + def __init__( + self, + batch_generator_train: BatchGenerator, batch_generator_val: BatchGenerator, + device: str, output_dir: str, + network: nn.Module, pre_process: nn.Module, + optimizer: torch.optim.Optimizer, criterion: nn.Module, + epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, + data_dtype=None, label_dtype=None, + train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, + logger=DummyLogger): + super().__init__() + self.device = device + self.output_dir = output_dir + self.data_is_label = data_is_label + self.logger = logger + + self.batch_generator_train = batch_generator_train + self.batch_generator_val = batch_generator_val + + self.pre_process = pre_process + self.train_pre_process = train_pre_process if train_pre_process is not None else pre_process + self.data_dtype = data_dtype + self.label_dtype = label_dtype + + self.network = network + self.optimizer = optimizer + self.criterion = criterion + self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30) + self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30) + + # Save network graph + batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device) + batch_labels = torch.as_tensor(batch_generator_train.batch_label[:2], dtype=label_dtype, device=device) + processed_inputs = pre_process(batch_inputs) + self.writer_train.add_graph(network, (processed_inputs,)) + self.writer_val.add_graph(network, (processed_inputs,)) + + # Save parameters info + with open(os.path.join(output_dir, 'parameters.csv'), 'w') as param_file: + param_summary = parameter_summary(network) + names = [len(name) for name, _, _ in param_summary] + shapes = [len(str(shape)) for _, shape, _ in param_summary] + param_file.write( + '\n'.join( + [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' + for name, shape, size in param_summary])) + + # Calculate summary periods + self.epoch_skip = epoch_skip + self.summary_period = batch_generator_train.step_per_epoch // summary_per_epoch + if self.summary_period == 0: + self.summary_period = 1 + self.image_period = batch_generator_train.step_per_epoch // image_per_epoch + if self.image_period == 0: + self.image_period = 1 + # Avoid different period between image and summaries (due to odd number division) + if summary_per_epoch % image_per_epoch == 0: + self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch) + + torch.save(network, os.path.join(output_dir, 'model_init.pt')) + # Save source files + for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( + os.path.join('src', '**', '*.py'), recursive=True): + dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) + if not os.path.exists(dirname): + os.makedirs(dirname) + shutil.copy2(entry, os.path.join(output_dir, 'code', entry)) + + # Initialize training loop variables + self.batch_inputs = batch_inputs + self.batch_labels = batch_labels + self.processed_inputs = processed_inputs + self.network_outputs = processed_inputs # Placeholder + self.train_loss = 0.0 + self.running_loss = 0.0 + self.running_count = 0 + self.benchmark_step = 0 + self.benchmark_time = time.time() + + def train_step_callback( + self, + batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, + batch_labels: torch.Tensor, network_outputs: torch.Tensor, + loss: float): + pass + + def val_step_callback( + self, + batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, + batch_labels: torch.Tensor, network_outputs: torch.Tensor, + loss: float): + pass + + def summary_callback( + self, + train_inputs: torch.Tensor, train_processed: torch.Tensor, + train_labels: torch.Tensor, train_outputs: torch.Tensor, train_running_count: int, + val_inputs: torch.Tensor, val_processed: torch.Tensor, + val_labels: torch.Tensor, val_outputs: torch.Tensor, val_running_count: int): + pass + + def image_callback( + self, + train_inputs: torch.Tensor, train_processed: torch.Tensor, + train_labels: torch.Tensor, train_outputs: torch.Tensor, + val_inputs: torch.Tensor, val_processed: torch.Tensor, + val_labels: torch.Tensor, val_outputs: torch.Tensor): + pass + + def fit(self, epochs: int) -> torch.Tensor: + train_start_time = time.time() + try: + while self.batch_generator_train.epoch < epochs: + epoch = self.batch_generator_train.epoch + print() + print(' ' * os.get_terminal_size()[0], end='\r') + print(f'Epoch {self.batch_generator_train.epoch}') + while epoch == self.batch_generator_train.epoch: + self.batch_inputs = torch.as_tensor( + self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device) + self.batch_labels = torch.as_tensor( + self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device) + + if self.benchmark_step > 1: + speed = self.benchmark_step / (time.time() - self.benchmark_time) + print( + f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s' + f', {speed * self.batch_generator_train.batch_size:0.02f} input/sec', end='\r') + + # Faster zero grad + for param in self.network.parameters(): + param.grad = None + + self.processed_inputs = self.train_pre_process(self.batch_inputs) + self.network_outputs = self.network(self.processed_inputs) + loss = self.criterion( + self.network_outputs, + self.batch_labels if not self.data_is_label else self.processed_inputs) + loss.backward() + self.optimizer.step() + + self.train_loss = loss.item() + self.running_loss += self.train_loss + self.running_count += len(self.batch_generator_train.batch_data) + self.train_step_callback( + self.batch_inputs, self.processed_inputs, self.batch_labels, + self.network_outputs, self.train_loss) + + self.benchmark_step += 1 + self.save_summaries() + self.batch_generator_train.next_batch() + except KeyboardInterrupt: + print() + + train_stop_time = time.time() + self.writer_train.close() + torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt')) + + memory_peak, gpu_memory = resource_usage() + self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format( + train_stop_time - train_start_time, + memory_peak // 1024, + gpu_memory)) + + def save_summaries(self): + global_step = self.batch_generator_train.global_step + if self.batch_generator_train.epoch < self.epoch_skip: + return + if self.batch_generator_train.step % self.summary_period != self.summary_period - 1 and ( + self.batch_generator_train.step % self.image_period != self.image_period - 1): + return + + if self.batch_generator_val.step != 0: + self.batch_generator_val.skip_epoch() + val_loss = 0.0 + val_count = 0 + self.network.train(False) + with torch.no_grad(): + val_epoch = self.batch_generator_val.epoch + while val_epoch == self.batch_generator_val.epoch: + val_inputs = torch.as_tensor( + self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device) + val_labels = torch.as_tensor( + self.batch_generator_val.batch_data, dtype=self.label_dtype, device=self.device) + + val_pre_process = self.pre_process(val_inputs) + val_outputs = self.network(val_pre_process) + loss = self.criterion( + val_outputs, + val_labels if not self.data_is_label else val_pre_process).item() + val_loss += loss + val_count += len(self.batch_generator_val.batch_data) + self.val_step_callback( + val_inputs, val_pre_process, val_labels, val_outputs, loss) + + self.batch_generator_val.next_batch() + self.network.train(True) + + # Add summaries + if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1): + self.writer_train.add_scalar( + 'loss', self.running_loss / self.running_count, global_step=global_step) + self.writer_val.add_scalar( + 'loss', val_loss / val_count, global_step=global_step) + self.summary_callback( + self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count, + val_inputs, val_pre_process, val_labels, val_outputs, val_count) + + # Add image + if self.batch_generator_train.step % self.image_period == (self.image_period - 1): + self.image_callback( + self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, + val_inputs, val_pre_process, val_labels, val_outputs) + + speed = self.benchmark_step / (time.time() - self.benchmark_time) + print(f'Step {global_step}, ' + f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' + f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') + torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt')) + self.benchmark_time = time.time() + self.benchmark_step = 0 + self.running_loss = 0.0 + self.running_count = 0 diff --git a/utils/batch_generator.py b/utils/batch_generator.py index 4049045..91b70e4 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -31,7 +31,7 @@ class BatchGenerator: if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' - if not os.path.exists(os.path.dirname(save_path)): + if os.path.dirname(save_path) and not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path):