Fix batch generators last batch size

This commit is contained in:
Corentin 2020-12-04 11:28:12 +09:00
commit 379dd4814f
2 changed files with 10 additions and 8 deletions

View file

@ -1,4 +1,3 @@
import math
import multiprocessing as mp
from multiprocessing import shared_memory
import os
@ -56,8 +55,10 @@ class BatchGenerator:
self.index_list = np.arange(len(self.data))
if shuffle or initial_shuffle:
np.random.shuffle(self.index_list)
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
self.step_per_epoch = len(self.index_list) // self.batch_size
self.last_batch_size = len(self.index_list) % self.batch_size
if self.last_batch_size == 0:
self.last_batch_size = self.batch_size
self.epoch = 0
self.global_step = 0
@ -318,9 +319,9 @@ class BatchGenerator:
if __name__ == '__main__':
def test():
data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.uint8)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=np.uint8)
label = np.array(
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18, .19], dtype=np.uint8)
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
for data_processor in [None, lambda x:x]:
for prefetch in [False, True]:
@ -328,7 +329,7 @@ if __name__ == '__main__':
print(f'{data_processor=} {prefetch=} {num_workers=}')
with BatchGenerator(data, label, 5, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(9):
for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch()
print()