Trainer save source parameter and batch generator fix

This commit is contained in:
Corentin 2021-02-19 13:35:51 +09:00
commit 24dea68044

View file

@ -59,6 +59,8 @@ class BatchGenerator:
self.last_batch_size = len(self.index_list) % self.batch_size
if self.last_batch_size == 0:
self.last_batch_size = self.batch_size
else:
self.step_per_epoch += 1
self.epoch = 0
self.global_step = 0
@ -194,10 +196,8 @@ class BatchGenerator:
self.current_cache = 0
parent_cache_data = self.cache_data
parent_cache_label = self.cache_label
cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)
cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)
self.cache_data = [cache_data]
self.cache_label = [cache_label]
self.cache_data = [np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)]
self.cache_label = [np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)]
self.index_list[:] = self.cache_indices
pipe = self.worker_pipes[worker_index][1]
@ -208,8 +208,6 @@ class BatchGenerator:
continue
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
self.cache_data[0] = cache_data[:self.batch_size]
self.cache_label[0] = cache_label[:self.batch_size]
self._next_batch()
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
self.current_cache][:self.batch_size]
@ -338,7 +336,6 @@ if __name__ == '__main__':
for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch()
raise KeyboardInterrupt
print()
test()