未验证 提交 31c2b9dc 编写于 作者: C Chen Zhiyang 提交者: GitHub

Vjp auto gen StrAttribute bug fixed (#56971)

* fix StrAttribute vjp gen bug

* add dropout to list

* fix bug

* fix bug
上级 d78cbee7
......@@ -45,7 +45,7 @@ OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
}}"""
OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();"""
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();"""
OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """
{attr_type} {attr_name} = {default_value};"""
......@@ -92,6 +92,10 @@ input_types_map = {
'ir::VectorType<paddle::dialect::DenseTensorType>': 'Tensor[]',
}
attr_data_map = {
'ir::StrAttribute': 'AsString',
}
def gen_op_vjp_str(
op_class_name,
......@@ -155,10 +159,17 @@ def gen_op_vjp_str(
)
)
else:
func = 'data'
if (
op_grad_info.attribute_type_list[idx]
in attr_data_map.keys()
):
func = attr_data_map[op_grad_info.attribute_type_list[idx]]
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx],
func=func,
)
else:
......
......@@ -37,6 +37,7 @@ vjp_interface_declare_gen_op_list = [
"subtract",
"pow",
"rsqrt",
"dropout",
]
vjp_interface_implementation_gen_op_list = [
"tanh",
......@@ -52,4 +53,5 @@ vjp_interface_implementation_gen_op_list = [
"subtract",
"pow",
"rsqrt",
"dropout",
]
......@@ -61,6 +61,7 @@ VJPS = [
'rsqrt_grad',
'slice_grad',
'transpose_grad',
'dropout_grad',
]
VJP_COMPS = ['divide_grad', 'sum_grad']
BACKENDS = [
......@@ -127,6 +128,7 @@ BACKENDS = [
'roll',
'scatter',
'scatter_nd_add',
'dropout_grad',
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册