Trainer last summary fix + memory utils

This commit is contained in:
Corentin 2021-02-25 02:18:02 +09:00
commit 50c395a07f
3 changed files with 54 additions and 3 deletions

View file

@ -28,11 +28,26 @@ def parameter_summary(network: torch.nn.Module) -> List[Tuple[str, Tuple[int], s
def resource_usage() -> Tuple[int, str]:
memory_peak = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
return memory_peak, gpu_used_memory()
def gpu_used_memory() -> str:
gpu_memory = subprocess.check_output(
'nvidia-smi --query-gpu=memory.used --format=csv,noheader', shell=True).decode()
'nvidia-smi --query-gpu=memory.used --format=csv,noheader', shell=True).decode().strip()
if 'CUDA_VISIBLE_DEVICES' in os.environ:
gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])]
else:
gpu_memory = ' '.join(gpu_memory.split('\n'))
gpu_memory = ','.join(gpu_memory.split('\n'))
return memory_peak, gpu_memory
return gpu_memory
def gpu_total_memory() -> str:
gpu_memory = subprocess.check_output(
'nvidia-smi --query-gpu=memory.total --format=csv,noheader', shell=True).decode().strip()
if 'CUDA_VISIBLE_DEVICES' in os.environ:
gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])]
else:
gpu_memory = ','.join(gpu_memory.split('\n'))
return gpu_memory

View file

@ -178,6 +178,23 @@ class Trainer:
if self.verbose:
print()
# Small training loop for last metrics
for _ in range(20):
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.benchmark_step += 1
self.save_summaries(force_summary=True)
train_stop_time = time.time()
self.writer_train.close()

View file

@ -6,3 +6,22 @@ def human_size(byte_count: int) -> str:
break
amount /= 1024.0
return f'{amount:.2f}{unit}B'
def human_to_bytes(text: str) -> float:
split_index = 0
while '0' <= text[split_index] <= '9':
split_index += 1
if split_index == len(text):
return float(text)
amount = float(text[:split_index])
unit = text[split_index:].strip()
if not unit:
return amount
if unit not in ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']:
raise RuntimeError(f'Unrecognized unit : {unit}')
for final_unit in ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']:
amount *= 1024.0
if unit == final_unit:
return amount