From d83fe716042174418a58d01049de6b42f5d33bd9 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Sun, 12 Feb 2023 03:17:28 +0000 Subject: [PATCH] add input map check --- python/paddle/incubate/autograd/primx.py | 2 +- python/paddle/incubate/autograd/utils.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index a69bce3c37..6a5e4ae6fc 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]): del block.vars[var_name] block._sync_with_cpp() - # composite ops may contain other ops, thus, call _lower_composite again. + # composite ops may contain other composite ops, thus, call _lower_composite again. if change: _lower_composite(block, blacklist) return diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 90bdb78336..c011c7495e 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -169,6 +169,7 @@ def _get_args_values(op, phi_name): arg_type, arg_name = _solve_arg(item) op_content = op_map[op.type] if arg_type in ("Tensor", "Tensor[]"): + # assume Tensor type must belong to inputs if ( "inputs" in op_content.keys() and arg_name in op_content["inputs"].keys() @@ -182,7 +183,9 @@ def _get_args_values(op, phi_name): "attrs" in op_content.keys() and arg_name in op_content["attrs"].keys() ): - attrs.append(op.attr(op_content["attrs"][arg_name])) + arg_name = op_content["attrs"][arg_name] + if arg_name not in op.attr_names: + attrs.append(None) else: attrs.append(op.attr(arg_name)) @@ -203,7 +206,12 @@ def prepare_python_api_arguments(op): else: phi_name = op.type inputs, attrs = _get_args_values(op, phi_name) - res = [get_var_block(op.block, op.input(n)) for n in inputs] + res = [] + for item in inputs: + if item in op.input_names: + res.append(get_var_block(op.block, op.input(item))) + else: + res.append(None) if attrs: res.extend(attrs) return res -- GitLab