From 164692da9a75744db770836111ae4c63ac6ed7c3 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 31 May 2018 11:00:40 +0800 Subject: [PATCH] drop the last batch, if the size of last batch is not equal to batch_size --- python/paddle/batch.py | 6 ++++-- python/paddle/v2/minibatch.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/batch.py b/python/paddle/batch.py index 317cf037c..d48c54fcb 100644 --- a/python/paddle/batch.py +++ b/python/paddle/batch.py @@ -15,7 +15,7 @@ __all__ = ['batch'] -def batch(reader, batch_size): +def batch(reader, batch_size, drop_last=False): """ Create a batched reader. @@ -23,6 +23,8 @@ def batch(reader, batch_size): :type reader: callable :param batch_size: size of each mini-batch :type batch_size: int + :param drop_last: drop the last batch, if the size of last batch is not equal to batch_size. + :type drop_last: bool :return: the batched reader. :rtype: callable """ @@ -35,7 +37,7 @@ def batch(reader, batch_size): if len(b) == batch_size: yield b b = [] - if b: + if drop_last == False and len(b) != 0: yield b return batch_reader diff --git a/python/paddle/v2/minibatch.py b/python/paddle/v2/minibatch.py index 317cf037c..d48c54fcb 100644 --- a/python/paddle/v2/minibatch.py +++ b/python/paddle/v2/minibatch.py @@ -15,7 +15,7 @@ __all__ = ['batch'] -def batch(reader, batch_size): +def batch(reader, batch_size, drop_last=False): """ Create a batched reader. @@ -23,6 +23,8 @@ def batch(reader, batch_size): :type reader: callable :param batch_size: size of each mini-batch :type batch_size: int + :param drop_last: drop the last batch, if the size of last batch is not equal to batch_size. + :type drop_last: bool :return: the batched reader. :rtype: callable """ @@ -35,7 +37,7 @@ def batch(reader, batch_size): if len(b) == batch_size: yield b b = [] - if b: + if drop_last == False and len(b) != 0: yield b return batch_reader -- GitLab