From 63592580613c0fc15f97b727b64b40d3d869db44 Mon Sep 17 00:00:00 2001 From: Corentin Risselin Date: Wed, 4 Jan 2023 16:58:48 +0900 Subject: [PATCH] Small fix, making h5py optional --- utils/batch_generator.py | 4 ++-- utils/sequence_batch_generator.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/utils/batch_generator.py b/utils/batch_generator.py index b8257ef..ad02b55 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -3,7 +3,6 @@ from multiprocessing import shared_memory import os from typing import Callable, Iterable, Optional, Tuple -import h5py import numpy as np @@ -20,6 +19,7 @@ class BatchGenerator: self.num_workers = num_workers self.flip_data = flip_data self.pipeline = pipeline + self.process_id = 'NA' if not preload: self.data_processor = data_processor @@ -37,6 +37,7 @@ class BatchGenerator: os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): + import h5py with h5py.File(save_path, 'r') as h5_file: self.data = np.asarray(h5_file['data']) self.label = np.asarray(h5_file['label']) @@ -80,7 +81,6 @@ class BatchGenerator: self.batch_data = first_data self.batch_label = first_label - self.process_id = 'NA' if self.prefetch or self.num_workers > 1: self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_indices = np.ndarray( diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 005caf3..947a0b2 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -3,7 +3,6 @@ from multiprocessing import shared_memory import os from typing import Callable, Iterable, Optional -import h5py import numpy as np try: @@ -25,12 +24,13 @@ class SequenceGenerator(BatchGenerator): self.prefetch = prefetch and not preload self.num_workers = num_workers self.pipeline = pipeline + self.process_id = 'NA' if not preload: self.data_processor = data_processor self.label_processor = label_processor - self.data = np.asarray(data) - self.label = np.asarray(label) + self.data = np.asarray(data, dtype=np.object) + self.label = np.asarray(label, dtype=np.object) else: self.data_processor = None self.label_processor = None @@ -42,6 +42,7 @@ class SequenceGenerator(BatchGenerator): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): + import h5py with h5py.File(save_path, 'r') as h5_file: data_len = np.asarray(h5_file['data_len']) self.data = [] @@ -49,22 +50,23 @@ class SequenceGenerator(BatchGenerator): for sequence_index in range(data_len): self.data.append(np.asarray(h5_file[f'data_{sequence_index}'])) self.label.append(np.asarray(h5_file[f'label_{sequence_index}'])) - self.data = np.asarray(self.data) - self.label = np.asarray(self.label) + self.data = np.asarray(self.data, dtype=np.object) + self.label = np.asarray(self.label, dtype=np.object) else: if data_processor: self.data = np.asarray( [np.asarray([data_processor(entry) for entry in serie]) for serie in data], dtype=np.object if len(data) > 1 else None) else: - self.data = np.asarray(data) + self.data = np.asarray(data, dtype=np.object) if label_processor: self.label = np.asarray( [np.asarray([label_processor(entry) for entry in serie]) for serie in label], dtype=np.object if len(label) > 1 else None) else: - self.label = np.asarray(label) + self.label = np.asarray(label, dtype=np.object) if save: + import h5py with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data_len', data=len(self.data)) for sequence_index in range(len(self.data)): @@ -133,7 +135,6 @@ class SequenceGenerator(BatchGenerator): self.batch_data = first_data self.batch_label = first_label - self.process_id = 'NA' if self.prefetch or self.num_workers > 1: self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_indices = np.ndarray(