Accuracy function implemented in Trainer

This commit is contained in:
Corentin 2021-01-26 01:15:02 +09:00
commit d315f342a4

View file

@ -43,6 +43,7 @@ class Trainer:
self.network = network
self.optimizer = optimizer
self.criterion = criterion
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
@ -90,7 +91,9 @@ class Trainer:
self.processed_inputs = processed_inputs
self.network_outputs = processed_inputs # Placeholder
self.train_loss = 0.0
self.train_accuracy = 0.0
self.running_loss = 0.0
self.running_accuracy = 0.0
self.running_count = 0
self.benchmark_step = 0
self.benchmark_time = time.time()
@ -99,14 +102,14 @@ class Trainer:
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float):
loss: float, accuracy: float):
pass
def val_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float):
loss: float, accuracy: float):
pass
def summary_callback(
@ -152,18 +155,20 @@ class Trainer:
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
loss = self.criterion(
self.network_outputs,
self.batch_labels if not self.data_is_label else self.processed_inputs)
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
loss = self.criterion(self.network_outputs, labels)
loss.backward()
self.optimizer.step()
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.train_step_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels,
self.network_outputs, self.train_loss)
self.batch_inputs, self.processed_inputs, labels,
self.network_outputs, self.train_loss, self.train_accuracy)
self.benchmark_step += 1
self.save_summaries()
@ -192,6 +197,7 @@ class Trainer:
if self.batch_generator_val.step != 0:
self.batch_generator_val.skip_epoch()
val_loss = 0.0
val_accuracy = 0.0
val_count = 0
self.network.train(False)
with torch.no_grad():
@ -204,13 +210,15 @@ class Trainer:
val_pre_process = self.pre_process(val_inputs)
val_outputs = self.network(val_pre_process)
loss = self.criterion(
val_outputs,
val_labels if not self.data_is_label else val_pre_process).item()
val_labels = val_labels if not self.data_is_label else val_pre_process
loss = self.criterion(val_outputs, val_labels).item()
accuracy = self.accuracy_fn(
val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0
val_loss += loss
val_accuracy += accuracy
val_count += len(self.batch_generator_val.batch_data)
self.val_step_callback(
val_inputs, val_pre_process, val_labels, val_outputs, loss)
val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy)
self.batch_generator_val.next_batch()
self.network.train(True)
@ -221,6 +229,11 @@ class Trainer:
'loss', self.running_loss / self.running_count, global_step=global_step)
self.writer_val.add_scalar(
'loss', val_loss / val_count, global_step=global_step)
if self.accuracy_fn is not None:
self.writer_train.add_scalar(
'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step)
self.writer_val.add_scalar(
'error', 1 - (val_accuracy / val_count), global_step=global_step)
self.summary_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
@ -240,4 +253,5 @@ class Trainer:
self.benchmark_time = time.time()
self.benchmark_step = 0
self.running_loss = 0.0
self.running_accuracy = 0.0
self.running_count = 0