Fix batch generators last batch size

This commit is contained in:
Corentin 2020-12-04 11:28:12 +09:00
commit 379dd4814f
2 changed files with 10 additions and 8 deletions

View file

@ -1,4 +1,3 @@
import math
import multiprocessing as mp import multiprocessing as mp
from multiprocessing import shared_memory from multiprocessing import shared_memory
import os import os
@ -56,8 +55,10 @@ class BatchGenerator:
self.index_list = np.arange(len(self.data)) self.index_list = np.arange(len(self.data))
if shuffle or initial_shuffle: if shuffle or initial_shuffle:
np.random.shuffle(self.index_list) np.random.shuffle(self.index_list)
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) self.step_per_epoch = len(self.index_list) // self.batch_size
self.last_batch_size = len(self.index_list) % self.batch_size self.last_batch_size = len(self.index_list) % self.batch_size
if self.last_batch_size == 0:
self.last_batch_size = self.batch_size
self.epoch = 0 self.epoch = 0
self.global_step = 0 self.global_step = 0
@ -318,9 +319,9 @@ class BatchGenerator:
if __name__ == '__main__': if __name__ == '__main__':
def test(): def test():
data = np.array( data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.uint8) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=np.uint8)
label = np.array( label = np.array(
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18, .19], dtype=np.uint8) [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
for data_processor in [None, lambda x:x]: for data_processor in [None, lambda x:x]:
for prefetch in [False, True]: for prefetch in [False, True]:
@ -328,7 +329,7 @@ if __name__ == '__main__':
print(f'{data_processor=} {prefetch=} {num_workers=}') print(f'{data_processor=} {prefetch=} {num_workers=}')
with BatchGenerator(data, label, 5, data_processor=data_processor, with BatchGenerator(data, label, 5, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator: prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(9): for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch() batch_generator.next_batch()
print() print()

View file

@ -1,4 +1,3 @@
import math
import multiprocessing as mp import multiprocessing as mp
from multiprocessing import shared_memory from multiprocessing import shared_memory
import os import os
@ -81,8 +80,10 @@ class SequenceGenerator(BatchGenerator):
self.index_list = np.concatenate(self.index_list, axis=0) self.index_list = np.concatenate(self.index_list, axis=0)
if shuffle or initial_shuffle: if shuffle or initial_shuffle:
np.random.shuffle(self.index_list) np.random.shuffle(self.index_list)
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) self.step_per_epoch = len(self.index_list) // self.batch_size
self.last_batch_size = len(self.index_list) % self.batch_size self.last_batch_size = len(self.index_list) % self.batch_size
if self.last_batch_size == 0:
self.last_batch_size = self.batch_size
self.epoch = 0 self.epoch = 0
self.global_step = 0 self.global_step = 0
@ -202,7 +203,7 @@ if __name__ == '__main__':
for prefetch in [False, True]: for prefetch in [False, True]:
for num_workers in [1, 2]: for num_workers in [1, 2]:
print(f'{data_processor=} {prefetch=} {num_workers=}') print(f'{data_processor=} {prefetch=} {num_workers=}')
with SequenceGenerator(data, label, 5, 3, data_processor=data_processor, with SequenceGenerator(data, label, 5, 2, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator: prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(9): for _ in range(9):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)