From 2590527f9efffd64f60017c29c4b132b303651c4 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Thu, 9 Feb 2023 17:16:08 +0000 Subject: [PATCH] fix composite mean op map --- paddle/phi/api/yaml/op_compat.yaml | 8 +++++++- python/paddle/incubate/autograd/utils.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 8734750400..d893484a9f 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1138,8 +1138,14 @@ extra : attrs : [bool use_mkldnn = false] -- op : reduce_mean +- op : mean (reduce_mean) backward : reduce_mean_grad + inputs : + x : X + outputs : + out : Out + attrs : + {axis : dim, keepdim : keep_dim} extra : attrs : [bool use_mkldnn = false] diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 211851160b..70537a3c8b 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -183,7 +183,8 @@ def _get_args_values(op, phi_name): and arg_name in op_content["attrs"].keys() ): attrs.append(op.attr(op_content["attrs"][arg_name])) - attrs.append(op.attr(arg_name)) + else: + attrs.append(op.attr(arg_name)) return inputs, attrs -- GitLab