Small fix, making h5py optional

This commit is contained in:
Corentin Risselin 2023-01-04 16:58:48 +09:00
commit 6359258061
2 changed files with 11 additions and 10 deletions

View file

@ -3,7 +3,6 @@ from multiprocessing import shared_memory
import os
from typing import Callable, Iterable, Optional, Tuple
import h5py
import numpy as np
@ -20,6 +19,7 @@ class BatchGenerator:
self.num_workers = num_workers
self.flip_data = flip_data
self.pipeline = pipeline
self.process_id = 'NA'
if not preload:
self.data_processor = data_processor
@ -37,6 +37,7 @@ class BatchGenerator:
os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save_path):
import h5py
with h5py.File(save_path, 'r') as h5_file:
self.data = np.asarray(h5_file['data'])
self.label = np.asarray(h5_file['label'])
@ -80,7 +81,6 @@ class BatchGenerator:
self.batch_data = first_data
self.batch_label = first_label
self.process_id = 'NA'
if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray(

View file

@ -3,7 +3,6 @@ from multiprocessing import shared_memory
import os
from typing import Callable, Iterable, Optional
import h5py
import numpy as np
try:
@ -25,12 +24,13 @@ class SequenceGenerator(BatchGenerator):
self.prefetch = prefetch and not preload
self.num_workers = num_workers
self.pipeline = pipeline
self.process_id = 'NA'
if not preload:
self.data_processor = data_processor
self.label_processor = label_processor
self.data = np.asarray(data)
self.label = np.asarray(label)
self.data = np.asarray(data, dtype=np.object)
self.label = np.asarray(label, dtype=np.object)
else:
self.data_processor = None
self.label_processor = None
@ -42,6 +42,7 @@ class SequenceGenerator(BatchGenerator):
os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save_path):
import h5py
with h5py.File(save_path, 'r') as h5_file:
data_len = np.asarray(h5_file['data_len'])
self.data = []
@ -49,22 +50,23 @@ class SequenceGenerator(BatchGenerator):
for sequence_index in range(data_len):
self.data.append(np.asarray(h5_file[f'data_{sequence_index}']))
self.label.append(np.asarray(h5_file[f'label_{sequence_index}']))
self.data = np.asarray(self.data)
self.label = np.asarray(self.label)
self.data = np.asarray(self.data, dtype=np.object)
self.label = np.asarray(self.label, dtype=np.object)
else:
if data_processor:
self.data = np.asarray(
[np.asarray([data_processor(entry) for entry in serie]) for serie in data],
dtype=np.object if len(data) > 1 else None)
else:
self.data = np.asarray(data)
self.data = np.asarray(data, dtype=np.object)
if label_processor:
self.label = np.asarray(
[np.asarray([label_processor(entry) for entry in serie]) for serie in label],
dtype=np.object if len(label) > 1 else None)
else:
self.label = np.asarray(label)
self.label = np.asarray(label, dtype=np.object)
if save:
import h5py
with h5py.File(save_path, 'w') as h5_file:
h5_file.create_dataset('data_len', data=len(self.data))
for sequence_index in range(len(self.data)):
@ -133,7 +135,6 @@ class SequenceGenerator(BatchGenerator):
self.batch_data = first_data
self.batch_label = first_label
self.process_id = 'NA'
if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray(