Trainer verbose argument implementation

This commit is contained in:
Corentin 2021-01-23 04:18:05 +09:00
commit b43b8b14d6

View file

@ -23,12 +23,13 @@ class Trainer:
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None, data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger()): logger=DummyLogger(), verbose=True):
super().__init__() super().__init__()
self.device = device self.device = device
self.output_dir = output_dir self.output_dir = output_dir
self.data_is_label = data_is_label self.data_is_label = data_is_label
self.logger = logger self.logger = logger
self.verbose = verbose
self.should_stop = False self.should_stop = False
self.batch_generator_train = batch_generator_train self.batch_generator_train = batch_generator_train
@ -129,6 +130,7 @@ class Trainer:
try: try:
while not self.should_stop and self.batch_generator_train.epoch < epochs: while not self.should_stop and self.batch_generator_train.epoch < epochs:
epoch = self.batch_generator_train.epoch epoch = self.batch_generator_train.epoch
if self.verbose:
print() print()
print(' ' * os.get_terminal_size()[0], end='\r') print(' ' * os.get_terminal_size()[0], end='\r')
print(f'Epoch {self.batch_generator_train.epoch}') print(f'Epoch {self.batch_generator_train.epoch}')
@ -138,7 +140,7 @@ class Trainer:
self.batch_labels = torch.as_tensor( self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device) self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
if self.benchmark_step > 1: if self.verbose and self.benchmark_step > 1:
speed = self.benchmark_step / (time.time() - self.benchmark_time) speed = self.benchmark_step / (time.time() - self.benchmark_time)
print( print(
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s' f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
@ -167,6 +169,7 @@ class Trainer:
self.save_summaries() self.save_summaries()
self.batch_generator_train.next_batch() self.batch_generator_train.next_batch()
except KeyboardInterrupt: except KeyboardInterrupt:
if self.verbose:
print() print()
train_stop_time = time.time() train_stop_time = time.time()
@ -228,6 +231,7 @@ class Trainer:
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
val_inputs, val_pre_process, val_labels, val_outputs) val_inputs, val_pre_process, val_labels, val_outputs)
if self.verbose:
speed = self.benchmark_step / (time.time() - self.benchmark_time) speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(f'Step {global_step}, ' print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '