未验证 提交 6915d51c 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #11062 from chengduoZH/refine_batch_py

Drop the last batch, if the size of last batch is not equal to batch_size.
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
__all__ = ['batch'] __all__ = ['batch']
def batch(reader, batch_size): def batch(reader, batch_size, drop_last=False):
""" """
Create a batched reader. Create a batched reader.
...@@ -23,6 +23,8 @@ def batch(reader, batch_size): ...@@ -23,6 +23,8 @@ def batch(reader, batch_size):
:type reader: callable :type reader: callable
:param batch_size: size of each mini-batch :param batch_size: size of each mini-batch
:type batch_size: int :type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader. :return: the batched reader.
:rtype: callable :rtype: callable
""" """
...@@ -35,7 +37,7 @@ def batch(reader, batch_size): ...@@ -35,7 +37,7 @@ def batch(reader, batch_size):
if len(b) == batch_size: if len(b) == batch_size:
yield b yield b
b = [] b = []
if b: if drop_last == False and len(b) != 0:
yield b yield b
return batch_reader return batch_reader
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
__all__ = ['batch'] __all__ = ['batch']
def batch(reader, batch_size): def batch(reader, batch_size, drop_last=False):
""" """
Create a batched reader. Create a batched reader.
...@@ -23,6 +23,8 @@ def batch(reader, batch_size): ...@@ -23,6 +23,8 @@ def batch(reader, batch_size):
:type reader: callable :type reader: callable
:param batch_size: size of each mini-batch :param batch_size: size of each mini-batch
:type batch_size: int :type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader. :return: the batched reader.
:rtype: callable :rtype: callable
""" """
...@@ -35,7 +37,7 @@ def batch(reader, batch_size): ...@@ -35,7 +37,7 @@ def batch(reader, batch_size):
if len(b) == batch_size: if len(b) == batch_size:
yield b yield b
b = [] b = []
if b: if drop_last == False and len(b) != 0:
yield b yield b
return batch_reader return batch_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册