From d315f342a49f96a2919199a1ddb24b65d4770e7e Mon Sep 17 00:00:00 2001 From: Corentin Date: Tue, 26 Jan 2021 01:15:02 +0900 Subject: [PATCH] Accuracy function implemented in Trainer --- trainer.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/trainer.py b/trainer.py index 39742ce..772cef3 100644 --- a/trainer.py +++ b/trainer.py @@ -43,6 +43,7 @@ class Trainer: self.network = network self.optimizer = optimizer self.criterion = criterion + self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30) self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30) @@ -90,7 +91,9 @@ class Trainer: self.processed_inputs = processed_inputs self.network_outputs = processed_inputs # Placeholder self.train_loss = 0.0 + self.train_accuracy = 0.0 self.running_loss = 0.0 + self.running_accuracy = 0.0 self.running_count = 0 self.benchmark_step = 0 self.benchmark_time = time.time() @@ -99,14 +102,14 @@ class Trainer: self, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, batch_labels: torch.Tensor, network_outputs: torch.Tensor, - loss: float): + loss: float, accuracy: float): pass def val_step_callback( self, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, batch_labels: torch.Tensor, network_outputs: torch.Tensor, - loss: float): + loss: float, accuracy: float): pass def summary_callback( @@ -152,18 +155,20 @@ class Trainer: self.processed_inputs = self.train_pre_process(self.batch_inputs) self.network_outputs = self.network(self.processed_inputs) - loss = self.criterion( - self.network_outputs, - self.batch_labels if not self.data_is_label else self.processed_inputs) + labels = self.batch_labels if not self.data_is_label else self.processed_inputs + loss = self.criterion(self.network_outputs, labels) loss.backward() self.optimizer.step() self.train_loss = loss.item() + self.train_accuracy = self.accuracy_fn( + self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0 self.running_loss += self.train_loss + self.running_accuracy += self.train_accuracy self.running_count += len(self.batch_generator_train.batch_data) self.train_step_callback( - self.batch_inputs, self.processed_inputs, self.batch_labels, - self.network_outputs, self.train_loss) + self.batch_inputs, self.processed_inputs, labels, + self.network_outputs, self.train_loss, self.train_accuracy) self.benchmark_step += 1 self.save_summaries() @@ -192,6 +197,7 @@ class Trainer: if self.batch_generator_val.step != 0: self.batch_generator_val.skip_epoch() val_loss = 0.0 + val_accuracy = 0.0 val_count = 0 self.network.train(False) with torch.no_grad(): @@ -204,13 +210,15 @@ class Trainer: val_pre_process = self.pre_process(val_inputs) val_outputs = self.network(val_pre_process) - loss = self.criterion( - val_outputs, - val_labels if not self.data_is_label else val_pre_process).item() + val_labels = val_labels if not self.data_is_label else val_pre_process + loss = self.criterion(val_outputs, val_labels).item() + accuracy = self.accuracy_fn( + val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0 val_loss += loss + val_accuracy += accuracy val_count += len(self.batch_generator_val.batch_data) self.val_step_callback( - val_inputs, val_pre_process, val_labels, val_outputs, loss) + val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy) self.batch_generator_val.next_batch() self.network.train(True) @@ -221,6 +229,11 @@ class Trainer: 'loss', self.running_loss / self.running_count, global_step=global_step) self.writer_val.add_scalar( 'loss', val_loss / val_count, global_step=global_step) + if self.accuracy_fn is not None: + self.writer_train.add_scalar( + 'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step) + self.writer_val.add_scalar( + 'error', 1 - (val_accuracy / val_count), global_step=global_step) self.summary_callback( self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count, val_inputs, val_pre_process, val_labels, val_outputs, val_count) @@ -240,4 +253,5 @@ class Trainer: self.benchmark_time = time.time() self.benchmark_step = 0 self.running_loss = 0.0 + self.running_accuracy = 0.0 self.running_count = 0