Trainer verbose argument implementation
This commit is contained in:
parent
fa9188ad75
commit
b43b8b14d6
1 changed files with 14 additions and 10 deletions
24
trainer.py
24
trainer.py
|
|
@ -23,12 +23,13 @@ class Trainer:
|
||||||
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
||||||
data_dtype=None, label_dtype=None,
|
data_dtype=None, label_dtype=None,
|
||||||
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
||||||
logger=DummyLogger()):
|
logger=DummyLogger(), verbose=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.data_is_label = data_is_label
|
self.data_is_label = data_is_label
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.verbose = verbose
|
||||||
self.should_stop = False
|
self.should_stop = False
|
||||||
|
|
||||||
self.batch_generator_train = batch_generator_train
|
self.batch_generator_train = batch_generator_train
|
||||||
|
|
@ -129,16 +130,17 @@ class Trainer:
|
||||||
try:
|
try:
|
||||||
while not self.should_stop and self.batch_generator_train.epoch < epochs:
|
while not self.should_stop and self.batch_generator_train.epoch < epochs:
|
||||||
epoch = self.batch_generator_train.epoch
|
epoch = self.batch_generator_train.epoch
|
||||||
print()
|
if self.verbose:
|
||||||
print(' ' * os.get_terminal_size()[0], end='\r')
|
print()
|
||||||
print(f'Epoch {self.batch_generator_train.epoch}')
|
print(' ' * os.get_terminal_size()[0], end='\r')
|
||||||
|
print(f'Epoch {self.batch_generator_train.epoch}')
|
||||||
while not self.should_stop and epoch == self.batch_generator_train.epoch:
|
while not self.should_stop and epoch == self.batch_generator_train.epoch:
|
||||||
self.batch_inputs = torch.as_tensor(
|
self.batch_inputs = torch.as_tensor(
|
||||||
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
||||||
self.batch_labels = torch.as_tensor(
|
self.batch_labels = torch.as_tensor(
|
||||||
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
||||||
|
|
||||||
if self.benchmark_step > 1:
|
if self.verbose and self.benchmark_step > 1:
|
||||||
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||||
print(
|
print(
|
||||||
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
|
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
|
||||||
|
|
@ -167,7 +169,8 @@ class Trainer:
|
||||||
self.save_summaries()
|
self.save_summaries()
|
||||||
self.batch_generator_train.next_batch()
|
self.batch_generator_train.next_batch()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print()
|
if self.verbose:
|
||||||
|
print()
|
||||||
|
|
||||||
train_stop_time = time.time()
|
train_stop_time = time.time()
|
||||||
self.writer_train.close()
|
self.writer_train.close()
|
||||||
|
|
@ -228,10 +231,11 @@ class Trainer:
|
||||||
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
|
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
|
||||||
val_inputs, val_pre_process, val_labels, val_outputs)
|
val_inputs, val_pre_process, val_labels, val_outputs)
|
||||||
|
|
||||||
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
if self.verbose:
|
||||||
print(f'Step {global_step}, '
|
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
print(f'Step {global_step}, '
|
||||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||||
|
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||||
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
|
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
|
||||||
self.benchmark_time = time.time()
|
self.benchmark_time = time.time()
|
||||||
self.benchmark_step = 0
|
self.benchmark_step = 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue