diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 5f72a2efbc1d6c836644dec6f6e24a8f83476526..2490335f6c3fb0e97523961a5416dd15428d6388 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -45,7 +45,7 @@ OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ }}""" OP_VJP_ATTRIBUTE_TEMPLATE = """ - {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" + {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();""" OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ {attr_type} {attr_name} = {default_value};""" @@ -92,6 +92,10 @@ input_types_map = { 'ir::VectorType': 'Tensor[]', } +attr_data_map = { + 'ir::StrAttribute': 'AsString', +} + def gen_op_vjp_str( op_class_name, @@ -155,10 +159,17 @@ def gen_op_vjp_str( ) ) else: + func = 'data' + if ( + op_grad_info.attribute_type_list[idx] + in attr_data_map.keys() + ): + func = attr_data_map[op_grad_info.attribute_type_list[idx]] attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( attr_type=op_grad_info.attribute_gen_arg_type_list[idx], attr_name=op_attribute_list[idx], attr_parse_type=op_grad_info.attribute_type_list[idx], + func=func, ) else: diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index cb130ae0b236541916cd2b7eb71b8abfa8ad0ea2..9707d6fb5f9a2090ea2cb0385084b0438128295a 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -37,6 +37,7 @@ vjp_interface_declare_gen_op_list = [ "subtract", "pow", "rsqrt", + "dropout", ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -52,4 +53,5 @@ vjp_interface_implementation_gen_op_list = [ "subtract", "pow", "rsqrt", + "dropout", ] diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 89ba4fe53cdf0e0b2df7b484f9a23a6488beafc6..722ed94953d19de789243737c97eab7d78661bf4 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -61,6 +61,7 @@ VJPS = [ 'rsqrt_grad', 'slice_grad', 'transpose_grad', + 'dropout_grad', ] VJP_COMPS = ['divide_grad', 'sum_grad'] BACKENDS = [ @@ -127,6 +128,7 @@ BACKENDS = [ 'roll', 'scatter', 'scatter_nd_add', + 'dropout_grad', ]