Fix dummy logger in trainer + stop mechanism
This commit is contained in:
parent
9ac6fb64e8
commit
fa9188ad75
1 changed files with 7 additions and 7 deletions
14
trainer.py
14
trainer.py
|
|
@ -23,12 +23,13 @@ class Trainer:
|
||||||
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
||||||
data_dtype=None, label_dtype=None,
|
data_dtype=None, label_dtype=None,
|
||||||
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
||||||
logger=DummyLogger):
|
logger=DummyLogger()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.data_is_label = data_is_label
|
self.data_is_label = data_is_label
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.should_stop = False
|
||||||
|
|
||||||
self.batch_generator_train = batch_generator_train
|
self.batch_generator_train = batch_generator_train
|
||||||
self.batch_generator_val = batch_generator_val
|
self.batch_generator_val = batch_generator_val
|
||||||
|
|
@ -126,12 +127,12 @@ class Trainer:
|
||||||
def fit(self, epochs: int) -> torch.Tensor:
|
def fit(self, epochs: int) -> torch.Tensor:
|
||||||
train_start_time = time.time()
|
train_start_time = time.time()
|
||||||
try:
|
try:
|
||||||
while self.batch_generator_train.epoch < epochs:
|
while not self.should_stop and self.batch_generator_train.epoch < epochs:
|
||||||
epoch = self.batch_generator_train.epoch
|
epoch = self.batch_generator_train.epoch
|
||||||
print()
|
print()
|
||||||
print(' ' * os.get_terminal_size()[0], end='\r')
|
print(' ' * os.get_terminal_size()[0], end='\r')
|
||||||
print(f'Epoch {self.batch_generator_train.epoch}')
|
print(f'Epoch {self.batch_generator_train.epoch}')
|
||||||
while epoch == self.batch_generator_train.epoch:
|
while not self.should_stop and epoch == self.batch_generator_train.epoch:
|
||||||
self.batch_inputs = torch.as_tensor(
|
self.batch_inputs = torch.as_tensor(
|
||||||
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
||||||
self.batch_labels = torch.as_tensor(
|
self.batch_labels = torch.as_tensor(
|
||||||
|
|
@ -173,10 +174,9 @@ class Trainer:
|
||||||
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
||||||
|
|
||||||
memory_peak, gpu_memory = resource_usage()
|
memory_peak, gpu_memory = resource_usage()
|
||||||
self.logger.info('Training time : {:.03f}s\n\tRAM peak : {} MB\n\tVRAM usage : {}'.format(
|
self.logger.info(
|
||||||
train_stop_time - train_start_time,
|
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
|
||||||
memory_peak // 1024,
|
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
|
||||||
gpu_memory))
|
|
||||||
|
|
||||||
def save_summaries(self):
|
def save_summaries(self):
|
||||||
global_step = self.batch_generator_train.global_step
|
global_step = self.batch_generator_train.global_step
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue