From 649ae2700e94233fe58a2fcdc18a8ee59f40f335 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 3 Apr 2018 14:48:14 +0800 Subject: [PATCH] fix bugs --- python/paddle/fluid/layers/io.py | 30 ++++++++++++------- .../tests/unittests/test_recordio_reader.py | 4 +-- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index f6bd3c7d0a..fb5bb6bcbc 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -259,17 +259,27 @@ def _copy_reader_var_(block, var): def _copy_reader_create_op_(block, op): - def _find_vars_(block, name_list): - res = {} - for n in name_list: - var = block.var(n) - res[n] = var - return res - - input_map = _find_vars_(block, op.input_names) - output_map = _find_vars_(block, op.output_names) + input_param_names = op.input_names + new_input_map = {} + for param_name in input_param_names: + new_input_map[param_name] = [] + arg_names = op.input(param_name) + for arg_name in arg_names: + new_input_map[param_name].append(block.var(arg_name)) + + output_param_names = op.output_names + new_output_map = {} + for param_name in output_param_names: + new_output_map[param_name] = [] + arg_names = op.output(param_name) + for arg_name in arg_names: + new_output_map[param_name].append(block.var(arg_name)) + new_op = block.append_op( - type=op.type, inputs=input_map, outputs=output_map, attrs=op.attrs) + type=op.type, + inputs=new_input_map, + outputs=new_output_map, + attrs=op.attrs) return new_op diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 640264d82f..24a0074d9b 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -15,8 +15,8 @@ import unittest import paddle.fluid as fluid -import paddle -import paddle.dataset.mnist as mnist +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist class TestRecordIO(unittest.TestCase): -- GitLab