Fix batch generators last batch size
This commit is contained in:
parent
40ebd0010d
commit
379dd4814f
2 changed files with 10 additions and 8 deletions
|
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
import multiprocessing as mp
|
||||
from multiprocessing import shared_memory
|
||||
import os
|
||||
|
|
@ -56,8 +55,10 @@ class BatchGenerator:
|
|||
self.index_list = np.arange(len(self.data))
|
||||
if shuffle or initial_shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
||||
self.step_per_epoch = len(self.index_list) // self.batch_size
|
||||
self.last_batch_size = len(self.index_list) % self.batch_size
|
||||
if self.last_batch_size == 0:
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
self.epoch = 0
|
||||
self.global_step = 0
|
||||
|
|
@ -318,9 +319,9 @@ class BatchGenerator:
|
|||
if __name__ == '__main__':
|
||||
def test():
|
||||
data = np.array(
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.uint8)
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=np.uint8)
|
||||
label = np.array(
|
||||
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18, .19], dtype=np.uint8)
|
||||
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
|
||||
|
||||
for data_processor in [None, lambda x:x]:
|
||||
for prefetch in [False, True]:
|
||||
|
|
@ -328,7 +329,7 @@ if __name__ == '__main__':
|
|||
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||
with BatchGenerator(data, label, 5, data_processor=data_processor,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(9):
|
||||
for _ in range(19):
|
||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
import multiprocessing as mp
|
||||
from multiprocessing import shared_memory
|
||||
import os
|
||||
|
|
@ -81,8 +80,10 @@ class SequenceGenerator(BatchGenerator):
|
|||
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||
if shuffle or initial_shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
||||
self.step_per_epoch = len(self.index_list) // self.batch_size
|
||||
self.last_batch_size = len(self.index_list) % self.batch_size
|
||||
if self.last_batch_size == 0:
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
self.epoch = 0
|
||||
self.global_step = 0
|
||||
|
|
@ -202,7 +203,7 @@ if __name__ == '__main__':
|
|||
for prefetch in [False, True]:
|
||||
for num_workers in [1, 2]:
|
||||
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||
with SequenceGenerator(data, label, 5, 3, data_processor=data_processor,
|
||||
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(9):
|
||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue