未验证 提交 31c2b9dc 编写于 作者: C Chen Zhiyang 提交者: GitHub

Vjp auto gen StrAttribute bug fixed (#56971)

* fix StrAttribute vjp gen bug

* add dropout to list

* fix bug

* fix bug
上级 d78cbee7
...@@ -45,7 +45,7 @@ OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ ...@@ -45,7 +45,7 @@ OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
}}""" }}"""
OP_VJP_ATTRIBUTE_TEMPLATE = """ OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();"""
OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """
{attr_type} {attr_name} = {default_value};""" {attr_type} {attr_name} = {default_value};"""
...@@ -92,6 +92,10 @@ input_types_map = { ...@@ -92,6 +92,10 @@ input_types_map = {
'ir::VectorType<paddle::dialect::DenseTensorType>': 'Tensor[]', 'ir::VectorType<paddle::dialect::DenseTensorType>': 'Tensor[]',
} }
attr_data_map = {
'ir::StrAttribute': 'AsString',
}
def gen_op_vjp_str( def gen_op_vjp_str(
op_class_name, op_class_name,
...@@ -155,10 +159,17 @@ def gen_op_vjp_str( ...@@ -155,10 +159,17 @@ def gen_op_vjp_str(
) )
) )
else: else:
func = 'data'
if (
op_grad_info.attribute_type_list[idx]
in attr_data_map.keys()
):
func = attr_data_map[op_grad_info.attribute_type_list[idx]]
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx], attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx], attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx], attr_parse_type=op_grad_info.attribute_type_list[idx],
func=func,
) )
else: else:
......
...@@ -37,6 +37,7 @@ vjp_interface_declare_gen_op_list = [ ...@@ -37,6 +37,7 @@ vjp_interface_declare_gen_op_list = [
"subtract", "subtract",
"pow", "pow",
"rsqrt", "rsqrt",
"dropout",
] ]
vjp_interface_implementation_gen_op_list = [ vjp_interface_implementation_gen_op_list = [
"tanh", "tanh",
...@@ -52,4 +53,5 @@ vjp_interface_implementation_gen_op_list = [ ...@@ -52,4 +53,5 @@ vjp_interface_implementation_gen_op_list = [
"subtract", "subtract",
"pow", "pow",
"rsqrt", "rsqrt",
"dropout",
] ]
...@@ -61,6 +61,7 @@ VJPS = [ ...@@ -61,6 +61,7 @@ VJPS = [
'rsqrt_grad', 'rsqrt_grad',
'slice_grad', 'slice_grad',
'transpose_grad', 'transpose_grad',
'dropout_grad',
] ]
VJP_COMPS = ['divide_grad', 'sum_grad'] VJP_COMPS = ['divide_grad', 'sum_grad']
BACKENDS = [ BACKENDS = [
...@@ -127,6 +128,7 @@ BACKENDS = [ ...@@ -127,6 +128,7 @@ BACKENDS = [
'roll', 'roll',
'scatter', 'scatter',
'scatter_nd_add', 'scatter_nd_add',
'dropout_grad',
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册