Small fix, making h5py optional
This commit is contained in:
parent
1bac46219b
commit
6359258061
2 changed files with 11 additions and 10 deletions
|
|
@ -3,7 +3,6 @@ from multiprocessing import shared_memory
|
|||
import os
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
@ -20,6 +19,7 @@ class BatchGenerator:
|
|||
self.num_workers = num_workers
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
self.process_id = 'NA'
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
|
|
@ -37,6 +37,7 @@ class BatchGenerator:
|
|||
os.makedirs(os.path.dirname(save_path))
|
||||
|
||||
if save and os.path.exists(save_path):
|
||||
import h5py
|
||||
with h5py.File(save_path, 'r') as h5_file:
|
||||
self.data = np.asarray(h5_file['data'])
|
||||
self.label = np.asarray(h5_file['label'])
|
||||
|
|
@ -80,7 +81,6 @@ class BatchGenerator:
|
|||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch or self.num_workers > 1:
|
||||
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||
self.cache_indices = np.ndarray(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from multiprocessing import shared_memory
|
|||
import os
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
|
|
@ -25,12 +24,13 @@ class SequenceGenerator(BatchGenerator):
|
|||
self.prefetch = prefetch and not preload
|
||||
self.num_workers = num_workers
|
||||
self.pipeline = pipeline
|
||||
self.process_id = 'NA'
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
self.data = np.asarray(data)
|
||||
self.label = np.asarray(label)
|
||||
self.data = np.asarray(data, dtype=np.object)
|
||||
self.label = np.asarray(label, dtype=np.object)
|
||||
else:
|
||||
self.data_processor = None
|
||||
self.label_processor = None
|
||||
|
|
@ -42,6 +42,7 @@ class SequenceGenerator(BatchGenerator):
|
|||
os.makedirs(os.path.dirname(save_path))
|
||||
|
||||
if save and os.path.exists(save_path):
|
||||
import h5py
|
||||
with h5py.File(save_path, 'r') as h5_file:
|
||||
data_len = np.asarray(h5_file['data_len'])
|
||||
self.data = []
|
||||
|
|
@ -49,22 +50,23 @@ class SequenceGenerator(BatchGenerator):
|
|||
for sequence_index in range(data_len):
|
||||
self.data.append(np.asarray(h5_file[f'data_{sequence_index}']))
|
||||
self.label.append(np.asarray(h5_file[f'label_{sequence_index}']))
|
||||
self.data = np.asarray(self.data)
|
||||
self.label = np.asarray(self.label)
|
||||
self.data = np.asarray(self.data, dtype=np.object)
|
||||
self.label = np.asarray(self.label, dtype=np.object)
|
||||
else:
|
||||
if data_processor:
|
||||
self.data = np.asarray(
|
||||
[np.asarray([data_processor(entry) for entry in serie]) for serie in data],
|
||||
dtype=np.object if len(data) > 1 else None)
|
||||
else:
|
||||
self.data = np.asarray(data)
|
||||
self.data = np.asarray(data, dtype=np.object)
|
||||
if label_processor:
|
||||
self.label = np.asarray(
|
||||
[np.asarray([label_processor(entry) for entry in serie]) for serie in label],
|
||||
dtype=np.object if len(label) > 1 else None)
|
||||
else:
|
||||
self.label = np.asarray(label)
|
||||
self.label = np.asarray(label, dtype=np.object)
|
||||
if save:
|
||||
import h5py
|
||||
with h5py.File(save_path, 'w') as h5_file:
|
||||
h5_file.create_dataset('data_len', data=len(self.data))
|
||||
for sequence_index in range(len(self.data)):
|
||||
|
|
@ -133,7 +135,6 @@ class SequenceGenerator(BatchGenerator):
|
|||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch or self.num_workers > 1:
|
||||
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||
self.cache_indices = np.ndarray(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue