diff --git a/trainer.py b/trainer.py index 0412c56..92b8f8f 100644 --- a/trainer.py +++ b/trainer.py @@ -23,12 +23,13 @@ class Trainer: epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, data_dtype=None, label_dtype=None, train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, - logger=DummyLogger): + logger=DummyLogger()): super().__init__() self.device = device self.output_dir = output_dir self.data_is_label = data_is_label self.logger = logger + self.should_stop = False self.batch_generator_train = batch_generator_train self.batch_generator_val = batch_generator_val @@ -126,12 +127,12 @@ class Trainer: def fit(self, epochs: int) -> torch.Tensor: train_start_time = time.time() try: - while self.batch_generator_train.epoch < epochs: + while not self.should_stop and self.batch_generator_train.epoch < epochs: epoch = self.batch_generator_train.epoch print() print(' ' * os.get_terminal_size()[0], end='\r') print(f'Epoch {self.batch_generator_train.epoch}') - while epoch == self.batch_generator_train.epoch: + while not self.should_stop and epoch == self.batch_generator_train.epoch: self.batch_inputs = torch.as_tensor( self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device) self.batch_labels = torch.as_tensor( @@ -173,10 +174,9 @@ class Trainer: torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt')) memory_peak, gpu_memory = resource_usage() - self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format( - train_stop_time - train_start_time, - memory_peak // 1024, - gpu_memory)) + self.logger.info( + f'Training time : {train_stop_time - train_start_time:.03f}s\n' + f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}') def save_summaries(self): global_step = self.batch_generator_train.global_step