From 24dea680440c59093ad7723f9efbadf8765c028f Mon Sep 17 00:00:00 2001 From: Corentin Date: Fri, 19 Feb 2021 13:35:51 +0900 Subject: [PATCH] Trainer save source parameter and batch generator fix --- utils/batch_generator.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/utils/batch_generator.py b/utils/batch_generator.py index 91b70e4..a39e396 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -59,6 +59,8 @@ class BatchGenerator: self.last_batch_size = len(self.index_list) % self.batch_size if self.last_batch_size == 0: self.last_batch_size = self.batch_size + else: + self.step_per_epoch += 1 self.epoch = 0 self.global_step = 0 @@ -194,10 +196,8 @@ class BatchGenerator: self.current_cache = 0 parent_cache_data = self.cache_data parent_cache_label = self.cache_label - cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype) - cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype) - self.cache_data = [cache_data] - self.cache_label = [cache_label] + self.cache_data = [np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)] + self.cache_label = [np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)] self.index_list[:] = self.cache_indices pipe = self.worker_pipes[worker_index][1] @@ -208,8 +208,6 @@ class BatchGenerator: continue self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy() - self.cache_data[0] = cache_data[:self.batch_size] - self.cache_label[0] = cache_label[:self.batch_size] self._next_batch() parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[ self.current_cache][:self.batch_size] @@ -338,7 +336,6 @@ if __name__ == '__main__': for _ in range(19): print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) batch_generator.next_batch() - raise KeyboardInterrupt print() test()