提交 649ae270 编写于 作者: F fengjiayi

fix bugs

上级 44d5f42a
...@@ -259,17 +259,27 @@ def _copy_reader_var_(block, var): ...@@ -259,17 +259,27 @@ def _copy_reader_var_(block, var):
def _copy_reader_create_op_(block, op): def _copy_reader_create_op_(block, op):
def _find_vars_(block, name_list): input_param_names = op.input_names
res = {} new_input_map = {}
for n in name_list: for param_name in input_param_names:
var = block.var(n) new_input_map[param_name] = []
res[n] = var arg_names = op.input(param_name)
return res for arg_name in arg_names:
new_input_map[param_name].append(block.var(arg_name))
input_map = _find_vars_(block, op.input_names)
output_map = _find_vars_(block, op.output_names) output_param_names = op.output_names
new_output_map = {}
for param_name in output_param_names:
new_output_map[param_name] = []
arg_names = op.output(param_name)
for arg_name in arg_names:
new_output_map[param_name].append(block.var(arg_name))
new_op = block.append_op( new_op = block.append_op(
type=op.type, inputs=input_map, outputs=output_map, attrs=op.attrs) type=op.type,
inputs=new_input_map,
outputs=new_output_map,
attrs=op.attrs)
return new_op return new_op
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle.v2 as paddle
import paddle.dataset.mnist as mnist import paddle.v2.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase): class TestRecordIO(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册