Fix dummy logger in trainer + stop mechanism

This commit is contained in:
Corentin 2021-01-20 13:16:01 +09:00
commit fa9188ad75

View file

@ -23,12 +23,13 @@ class Trainer:
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger):
logger=DummyLogger()):
super().__init__()
self.device = device
self.output_dir = output_dir
self.data_is_label = data_is_label
self.logger = logger
self.should_stop = False
self.batch_generator_train = batch_generator_train
self.batch_generator_val = batch_generator_val
@ -126,12 +127,12 @@ class Trainer:
def fit(self, epochs: int) -> torch.Tensor:
train_start_time = time.time()
try:
while self.batch_generator_train.epoch < epochs:
while not self.should_stop and self.batch_generator_train.epoch < epochs:
epoch = self.batch_generator_train.epoch
print()
print(' ' * os.get_terminal_size()[0], end='\r')
print(f'Epoch {self.batch_generator_train.epoch}')
while epoch == self.batch_generator_train.epoch:
while not self.should_stop and epoch == self.batch_generator_train.epoch:
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
@ -173,10 +174,9 @@ class Trainer:
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
memory_peak, gpu_memory = resource_usage()
self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format(
train_stop_time - train_start_time,
memory_peak // 1024,
gpu_memory))
self.logger.info(
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
def save_summaries(self):
global_step = self.batch_generator_train.global_step