diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 81b330a60522eaccfa24af5bbed71101aea26f48..d0aca5eb1d94a36a66a2e1c146adb54851baef71 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -113,7 +113,7 @@ def main(): test_creator = paddle.dataset.mnist.test() test_data = [] for item in test_creator(): - test_data.append(item[0]) + test_data.append((item[0], )) if len(test_data) == 100: break diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 7c079a0d32df6f6ae5399b1bba86dcc5db9a19dc..7d7dc82de987cb23d12c411c08e0e529afefe58b 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -43,10 +43,7 @@ class Inference(object): def __reader_impl__(): for each_sample in input: - if len(reader_dict) == 1: - yield [each_sample] - else: - yield each_sample + yield each_sample reader = minibatch.batch(__reader_impl__, batch_size=batch_size) else: