未验证 提交 b4dfd080 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #11402 from chengduoZH/refine_batch_func

Change drop_last defalut value
......@@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size, drop_last=False):
def batch(reader, batch_size, drop_last=True):
"""
Create a batched reader.
......
......@@ -96,10 +96,11 @@ def train(use_cuda, train_program, params_dirname):
train_reader = paddle.batch(
paddle.reader.shuffle(
cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
drop_last=False)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
......
......@@ -73,10 +73,11 @@ def train(use_cuda, train_program, params_dirname):
train_reader = paddle.batch(
paddle.reader.shuffle(
cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
drop_last=False)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
......
......@@ -87,7 +87,9 @@ def train(use_cuda, train_program, params_dirname):
def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
test_reader = paddle.batch(
paddle.dataset.imdb.test(word_dict), batch_size=BATCH_SIZE)
paddle.dataset.imdb.test(word_dict),
batch_size=BATCH_SIZE,
drop_last=False)
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['words', 'label'])
......@@ -113,7 +115,8 @@ def train(use_cuda, train_program, params_dirname):
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=25000),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
drop_last=False)
trainer.train(
num_epochs=1,
......
......@@ -56,7 +56,7 @@ BATCH_SIZE = 200
# fix the order of training data
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE)
paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE, drop_last=False)
# train_reader = paddle.batch(
# paddle.reader.shuffle(
......
......@@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size, drop_last=False):
def batch(reader, batch_size, drop_last=True):
"""
Create a batched reader.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册