diff --git a/utils/batch_generator.py b/utils/batch_generator.py index 2edaed6..aeada53 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -1,4 +1,3 @@ -import math import multiprocessing as mp from multiprocessing import shared_memory import os @@ -56,8 +55,10 @@ class BatchGenerator: self.index_list = np.arange(len(self.data)) if shuffle or initial_shuffle: np.random.shuffle(self.index_list) - self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) + self.step_per_epoch = len(self.index_list) // self.batch_size self.last_batch_size = len(self.index_list) % self.batch_size + if self.last_batch_size == 0: + self.last_batch_size = self.batch_size self.epoch = 0 self.global_step = 0 @@ -318,9 +319,9 @@ class BatchGenerator: if __name__ == '__main__': def test(): data = np.array( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.uint8) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=np.uint8) label = np.array( - [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18, .19], dtype=np.uint8) + [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8) for data_processor in [None, lambda x:x]: for prefetch in [False, True]: @@ -328,7 +329,7 @@ if __name__ == '__main__': print(f'{data_processor=} {prefetch=} {num_workers=}') with BatchGenerator(data, label, 5, data_processor=data_processor, prefetch=prefetch, num_workers=num_workers) as batch_generator: - for _ in range(9): + for _ in range(19): print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) batch_generator.next_batch() print() diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index abae375..2a30aab 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -1,4 +1,3 @@ -import math import multiprocessing as mp from multiprocessing import shared_memory import os @@ -81,8 +80,10 @@ class SequenceGenerator(BatchGenerator): self.index_list = np.concatenate(self.index_list, axis=0) if shuffle or initial_shuffle: np.random.shuffle(self.index_list) - self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) + self.step_per_epoch = len(self.index_list) // self.batch_size self.last_batch_size = len(self.index_list) % self.batch_size + if self.last_batch_size == 0: + self.last_batch_size = self.batch_size self.epoch = 0 self.global_step = 0 @@ -202,7 +203,7 @@ if __name__ == '__main__': for prefetch in [False, True]: for num_workers in [1, 2]: print(f'{data_processor=} {prefetch=} {num_workers=}') - with SequenceGenerator(data, label, 5, 3, data_processor=data_processor, + with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, prefetch=prefetch, num_workers=num_workers) as batch_generator: for _ in range(9): print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)