提交 d83fe716 编写于 作者: C cyber-pioneer

add input map check

上级 3e85dbb6
......@@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]):
del block.vars[var_name]
block._sync_with_cpp()
# composite ops may contain other ops, thus, call _lower_composite again.
# composite ops may contain other composite ops, thus, call _lower_composite again.
if change:
_lower_composite(block, blacklist)
return
......
......@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name):
arg_type, arg_name = _solve_arg(item)
op_content = op_map[op.type]
if arg_type in ("Tensor", "Tensor[]"):
# assume Tensor type must belong to inputs
if (
"inputs" in op_content.keys()
and arg_name in op_content["inputs"].keys()
......@@ -182,7 +183,9 @@ def _get_args_values(op, phi_name):
"attrs" in op_content.keys()
and arg_name in op_content["attrs"].keys()
):
attrs.append(op.attr(op_content["attrs"][arg_name]))
arg_name = op_content["attrs"][arg_name]
if arg_name not in op.attr_names:
attrs.append(None)
else:
attrs.append(op.attr(arg_name))
......@@ -203,7 +206,12 @@ def prepare_python_api_arguments(op):
else:
phi_name = op.type
inputs, attrs = _get_args_values(op, phi_name)
res = [get_var_block(op.block, op.input(n)) for n in inputs]
res = []
for item in inputs:
if item in op.input_names:
res.append(get_var_block(op.block, op.input(item)))
else:
res.append(None)
if attrs:
res.extend(attrs)
return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册