Small fix, making h5py optional

This commit is contained in:
Corentin Risselin 2023-01-04 16:58:48 +09:00
commit 6359258061
2 changed files with 11 additions and 10 deletions

View file

@ -3,7 +3,6 @@ from multiprocessing import shared_memory
import os
from typing import Callable, Iterable, Optional, Tuple
import h5py
import numpy as np
@ -20,6 +19,7 @@ class BatchGenerator:
self.num_workers = num_workers
self.flip_data = flip_data
self.pipeline = pipeline
self.process_id = 'NA'
if not preload:
self.data_processor = data_processor
@ -37,6 +37,7 @@ class BatchGenerator:
os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save_path):
import h5py
with h5py.File(save_path, 'r') as h5_file:
self.data = np.asarray(h5_file['data'])
self.label = np.asarray(h5_file['label'])
@ -80,7 +81,6 @@ class BatchGenerator:
self.batch_data = first_data
self.batch_label = first_label
self.process_id = 'NA'
if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray(