diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 48078e8c432f57a7edac419ecc1068bbf6e61f62..389643cf161139749918131ed529c149402a3f90 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -25,6 +25,15 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" +OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """ + ir::CombineOp combine_op_obj = + op_obj.{input_name}().GetDefiningOp()->dyn_cast(); + std::vector {input_name}; + for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ + {input_name}.emplace_back( + std::make_shared(combine_op_obj.inputs()[idx])); + }}""" + OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" @@ -44,21 +53,21 @@ OP_VJP_CALL_VJP_TEMPLATE = """ std::vector> tensor_res = OP_VJP_STOPGRADIENT_TEMPLATE = """ std::vector> res(tensor_res.size()); - for (size_t i = 0; i < tensor_res.size(); ++i) {{ + for (size_t i = 0; i < tensor_res.size(); ++i) { res[i].resize(tensor_res[i].size()); - for (size_t j = 0; j < tensor_res[i].size(); ++j) {{ - if(tensor_res[i][j].defined()){{ + for (size_t j = 0; j < tensor_res[i].size(); ++j) { + if(tensor_res[i][j].defined()){ res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); - }} - }} - }}""" + } + } + }""" OP_VJP_DEFINE_TEMPLATE = """ std::vector> {op_class_name}::Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); VLOG(6) << "Prepare inputs of {op_grad_name}"; -{forward_input_code} +{forward_input_output_code} {forward_output_grad_code} VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; @@ -87,7 +96,7 @@ def gen_op_vjp_str( op_grad_info, ): bw_input_list = op_grad_info.input_name_list - forward_input_code = '' + forward_input_output_code = '' forward_output_grad_code = '' build_args_str = '' grad_idx = -1 @@ -97,14 +106,20 @@ def gen_op_vjp_str( bw_input_list[idx] in op_info.input_name_list or bw_input_list[idx] in op_info.output_name_list ): - forward_input_code += ( - OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( - input_type=input_types_map[ - op_grad_info.input_type_list[idx] - ], - input_name=bw_input_list[idx], + input_type = input_types_map[op_grad_info.input_type_list[idx]] + if input_type == 'Tensor': + forward_input_output_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + input_type=input_type, + input_name=bw_input_list[idx], + ) + ) + else: + forward_input_output_code += ( + OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], + ) ) - ) else: grad_idx += 1 forward_output_grad_code += ( @@ -117,21 +132,31 @@ def gen_op_vjp_str( for idx in range(len(op_attribute_list)): build_args_str += op_attribute_list[idx] + ", " if op_attribute_list[idx] in op_info.attribute_name_list: - attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( - attr_type=op_grad_info.attribute_gen_arg_type_list[idx], - attr_name=op_attribute_list[idx], - attr_parse_type=op_grad_info.attribute_type_list[idx], - ) + if op_attribute_list[idx] in op_info.mutable_attribute_name_list: + attribute_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + input_type="Tensor", + input_name=op_attribute_list[idx], + ) + ) + else: + attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( + attr_type=op_grad_info.attribute_gen_arg_type_list[idx], + attr_name=op_attribute_list[idx], + attr_parse_type=op_grad_info.attribute_type_list[idx], + ) + else: attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format( attr_type=op_grad_info.attribute_gen_arg_type_list[idx], attr_name=op_attribute_list[idx], default_value=op_grad_info.attribute_default_value_list[idx], ) + op_phi_name_format = op_phi_name if op_phi_name[-1] == '_': - op_phi_name = op_phi_name[:-1] + op_phi_name_format = op_phi_name[:-1] call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format( - op_phi_name=op_phi_name, + op_phi_name=op_phi_name_format, inputs_list=build_args_str, ) stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE @@ -141,7 +166,7 @@ def gen_op_vjp_str( op_grad_name=op_grad_name, op_phi_name=op_phi_name, res_size=len(op_info.input_name_list), - forward_input_code=forward_input_code, + forward_input_output_code=forward_input_output_code, forward_output_grad_code=forward_output_grad_code, attribute_code=attribute_code, call_vjp_code=call_vjp_code, diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 8077801bf235ff5485003e363200138f4a9a1d4f..56991900b58572233ca708fe9e9de0f87ccca563 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -30,4 +30,10 @@ vjp_interface_declare_gen_op_list = [ "add", "concat", ] -vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"] +vjp_interface_implementation_gen_op_list = [ + "tanh", + "mean", + "divide", + "add", + "concat", +] diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc index 9f9060e4cf4ca90968fc13efffbc5bfb466fc210..d8b21ed96e9639e5663488a5a65135a9a5089356 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc @@ -27,40 +27,6 @@ namespace paddle { namespace dialect { using IntArray = paddle::experimental::IntArray; -std::vector> ConcatOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - ConcatOp op_obj = op->dyn_cast(); - ir::CombineOp combine_op_obj = - op_obj.x().GetDefiningOp()->dyn_cast(); - std::vector x; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { - x.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); - } - - Tensor out_grad(std::make_shared(out_grads[0][0])); - Tensor axis(std::make_shared(op_obj.axis())); - - std::vector> tensor_res = - primitive::concat_vjp(x, out_grad, axis, stop_gradients); - std::vector> res(tensor_res.size(), - std::vector()); - for (uint64_t i = 0; i < tensor_res.size(); i++) { - res[i].resize(tensor_res[i].size()); - for (uint64_t j = 0; j < tensor_res[i].size(); j++) { - if (tensor_res[i][j].defined()) { - res[i][j] = std::static_pointer_cast( - tensor_res[i][j].impl()) - ->getValue() - .dyn_cast(); - } - } - } - return res; -} - std::vector> SumOp::Vjp( ir::Operation* op, const std::vector>& out_grads,