Accuracy function implemented in Trainer
This commit is contained in:
parent
b43b8b14d6
commit
d315f342a4
1 changed files with 25 additions and 11 deletions
36
trainer.py
36
trainer.py
|
|
@ -43,6 +43,7 @@ class Trainer:
|
|||
self.network = network
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
|
||||
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
|
||||
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
|
||||
|
||||
|
|
@ -90,7 +91,9 @@ class Trainer:
|
|||
self.processed_inputs = processed_inputs
|
||||
self.network_outputs = processed_inputs # Placeholder
|
||||
self.train_loss = 0.0
|
||||
self.train_accuracy = 0.0
|
||||
self.running_loss = 0.0
|
||||
self.running_accuracy = 0.0
|
||||
self.running_count = 0
|
||||
self.benchmark_step = 0
|
||||
self.benchmark_time = time.time()
|
||||
|
|
@ -99,14 +102,14 @@ class Trainer:
|
|||
self,
|
||||
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
||||
loss: float):
|
||||
loss: float, accuracy: float):
|
||||
pass
|
||||
|
||||
def val_step_callback(
|
||||
self,
|
||||
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
||||
loss: float):
|
||||
loss: float, accuracy: float):
|
||||
pass
|
||||
|
||||
def summary_callback(
|
||||
|
|
@ -152,18 +155,20 @@ class Trainer:
|
|||
|
||||
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
||||
self.network_outputs = self.network(self.processed_inputs)
|
||||
loss = self.criterion(
|
||||
self.network_outputs,
|
||||
self.batch_labels if not self.data_is_label else self.processed_inputs)
|
||||
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
||||
loss = self.criterion(self.network_outputs, labels)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.train_loss = loss.item()
|
||||
self.train_accuracy = self.accuracy_fn(
|
||||
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
|
||||
self.running_loss += self.train_loss
|
||||
self.running_accuracy += self.train_accuracy
|
||||
self.running_count += len(self.batch_generator_train.batch_data)
|
||||
self.train_step_callback(
|
||||
self.batch_inputs, self.processed_inputs, self.batch_labels,
|
||||
self.network_outputs, self.train_loss)
|
||||
self.batch_inputs, self.processed_inputs, labels,
|
||||
self.network_outputs, self.train_loss, self.train_accuracy)
|
||||
|
||||
self.benchmark_step += 1
|
||||
self.save_summaries()
|
||||
|
|
@ -192,6 +197,7 @@ class Trainer:
|
|||
if self.batch_generator_val.step != 0:
|
||||
self.batch_generator_val.skip_epoch()
|
||||
val_loss = 0.0
|
||||
val_accuracy = 0.0
|
||||
val_count = 0
|
||||
self.network.train(False)
|
||||
with torch.no_grad():
|
||||
|
|
@ -204,13 +210,15 @@ class Trainer:
|
|||
|
||||
val_pre_process = self.pre_process(val_inputs)
|
||||
val_outputs = self.network(val_pre_process)
|
||||
loss = self.criterion(
|
||||
val_outputs,
|
||||
val_labels if not self.data_is_label else val_pre_process).item()
|
||||
val_labels = val_labels if not self.data_is_label else val_pre_process
|
||||
loss = self.criterion(val_outputs, val_labels).item()
|
||||
accuracy = self.accuracy_fn(
|
||||
val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0
|
||||
val_loss += loss
|
||||
val_accuracy += accuracy
|
||||
val_count += len(self.batch_generator_val.batch_data)
|
||||
self.val_step_callback(
|
||||
val_inputs, val_pre_process, val_labels, val_outputs, loss)
|
||||
val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy)
|
||||
|
||||
self.batch_generator_val.next_batch()
|
||||
self.network.train(True)
|
||||
|
|
@ -221,6 +229,11 @@ class Trainer:
|
|||
'loss', self.running_loss / self.running_count, global_step=global_step)
|
||||
self.writer_val.add_scalar(
|
||||
'loss', val_loss / val_count, global_step=global_step)
|
||||
if self.accuracy_fn is not None:
|
||||
self.writer_train.add_scalar(
|
||||
'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step)
|
||||
self.writer_val.add_scalar(
|
||||
'error', 1 - (val_accuracy / val_count), global_step=global_step)
|
||||
self.summary_callback(
|
||||
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
|
||||
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
|
||||
|
|
@ -240,4 +253,5 @@ class Trainer:
|
|||
self.benchmark_time = time.time()
|
||||
self.benchmark_step = 0
|
||||
self.running_loss = 0.0
|
||||
self.running_accuracy = 0.0
|
||||
self.running_count = 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue