diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 8071edee33607c95df994ef6cf2959e4bdf00a64..fe2edb2b00ea57c75ba1ab4baa0a73f613cc2feb 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -172,7 +172,7 @@ scalar_type_maps = { 'bool': 'ir::BoolAttribute', } -_NO_NEED_GEN_OPS = {'add_n'} +_NO_NEED_GEN_OPS = {'add_n', 'split_grad'} def to_phi_and_fluid_op_name(op_item): diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 56991900b58572233ca708fe9e9de0f87ccca563..aa8d8d1c8e3e8ce737cbf29aeb43d3aa382ecf41 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -29,6 +29,7 @@ vjp_interface_declare_gen_op_list = [ "sum", "add", "concat", + "split", ] vjp_interface_implementation_gen_op_list = [ "tanh", diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc index ddc117cb22c192a5f3ee5c024f05512f748651b1..19b8b133559b7469117ec83bf9bdca935b6b92c4 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc @@ -48,7 +48,7 @@ void PaddleDialect::initialize() { #define GET_OP_LIST #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" // NOLINT >(); - RegisterOp(); + RegisterOps(); RegisterInterfaces(); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc index 8866922e4aa34ee74d3f49389a82f31a41a4d9de..070c2e49ac6d5a8a29fe4e5743b420fafa97dccd 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc @@ -18,5 +18,16 @@ #include "paddle/ir/core/builtin_op.h" namespace paddle { -namespace dialect {} // namespace dialect +namespace dialect { +ir::OpResult split_grad(std::vector out_grads, + ir::OpResult axis) { + auto combine_op = + APIBuilder::Instance().GetBuilder()->Build(out_grads); + paddle::dialect::SplitGradOp split_grad_op = + APIBuilder::Instance().GetBuilder()->Build( + combine_op.out(), axis); + + return split_grad_op.x_grad(); +} +} // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h index de86758dddba8efe7b353f1679ab4b902c981d90..1d16bc079378829d858140ba8ef286fd55a77bd6 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h @@ -21,5 +21,9 @@ #include "paddle/phi/common/place.h" namespace paddle { -namespace dialect {} // namespace dialect +namespace dialect { + +ir::OpResult split_grad(std::vector out_grads, ir::OpResult axis); + +} // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 4a131bbf1dc506561b235699e764a5a78ed1e5cf..64cb1d69b210a7295752fed4fa662d9d04de5344 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_op.h" @@ -145,7 +146,221 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +const char *SplitGradOp::attributes_name[1] = {"axis"}; + +OpInfoTuple SplitGradOp::GetOpInfo() { + std::vector inputs = { + OpInputInfo("out_grad", + "ir::VectorType", + false, + false, + false), + OpInputInfo( + "axis", "paddle::dialect::ScalarAttribute", false, false, true)}; + std::vector attributes = {}; + std::vector outputs = { + OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + OpRunTimeInfo("ConcatInferMeta", + {"out_grad", "axis"}, + {"concat"}, + {"out_grad", "axis"}, + {"out_grad"}, + {}, + {}, + {}); + + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "split_grad"); +} + +void SplitGradOp::Build(ir::Builder &builder, + ir::OperationArgument &argument, + ir::OpResult out_grad_, + float axis) { + // Generate scalar mutable attribute: axis + paddle::dialect::FullOp full_axis_op = builder.Build( + std::vector{1}, axis, phi::DataType::FLOAT32, phi::CPUPlace()); + ir::OpResult axis_ = full_axis_op->result(0); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {out_grad_, axis_}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + ir::VectorType out_grad = out_grad_.type().dyn_cast(); + std::vector vec_dense_out_grad; + for (size_t i = 0; i < static_cast(out_grad.size()); i++) { + vec_dense_out_grad.push_back(phi::DenseTensor( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta( + paddle::dialect::TransToPhiDataType( + out_grad[i] + .dyn_cast() + .dtype()), + out_grad[i].dyn_cast().dims(), + out_grad[i] + .dyn_cast() + .data_layout(), + out_grad[i].dyn_cast().lod(), + out_grad[i] + .dyn_cast() + .offset()))); + } + std::vector vec_meta_out_grad; + for (size_t i = 0; i < vec_dense_out_grad.size(); i++) { + vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i])); + } + + std::vector meta_out_grad; + for (size_t i = 0; i < static_cast(vec_meta_out_grad.size()); i++) { + meta_out_grad.push_back(&vec_meta_out_grad[i]); + } + phi::DenseTensor dense_x_grad; + phi::MetaTensor meta_x_grad(&dense_x_grad); + + phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); + + std::vector argument_outputs; + ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + +void SplitGradOp::Build(ir::Builder &builder, + ir::OperationArgument &argument, + ir::OpResult out_grad_, + ir::OpResult axis_) { + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {out_grad_, axis_}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + ir::VectorType out_grad = out_grad_.type().dyn_cast(); + int axis = axis_.owner() + ->dyn_cast() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + + std::vector vec_dense_out_grad; + for (size_t i = 0; i < static_cast(out_grad.size()); i++) { + vec_dense_out_grad.push_back(phi::DenseTensor( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta( + TransToPhiDataType(out_grad[i] + .dyn_cast() + .dtype()), + out_grad[i].dyn_cast().dims(), + out_grad[i] + .dyn_cast() + .data_layout(), + out_grad[i].dyn_cast().lod(), + out_grad[i] + .dyn_cast() + .offset()))); + } + std::vector vec_meta_out_grad; + for (size_t i = 0; i < vec_dense_out_grad.size(); i++) { + vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i])); + } + + std::vector meta_out_grad; + for (size_t i = 0; i < static_cast(vec_meta_out_grad.size()); i++) { + meta_out_grad.push_back(&vec_meta_out_grad[i]); + } + phi::DenseTensor dense_x_grad; + phi::MetaTensor meta_x_grad(&dense_x_grad); + + phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); + + std::vector argument_outputs; + ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + +void SplitGradOp::Verify() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 2u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 2.", input_size)); + if (auto vec_type = + (*this)->operand_source(0).type().dyn_cast()) { + for (size_t i = 0; i < vec_type.size(); ++i) { + PADDLE_ENFORCE(vec_type[i].isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + } else { + PADDLE_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + PADDLE_ENFORCE((*this) + ->operand_source(1) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th input.")); + } + VLOG(4) << "Verifying attributes:"; + { + // Attributes num is 0, not need to check attributes type. + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", output_size)); + PADDLE_ENFORCE( + (*this)->result(0).type().isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + } + VLOG(4) << "End Verifying for: SplitGradOp."; +} + +void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::ConcatInferMeta); + fn(infer_meta); +} + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h index e3814a535ccf0c66d199739bcb6a7fa9347826d7..fe9beb46012edf06f8601fa5a0ef2178ed8ac0ff 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h @@ -14,7 +14,7 @@ #ifdef GET_MANUAL_OP_LIST #undef GET_MANUAL_OP_LIST -paddle::dialect::AddNOp +paddle::dialect::AddNOp, paddle::dialect::SplitGradOp #else @@ -51,9 +51,33 @@ class AddNOp : public ir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); }; +class SplitGradOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "pd.split_grad"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult x_, + float axis = 0); + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult out_grad_, + ir::OpResult axis_); + + void Verify(); + ir::Value out_grad() { return operand_source(0); } + ir::Value axis() { return operand_source(1); } + ir::OpResult x_grad() { return result(0); } + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) #endif diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc index d8b21ed96e9639e5663488a5a65135a9a5089356..9806fb4cf0ce2279bfb649f95f01be93534e01fb 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op_vjp_manual.cc @@ -53,5 +53,39 @@ std::vector> SumOp::Vjp( } return res; } + +std::vector> SplitOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + SplitOp op_obj = op->dyn_cast(); + + Tensor axis(std::make_shared(op_obj.axis())); + std::vector out_grads_; + for (size_t idx = 0; idx < out_grads[0].size(); idx++) { + out_grads_.emplace_back( + std::make_shared(out_grads[0][idx])); + } + + std::vector> tensor_res = + primitive::split_vjp(out_grads_, axis, stop_gradients); + + std::vector> res(tensor_res.size(), + std::vector()); + + for (uint64_t i = 0; i < tensor_res.size(); i++) { + res[i].resize(tensor_res[i].size()); + for (uint64_t j = 0; j < tensor_res[i].size(); j++) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->getValue() + .dyn_cast(); + } + } + } + return res; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_backend.h b/paddle/fluid/primitive/backend/manual/manual_backend.h index fba6c7b3bbee4d91bd1912d104c7fd5de71004b9..fe009bd8fecbe1431f2e1ff9ce32307c49b37ba1 100644 --- a/paddle/fluid/primitive/backend/manual/manual_backend.h +++ b/paddle/fluid/primitive/backend/manual/manual_backend.h @@ -33,6 +33,9 @@ std::vector concat_grad(const std::vector& x, const Tensor& out_grad, const Tensor& axis); +template +Tensor split_grad(const std::vector& out_grads, const Tensor& axis); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc index 466258a73fd6fb993c37b58e937012c744405e9d..53a231d4cf9e361dd8c2574531b7dd8b122e3408 100644 --- a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc +++ b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc @@ -54,6 +54,23 @@ std::vector concat_grad(const std::vector& x, return op_result; } +template <> +Tensor split_grad(const std::vector& out_grads, + const Tensor& axis) { + std::vector out_grads_res; + for (uint64_t idx = 0; idx < out_grads.size(); idx++) { + out_grads_res.emplace_back( + std::static_pointer_cast(out_grads[idx].impl()) + ->getValue() + .dyn_cast()); + } + ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::split_grad(out_grads_res, axis_res); + return Tensor(std::make_shared(op_res)); +} + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc index f30a5984417c648f7f27fbde8807c672a5e5026b..86d83dbee249d135793d4d7beae15ed893f20d64 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc @@ -48,5 +48,25 @@ std::vector> concat_vjp( return vjp_res; } +std::vector> split_vjp( + const std::vector& out_grads, + const Tensor& axis, + const std::vector>& stop_gradients) { + std::vector> vjp_res(3, std::vector(1)); + // get concat_grad res. + Tensor op_res = backend::split_grad(out_grads, axis); + + // construct vjp result by op result and stop_gradients info + if (!stop_gradients[0][0]) { + vjp_res[0][0] = op_res; + } + + // vjp_res[1] is sections's grad which is attribute (no grad). + // vjp_res[2] is axis's grad which is attribute (no grad). + vjp_res[1].resize(stop_gradients[1].size()); + vjp_res[2].resize(stop_gradients[2].size()); + return vjp_res; +} + } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h index 49510b80d452635cf8e1f8660889a744731df578..87e1f33bb9ebdd50d365f341b7227aa6a9a7d99c 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h @@ -30,5 +30,10 @@ std::vector> concat_vjp( const Tensor& axis, const std::vector>& stop_gradients); +std::vector> split_vjp( + const std::vector& out_grads, + const Tensor& axis, + const std::vector>& stop_gradients); + } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/pybind/ops_api.cc b/paddle/fluid/pybind/ops_api.cc index 506b932870bf14ead747c80780145f98e2af20b4..9efe49a97c8c9e6b4764586a963aa269959f0142 100644 --- a/paddle/fluid/pybind/ops_api.cc +++ b/paddle/fluid/pybind/ops_api.cc @@ -52,6 +52,10 @@ static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) { return static_api_concat(self, args, kwargs); } +static PyObject *split(PyObject *self, PyObject *args, PyObject *kwargs) { + return static_api_split(self, args, kwargs); +} + static PyMethodDef OpsAPI[] = {{"add_n", (PyCFunction)(void (*)(void))add_n, METH_VARARGS | METH_KEYWORDS, @@ -76,6 +80,10 @@ static PyMethodDef OpsAPI[] = {{"add_n", (PyCFunction)(void (*)(void))full, METH_VARARGS | METH_KEYWORDS, "C++ interface function for full."}, + {"split", + (PyCFunction)(void (*)(void))split, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for split."}, {"data", (PyCFunction)(void (*)(void))data, METH_VARARGS | METH_KEYWORDS, diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 13a5acd46a91758761b007d51ced684b70fc6ce0..c227577d0606cfdad90bba97931c446e58533384 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -201,7 +201,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): outputs_set.add(operand) else: relevant_op_flags[i] = False - # recover full op or full_Intarray op created by mutable attribute. total_ops_list = list(total_ops) for i, op in enumerate(total_ops_list): @@ -354,12 +353,16 @@ def append_backward_ops( def make_output_grad(op): zero_flag = [False] * op.num_results() + output_grads = [] for i, value in enumerate(op.results()): if ( value not in state.value_to_valuegrad or state.value_to_valuegrad[value] is None ): - if value.first_use().owner().name() == "builtin.split": + if ( + not value.use_empty() + and value.first_use().owner().name() == "builtin.split" + ): # pattern case: # this fwd_op's output is vectorType, it will split to # Type by builtin.split op, so need get from split op's ouput @@ -367,7 +370,7 @@ def append_backward_ops( value.first_use().owner() ) zero_flag[i] = all(split_zero_flag) - grad_value = [op_list[0] for op_list in split_output_grad] + state.value_to_valuegrad[value] = [split_output_grad] else: # first case: # this fwd_op's output didn't used by other fwd_op, @@ -388,7 +391,7 @@ def append_backward_ops( ) zero_flag[i] = True - state.value_to_valuegrad[value] = [[grad_value]] + state.value_to_valuegrad[value] = [[grad_value]] if len(state.value_to_valuegrad[value]) > 1: # one value is input of more than one fwd_op, @@ -411,8 +414,8 @@ def append_backward_ops( value ] - output_grad = state.value_to_valuegrad[value][0] - return zero_flag, output_grad + output_grads.append(state.value_to_valuegrad[value][0][0]) + return zero_flag, output_grads def make_input_stopgradient(op): input_grad_stopgradients = [] @@ -529,8 +532,9 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): inputs_set.add(state.value_to_valuegrad[output][0][0]) inputs_set_tmp = set() for out_grad in inputs_set: - for item in out_grad.first_use().owner().operands_source(): - inputs_set_tmp.add(item) + if not out_grad.use_empty(): + for item in out_grad.first_use().owner().operands_source(): + inputs_set_tmp.add(item) inputs_set.update(inputs_set_tmp) no_gradvar_set = set() # grad_value of value in no_grad_set diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4a4e1dedf4a13dd5f561bb6aa785f7af8bbcad75..59a2f6f0cc4131f3e9e4e93a669f5e34e7dfefa4 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1976,6 +1976,14 @@ def split(x, num_or_sections, axis=0, name=None): else: return _C_ops.split(input, num_or_sections, dim) else: + if paddle.ir.core._use_new_ir_api(): + if not isinstance(num_or_sections, int): + return paddle._ir_ops.split(input, num_or_sections, dim) + else: + raise NotImplementedError( + "_ir_ops.split_with_num is not implemented, please change sections as list" + ) + check_variable_and_dtype( input, 'input', diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 667690472c2706c27e8d029fb02bc51f3df1b35e..496bb1de1891ab1671a3402ca3d8637f57a7df1f 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -402,5 +402,74 @@ TEST(VJP, Add_BackwardTest) { ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); } + +TEST(VJP, SplitBackwardTest) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + paddle::dialect::APIBuilder::Instance().SetProgram(&program); + + std::shared_ptr builder = + paddle::dialect::APIBuilder::Instance().GetBuilder(); + paddle::dialect::FullOp op1 = builder->Build( + std::vector{2, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::SplitOp op2 = builder->Build( + op1.out(), std::vector{1, 1}, 0); + + ir::SplitOp op3 = builder->Build(op2.out()); + + paddle::dialect::FullOp op4 = builder->Build( + std::vector{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + std::vector> stop_gradients{{false}, {true}, {true}}; + std::vector> out_grads{{op3.result(0), op4.out()}}; + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.split"); + + auto concat_vjp_interface_impl = + op2_info.GetInterfaceImpl(); + + concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + std::stringstream os; + os << reinterpret_cast( + const_cast(test_core.Impl())); + std::string prefix_str = os.str(); + test_core.SetSkipGcVars({prefix_str + "_inner_var_4", + prefix_str + "_inner_var_5", + prefix_str + "_inner_var_8"}); + test_core.Run({}); + auto out_tensor_0 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_4")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_4") + ->Get(); + auto out_tensor_1 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_5")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_5") + ->Get(); + auto grad_out_tensor_0 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_8") + ->Get(); + ASSERT_EQ(out_tensor_0.data()[0], 2.0); + ASSERT_EQ(out_tensor_0.data()[1], 2.0); + ASSERT_EQ(out_tensor_1.data()[0], 2.0); + ASSERT_EQ(out_tensor_1.data()[1], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[0], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[1], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[2], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[3], 1.0); +} + } // namespace framework } // namespace paddle diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/new_ir/test_build_op.py index e54e493b99a773de9fa5c3fe3bc3f6c070472691..16bc1adb0628ed95e1b4b4b42e1ac44173ea1f91 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/new_ir/test_build_op.py @@ -121,5 +121,25 @@ class TestBuildOp4(unittest.TestCase): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) +class TestBuildOp5(unittest.TestCase): + def test_build_split_op(self): + newir_program = get_ir_program() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.split(tanh_out, [2, 2], 0) + self.assertEqual(out[0].get_defining_op().name(), "builtin.split") + self.assertEqual( + out[0] + .get_defining_op() + .operands()[0] + .source() + .get_defining_op() + .name(), + "pd.split", + ) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + if __name__ == "__main__": unittest.main() diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 7adbbca86cf6116af7052602678bf89bb2259c5e..be29baa1069d2f0b4a60632bf010ff7cdfb9bb13 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -94,6 +94,32 @@ class TesBackward_1(unittest.TestCase): self.assertEqual(newir_program.block().ops[-1].name(), "pd.mean") paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + def test_split(self): + # test create output_grad in backward use full op + newir_program = get_ir_program_0() + input = newir_program.block().ops[-1].operand(0).source() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.split(tanh_out, [2, 2], 0) + input_grad = grad(out, input) + + ops_name = [ + "pd.data", + "pd.tanh", + "pd.full_int_array", + "pd.full", + "pd.split", + "builtin.split", + "pd.full", + "builtin.combine", + "pd.split_grad", + "pd.tanh_grad", + ] + for i, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), ops_name[i]) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + def get_ir_program_1(): x = paddle.randn([2, 2])