提交 d83fe716 编写于 作者: C cyber-pioneer

add input map check

上级 3e85dbb6
...@@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]): ...@@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]):
del block.vars[var_name] del block.vars[var_name]
block._sync_with_cpp() block._sync_with_cpp()
# composite ops may contain other ops, thus, call _lower_composite again. # composite ops may contain other composite ops, thus, call _lower_composite again.
if change: if change:
_lower_composite(block, blacklist) _lower_composite(block, blacklist)
return return
......
...@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name): ...@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name):
arg_type, arg_name = _solve_arg(item) arg_type, arg_name = _solve_arg(item)
op_content = op_map[op.type] op_content = op_map[op.type]
if arg_type in ("Tensor", "Tensor[]"): if arg_type in ("Tensor", "Tensor[]"):
# assume Tensor type must belong to inputs
if ( if (
"inputs" in op_content.keys() "inputs" in op_content.keys()
and arg_name in op_content["inputs"].keys() and arg_name in op_content["inputs"].keys()
...@@ -182,7 +183,9 @@ def _get_args_values(op, phi_name): ...@@ -182,7 +183,9 @@ def _get_args_values(op, phi_name):
"attrs" in op_content.keys() "attrs" in op_content.keys()
and arg_name in op_content["attrs"].keys() and arg_name in op_content["attrs"].keys()
): ):
attrs.append(op.attr(op_content["attrs"][arg_name])) arg_name = op_content["attrs"][arg_name]
if arg_name not in op.attr_names:
attrs.append(None)
else: else:
attrs.append(op.attr(arg_name)) attrs.append(op.attr(arg_name))
...@@ -203,7 +206,12 @@ def prepare_python_api_arguments(op): ...@@ -203,7 +206,12 @@ def prepare_python_api_arguments(op):
else: else:
phi_name = op.type phi_name = op.type
inputs, attrs = _get_args_values(op, phi_name) inputs, attrs = _get_args_values(op, phi_name)
res = [get_var_block(op.block, op.input(n)) for n in inputs] res = []
for item in inputs:
if item in op.input_names:
res.append(get_var_block(op.block, op.input(item)))
else:
res.append(None)
if attrs: if attrs:
res.extend(attrs) res.extend(attrs)
return res return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册