diff --git a/trainer.py b/trainer.py index 92b8f8f..39742ce 100644 --- a/trainer.py +++ b/trainer.py @@ -23,12 +23,13 @@ class Trainer: epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, data_dtype=None, label_dtype=None, train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, - logger=DummyLogger()): + logger=DummyLogger(), verbose=True): super().__init__() self.device = device self.output_dir = output_dir self.data_is_label = data_is_label self.logger = logger + self.verbose = verbose self.should_stop = False self.batch_generator_train = batch_generator_train @@ -129,16 +130,17 @@ class Trainer: try: while not self.should_stop and self.batch_generator_train.epoch < epochs: epoch = self.batch_generator_train.epoch - print() - print(' ' * os.get_terminal_size()[0], end='\r') - print(f'Epoch {self.batch_generator_train.epoch}') + if self.verbose: + print() + print(' ' * os.get_terminal_size()[0], end='\r') + print(f'Epoch {self.batch_generator_train.epoch}') while not self.should_stop and epoch == self.batch_generator_train.epoch: self.batch_inputs = torch.as_tensor( self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device) self.batch_labels = torch.as_tensor( self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device) - if self.benchmark_step > 1: + if self.verbose and self.benchmark_step > 1: speed = self.benchmark_step / (time.time() - self.benchmark_time) print( f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s' @@ -167,7 +169,8 @@ class Trainer: self.save_summaries() self.batch_generator_train.next_batch() except KeyboardInterrupt: - print() + if self.verbose: + print() train_stop_time = time.time() self.writer_train.close() @@ -228,10 +231,11 @@ class Trainer: self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, val_inputs, val_pre_process, val_labels, val_outputs) - speed = self.benchmark_step / (time.time() - self.benchmark_time) - print(f'Step {global_step}, ' - f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' - f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') + if self.verbose: + speed = self.benchmark_step / (time.time() - self.benchmark_time) + print(f'Step {global_step}, ' + f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' + f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt')) self.benchmark_time = time.time() self.benchmark_step = 0