From 0d2d419a598ceb9deb488e5e3d2f1384635ee9be Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 7 Mar 2017 16:25:09 +0800 Subject: [PATCH] Follow comments --- demo/mnist/api_train_v2.py | 2 +- python/paddle/v2/inference.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 81b330a6052..d0aca5eb1d9 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -113,7 +113,7 @@ def main(): test_creator = paddle.dataset.mnist.test() test_data = [] for item in test_creator(): - test_data.append(item[0]) + test_data.append((item[0], )) if len(test_data) == 100: break diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 7c079a0d32d..7d7dc82de98 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -43,10 +43,7 @@ class Inference(object): def __reader_impl__(): for each_sample in input: - if len(reader_dict) == 1: - yield [each_sample] - else: - yield each_sample + yield each_sample reader = minibatch.batch(__reader_impl__, batch_size=batch_size) else: -- GitLab