From 31c2b9dcc720f8e6377682f6e754646e091514fc Mon Sep 17 00:00:00 2001 From: Chen Zhiyang <1792266893@qq.com> Date: Thu, 7 Sep 2023 14:04:22 +0800 Subject: [PATCH] Vjp auto gen StrAttribute bug fixed (#56971) * fix StrAttribute vjp gen bug * add dropout to list * fix bug * fix bug --- .../ir/dialect/op_generator/op_interface_gen.py | 13 ++++++++++++- .../op_generator/vjp_interface_gen_op_list.py | 2 ++ paddle/fluid/primitive/codegen/gen.py | 2 ++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 5f72a2efbc1..2490335f6c3 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -45,7 +45,7 @@ OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ }}""" OP_VJP_ATTRIBUTE_TEMPLATE = """ - {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" + {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();""" OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ {attr_type} {attr_name} = {default_value};""" @@ -92,6 +92,10 @@ input_types_map = { 'ir::VectorType': 'Tensor[]', } +attr_data_map = { + 'ir::StrAttribute': 'AsString', +} + def gen_op_vjp_str( op_class_name, @@ -155,10 +159,17 @@ def gen_op_vjp_str( ) ) else: + func = 'data' + if ( + op_grad_info.attribute_type_list[idx] + in attr_data_map.keys() + ): + func = attr_data_map[op_grad_info.attribute_type_list[idx]] attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( attr_type=op_grad_info.attribute_gen_arg_type_list[idx], attr_name=op_attribute_list[idx], attr_parse_type=op_grad_info.attribute_type_list[idx], + func=func, ) else: diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index cb130ae0b23..9707d6fb5f9 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -37,6 +37,7 @@ vjp_interface_declare_gen_op_list = [ "subtract", "pow", "rsqrt", + "dropout", ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -52,4 +53,5 @@ vjp_interface_implementation_gen_op_list = [ "subtract", "pow", "rsqrt", + "dropout", ] diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 89ba4fe53cd..722ed94953d 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -61,6 +61,7 @@ VJPS = [ 'rsqrt_grad', 'slice_grad', 'transpose_grad', + 'dropout_grad', ] VJP_COMPS = ['divide_grad', 'sum_grad'] BACKENDS = [ @@ -127,6 +128,7 @@ BACKENDS = [ 'roll', 'scatter', 'scatter_nd_add', + 'dropout_grad', ] -- GitLab