diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index a69bce3c37bc43f5c39f8b0ca560c17fecfcafa7..6a5e4ae6fc366ad5197b40cdc487e86ddce005a5 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]): del block.vars[var_name] block._sync_with_cpp() - # composite ops may contain other ops, thus, call _lower_composite again. + # composite ops may contain other composite ops, thus, call _lower_composite again. if change: _lower_composite(block, blacklist) return diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 90bdb78336dc4a257e39c7ae37f3babf2a6cca0f..c011c7495e6ac269ea3eccb3045b3fe21fe292ba 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -169,6 +169,7 @@ def _get_args_values(op, phi_name): arg_type, arg_name = _solve_arg(item) op_content = op_map[op.type] if arg_type in ("Tensor", "Tensor[]"): + # assume Tensor type must belong to inputs if ( "inputs" in op_content.keys() and arg_name in op_content["inputs"].keys() @@ -182,7 +183,9 @@ def _get_args_values(op, phi_name): "attrs" in op_content.keys() and arg_name in op_content["attrs"].keys() ): - attrs.append(op.attr(op_content["attrs"][arg_name])) + arg_name = op_content["attrs"][arg_name] + if arg_name not in op.attr_names: + attrs.append(None) else: attrs.append(op.attr(arg_name)) @@ -203,7 +206,12 @@ def prepare_python_api_arguments(op): else: phi_name = op.type inputs, attrs = _get_args_values(op, phi_name) - res = [get_var_block(op.block, op.input(n)) for n in inputs] + res = [] + for item in inputs: + if item in op.input_names: + res.append(get_var_block(op.block, op.input(item))) + else: + res.append(None) if attrs: res.extend(attrs) return res