Small fix, making h5py optional
This commit is contained in:
parent
1bac46219b
commit
6359258061
2 changed files with 11 additions and 10 deletions
|
|
@ -3,7 +3,6 @@ from multiprocessing import shared_memory
|
|||
import os
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
@ -20,6 +19,7 @@ class BatchGenerator:
|
|||
self.num_workers = num_workers
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
self.process_id = 'NA'
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
|
|
@ -37,6 +37,7 @@ class BatchGenerator:
|
|||
os.makedirs(os.path.dirname(save_path))
|
||||
|
||||
if save and os.path.exists(save_path):
|
||||
import h5py
|
||||
with h5py.File(save_path, 'r') as h5_file:
|
||||
self.data = np.asarray(h5_file['data'])
|
||||
self.label = np.asarray(h5_file['label'])
|
||||
|
|
@ -80,7 +81,6 @@ class BatchGenerator:
|
|||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch or self.num_workers > 1:
|
||||
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||
self.cache_indices = np.ndarray(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue