提交 164692da 编写于 作者: C chengduoZH

drop the last batch, if the size of last batch is not equal to batch_size

上级 376c948e
......@@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size):
def batch(reader, batch_size, drop_last=False):
"""
Create a batched reader.
......@@ -23,6 +23,8 @@ def batch(reader, batch_size):
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
......@@ -35,7 +37,7 @@ def batch(reader, batch_size):
if len(b) == batch_size:
yield b
b = []
if b:
if drop_last == False and len(b) != 0:
yield b
return batch_reader
......@@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size):
def batch(reader, batch_size, drop_last=False):
"""
Create a batched reader.
......@@ -23,6 +23,8 @@ def batch(reader, batch_size):
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
......@@ -35,7 +37,7 @@ def batch(reader, batch_size):
if len(b) == batch_size:
yield b
b = []
if b:
if drop_last == False and len(b) != 0:
yield b
return batch_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册