Trainer save source parameter and batch generator fix
This commit is contained in:
parent
8ecf175265
commit
24dea68044
1 changed files with 4 additions and 7 deletions
|
|
@ -59,6 +59,8 @@ class BatchGenerator:
|
||||||
self.last_batch_size = len(self.index_list) % self.batch_size
|
self.last_batch_size = len(self.index_list) % self.batch_size
|
||||||
if self.last_batch_size == 0:
|
if self.last_batch_size == 0:
|
||||||
self.last_batch_size = self.batch_size
|
self.last_batch_size = self.batch_size
|
||||||
|
else:
|
||||||
|
self.step_per_epoch += 1
|
||||||
|
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
@ -194,10 +196,8 @@ class BatchGenerator:
|
||||||
self.current_cache = 0
|
self.current_cache = 0
|
||||||
parent_cache_data = self.cache_data
|
parent_cache_data = self.cache_data
|
||||||
parent_cache_label = self.cache_label
|
parent_cache_label = self.cache_label
|
||||||
cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)
|
self.cache_data = [np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)]
|
||||||
cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)
|
self.cache_label = [np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)]
|
||||||
self.cache_data = [cache_data]
|
|
||||||
self.cache_label = [cache_label]
|
|
||||||
self.index_list[:] = self.cache_indices
|
self.index_list[:] = self.cache_indices
|
||||||
pipe = self.worker_pipes[worker_index][1]
|
pipe = self.worker_pipes[worker_index][1]
|
||||||
|
|
||||||
|
|
@ -208,8 +208,6 @@ class BatchGenerator:
|
||||||
continue
|
continue
|
||||||
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
|
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
|
||||||
|
|
||||||
self.cache_data[0] = cache_data[:self.batch_size]
|
|
||||||
self.cache_label[0] = cache_label[:self.batch_size]
|
|
||||||
self._next_batch()
|
self._next_batch()
|
||||||
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
|
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
|
||||||
self.current_cache][:self.batch_size]
|
self.current_cache][:self.batch_size]
|
||||||
|
|
@ -338,7 +336,6 @@ if __name__ == '__main__':
|
||||||
for _ in range(19):
|
for _ in range(19):
|
||||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||||
batch_generator.next_batch()
|
batch_generator.next_batch()
|
||||||
raise KeyboardInterrupt
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
test()
|
test()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue