提交 0d2d419a 编写于 作者: Y Yu Yang

Follow comments

上级 d5365bb7
...@@ -113,7 +113,7 @@ def main(): ...@@ -113,7 +113,7 @@ def main():
test_creator = paddle.dataset.mnist.test() test_creator = paddle.dataset.mnist.test()
test_data = [] test_data = []
for item in test_creator(): for item in test_creator():
test_data.append(item[0]) test_data.append((item[0], ))
if len(test_data) == 100: if len(test_data) == 100:
break break
......
...@@ -43,10 +43,7 @@ class Inference(object): ...@@ -43,10 +43,7 @@ class Inference(object):
def __reader_impl__(): def __reader_impl__():
for each_sample in input: for each_sample in input:
if len(reader_dict) == 1: yield each_sample
yield [each_sample]
else:
yield each_sample
reader = minibatch.batch(__reader_impl__, batch_size=batch_size) reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册