diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 389643cf161139749918131ed529c149402a3f90..5f72a2efbc1d6c836644dec6f6e24a8f83476526 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -38,7 +38,11 @@ OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" + std::vector {output_grad_name}; + for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{ + {output_grad_name}.emplace_back( + std::make_shared(out_grads[{index}][idx])); + }}""" OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" @@ -47,9 +51,10 @@ OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ {attr_type} {attr_name} = {default_value};""" -OP_VJP_CALL_VJP_TEMPLATE = """ std::vector> tensor_res = - primitive::{op_phi_name}_vjp( - {inputs_list}stop_gradients);""" +OP_VJP_CALL_VJP_TEMPLATE = """ + std::vector> tensor_res = + primitive::{op_phi_name}_vjp( + {inputs_list}stop_gradients);""" OP_VJP_STOPGRADIENT_TEMPLATE = """ std::vector> res(tensor_res.size()); @@ -73,10 +78,10 @@ std::vector> {op_class_name}::Vjp(ir::Operation* op, c VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; {attribute_code} - VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; + VLOG(6) << "Vjp prepare call {op_phi_name}'s vjp inteface"; {call_vjp_code} - VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}"; + VLOG(6) << "Vjp prepare stop gradient of {op_grad_name}"; {stop_gradient_input_grad_code} return res; }} @@ -122,11 +127,21 @@ def gen_op_vjp_str( ) else: grad_idx += 1 - forward_output_grad_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( - output_grad_name=bw_input_list[idx], idx1=grad_idx, idx2=0 + input_type = input_types_map[op_grad_info.input_type_list[idx]] + if input_type == 'Tensor': + forward_output_grad_code += ( + OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], + idx1=grad_idx, + idx2=0, + ) + ) + else: + forward_input_output_code += ( + OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE.format( + output_grad_name=bw_input_list[idx], index=grad_idx + ) ) - ) op_attribute_list = op_grad_info.attribute_name_list attribute_code = '' for idx in range(len(op_attribute_list)): diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index aa8d8d1c8e3e8ce737cbf29aeb43d3aa382ecf41..4fc85a07511f6d6640a93fe63088193d6751cc8b 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -37,4 +37,5 @@ vjp_interface_implementation_gen_op_list = [ "divide", "add", "concat", + "split", ] diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt index 02bedd7fd8619d9bbcf850850040fce3a695a9f9..86ade99a3cc228f930a2e17ec606839913ccd7a5 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt @@ -134,6 +134,6 @@ target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR}) cc_library( pd_dialect - SRCS pd_dialect.cc pd_op_vjp_manual.cc ${op_vjp_source_file} + SRCS pd_dialect.cc pd_manual_op_vjp.cc ${op_vjp_source_file} DEPS pd_dialect_api param_to_variable primitive_vjp_experimental pd_dialect_utils op_yaml_info_parser) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc similarity index 65% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc rename to paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc index 9806fb4cf0ce2279bfb649f95f01be93534e01fb..a69a1af650b9822212dda6bc05be79fd3c09db76 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc @@ -54,38 +54,5 @@ std::vector> SumOp::Vjp( return res; } -std::vector> SplitOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - SplitOp op_obj = op->dyn_cast(); - - Tensor axis(std::make_shared(op_obj.axis())); - std::vector out_grads_; - for (size_t idx = 0; idx < out_grads[0].size(); idx++) { - out_grads_.emplace_back( - std::make_shared(out_grads[0][idx])); - } - - std::vector> tensor_res = - primitive::split_vjp(out_grads_, axis, stop_gradients); - - std::vector> res(tensor_res.size(), - std::vector()); - - for (uint64_t i = 0; i < tensor_res.size(); i++) { - res[i].resize(tensor_res[i].size()); - for (uint64_t j = 0; j < tensor_res[i].size(); j++) { - if (tensor_res[i][j].defined()) { - res[i][j] = std::static_pointer_cast( - tensor_res[i][j].impl()) - ->getValue() - .dyn_cast(); - } - } - } - return res; -} - } // namespace dialect } // namespace paddle