diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/ir/dialect/op_generator/api_gen.py index 6e313267059666bf250b8ee6653e79a5bf8477d0..8180723b3c85e9a163621f836e8269e17663e7f9 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/api_gen.py @@ -29,6 +29,7 @@ H_FILE_TEMPLATE = """ #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" +#include "paddle/fluid/ir/dialect/pd_manual_api.h" {body} diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index fd7d61897d858533a0b8ed49d3f65fe6a2e66665..8077801bf235ff5485003e363200138f4a9a1d4f 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -21,5 +21,13 @@ # TODO(wanghao107) # remove this file and support Vjp methods # code gen. -vjp_interface_declare_gen_op_list = ["tanh", "mean", "divide", "sum", "add"] + +vjp_interface_declare_gen_op_list = [ + "tanh", + "mean", + "divide", + "sum", + "add", + "concat", +] vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"] diff --git a/paddle/fluid/ir/dialect/pd_manual_api.cc b/paddle/fluid/ir/dialect/pd_manual_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..10e25ea883b91445c8ab2c24d8b916e8e3a825ac --- /dev/null +++ b/paddle/fluid/ir/dialect/pd_manual_api.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/ir/dialect/pd_manual_api.h" +#include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/fluid/ir/dialect/pd_op.h" +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/builtin_op.h" + +namespace paddle { +namespace dialect { +std::vector concat_grad(std::vector x, + ir::OpResult out_grad, + ir::OpResult axis) { + auto combine_op = + APIBuilder::Instance().GetBuilder()->Build(x); + paddle::dialect::ConcatGradOp concat_grad_op = + APIBuilder::Instance().GetBuilder()->Build( + combine_op.out(), out_grad, axis); + auto split_op = APIBuilder::Instance().GetBuilder()->Build( + concat_grad_op.result(0)); + return split_op.outputs(); +} +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_manual_api.h b/paddle/fluid/ir/dialect/pd_manual_api.h new file mode 100644 index 0000000000000000000000000000000000000000..dff38ef565cb2d8c479970fb61264baaa67510fc --- /dev/null +++ b/paddle/fluid/ir/dialect/pd_manual_api.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/ir/core/value.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" + +namespace paddle { +namespace dialect { + +std::vector concat_grad(std::vector x, + ir::OpResult out_grad, + ir::OpResult axis); +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index b41cbdab51991ea12d63735ca8388652470d190b..a68d0ee505816bc51d83d715820e42648677384c 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/op_base.h" #include "paddle/phi/common/int_array.h" @@ -24,6 +25,42 @@ namespace paddle { namespace dialect { +using IntArray = paddle::experimental::IntArray; + +std::vector> ConcatOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + ConcatOp op_obj = op->dyn_cast(); + ir::CombineOp combine_op_obj = + op_obj.x().GetDefiningOp()->dyn_cast(); + std::vector x; + for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + x.emplace_back( + std::make_shared(combine_op_obj.inputs()[idx])); + } + + Tensor out_grad(std::make_shared(out_grads[0][0])); + Tensor axis(std::make_shared(op_obj.axis())); + + std::vector> tensor_res = + primitive::concat_vjp(x, out_grad, axis, stop_gradients); + std::vector> res(tensor_res.size(), + std::vector()); + for (uint64_t i = 0; i < tensor_res.size(); i++) { + res[i].resize(tensor_res[i].size()); + for (uint64_t j = 0; j < tensor_res[i].size(); j++) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->getValue() + .dyn_cast(); + } + } + } + return res; +} + std::vector> SumOp::Vjp( ir::Operation* op, const std::vector>& out_grads, diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index 539237c22775cebf189f6cc7fe0dcbb4dd655baf..3bb9c616a781999d672d3d2f927275f61e606d41 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -242,6 +242,38 @@ Tensor sum_grad(const Tensor& x, x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); return Tensor(std::make_shared(op_res)); } + +template <> +std::vector concat_grad(const std::vector& x, + const Tensor& out_grad, + const Tensor& axis) { + std::vector x_res; + for (uint64_t idx = 0; idx < x.size(); idx++) { + x_res.emplace_back(std::static_pointer_cast(x[idx].impl()) + ->getValue() + .dyn_cast()); + } + + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + + ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) + ->getValue() + .dyn_cast(); + + std::vector op_res = + paddle::dialect::concat_grad(x_res, out_grad_res, axis_res); + + std::vector op_result; + for (uint64_t idx = 0; idx < op_res.size(); idx++) { + op_result.emplace_back( + std::make_shared(op_res[idx])); + } + return op_result; +} + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 1e484aa35e67603f198fa10001b760efd9a4ba9c..ba608a01eeab367731d4774647879df2e11112ac 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -37,6 +37,11 @@ Tensor mean_grad(const Tensor& x, bool keepdim = false, bool reduce_all = false); +template +std::vector concat_grad(const std::vector& x, + const Tensor& out_grad, + const Tensor& axis); + template std::tuple add_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 59fabfc87cfa23fa231405bf5ea04f36f9b94a69..364c590a6edbb251d9a75f98b6171e1b8a75b5ff 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -148,6 +148,28 @@ std::vector> add_vjp( return vjp_res; } +std::vector> concat_vjp( + const std::vector& x, + const Tensor& out_grad, + const Tensor& axis, + const std::vector>& stop_gradients) { + std::vector> vjp_res(2, std::vector()); + // get concat_grad res. + std::vector op_res = + backend::concat_grad(x, out_grad, axis); + + // construct vjp result by op result and stop_gradients info + vjp_res[0].resize(op_res.size()); + for (uint64_t idx = 0; idx < op_res.size(); idx++) { + if (!stop_gradients[0][idx]) { + vjp_res[0][idx] = op_res[idx]; + } + } + // vjp_res[1] is axis's grad which is attribute (no grad). + vjp_res[1].resize(1); + return vjp_res; +} + std::vector> divide_vjp( const Tensor& x, const Tensor& y, diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index 94b1f9d67ccb6a13b3d6c51f722149110947bce9..eace3d3cb5bdf22778554d24f8e7dcff1762e4da 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -38,6 +38,12 @@ std::vector> mean_vjp( bool reduce_all, const std::vector>& stop_gradients); +std::vector> concat_vjp( + const std::vector& x, + const Tensor& out_grad, + const Tensor& axis, + const std::vector>& stop_gradients); + std::vector> add_vjp( const Tensor& x, const Tensor& y, diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index a6da23bc78e0f04e7dd3586e021659c26d315719..05295243b124c16ff691b021c187e1e0033f6cf2 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -249,6 +249,7 @@ void BindValue(py::module *m) { .def("get_defining_op", &Value::GetDefiningOp, return_value_policy::reference) + .def("first_use", &Value::first_use, return_value_policy::reference) .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { @@ -272,9 +273,11 @@ void BindOpOperand(py::module *m) { op_operand .def("source", [](OpOperand &self) { return self.source().dyn_cast(); }) - .def("set_source", [](OpOperand &self, const OpResult &result) { - self.set_source(result); - }); + .def("set_source", + [](OpOperand &self, const OpResult &result) { + self.set_source(result); + }) + .def("owner", &OpOperand::owner, return_value_policy::reference); } bool GetStopGradient(const OpResult &self) { @@ -331,6 +334,7 @@ void BindOpResult(py::module *m) { .def("get_defining_op", &OpResult::GetDefiningOp, return_value_policy::reference) + .def("first_use", &OpResult::first_use, return_value_policy::reference) .def("use_empty", &OpResult::use_empty) .def("type", &OpResult::type) .def_property( diff --git a/paddle/fluid/pybind/ops_api.cc b/paddle/fluid/pybind/ops_api.cc index ea6c010ec04abbc65e0d9e561a3cec041a418434..27cbf36aecbcb52e8acfbaaffa0d08f4d1cf9c9f 100644 --- a/paddle/fluid/pybind/ops_api.cc +++ b/paddle/fluid/pybind/ops_api.cc @@ -40,7 +40,11 @@ static PyObject *divide(PyObject *self, PyObject *args, PyObject *kwargs) { return static_api_divide(self, args, kwargs); } -static PyMethodDef OpsAPI[] = {{"add_n", // NOLINT +static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) { + return static_api_concat(self, args, kwargs); +} + +static PyMethodDef OpsAPI[] = {{"add_n", (PyCFunction)(void (*)(void))add_n, METH_VARARGS | METH_KEYWORDS, "C++ interface function for add_n."}, @@ -56,6 +60,10 @@ static PyMethodDef OpsAPI[] = {{"add_n", // NOLINT (PyCFunction)(void (*)(void))divide, METH_VARARGS | METH_KEYWORDS, "C++ interface function for divide."}, + {"concat", + (PyCFunction)(void (*)(void))concat, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for concat."}, {"full", (PyCFunction)(void (*)(void))full, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/fluid/pybind/static_op_function.cc b/paddle/fluid/pybind/static_op_function.cc index ad992fab4972fbdb84556cf53a689e0e4131424e..632f7044c46170865f568d72096c58ac6056ea06 100644 --- a/paddle/fluid/pybind/static_op_function.cc +++ b/paddle/fluid/pybind/static_op_function.cc @@ -109,6 +109,26 @@ PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs) { } } +PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs) { + try { + VLOG(6) << "Add concat op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + // Get OpResult from args + PyObject *x_obj = PyTuple_GET_ITEM(args, 0); + auto x = CastPyArg2VectorOfOpResult("concat", x_obj, 0); + + PyObject *axis_obj = PyTuple_GET_ITEM(args, 1); + paddle::experimental::Scalar axis = CastPyArg2Scalar(axis_obj, "concat", 1); + + // Call ir static api + auto out = paddle::dialect::concat(x, axis.to()); + return ToPyObject(out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) { try { VLOG(6) << "Add full op into program"; diff --git a/paddle/fluid/pybind/static_op_function.h b/paddle/fluid/pybind/static_op_function.h index 22bee5c344837d24acecb4f6ef99d865ab09adec..02d4777eeef052a3b4fe772f5bc705b4318cc479 100644 --- a/paddle/fluid/pybind/static_op_function.h +++ b/paddle/fluid/pybind/static_op_function.h @@ -28,6 +28,7 @@ PyObject *static_api_add_n(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_sum(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs); +PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs); } // namespace pybind diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 671182f7c3040d163886f6d324cf50752db8a04b..67fafbae389c5e42bf63c492d81bd8caddb7762b 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -216,11 +216,11 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): return effective_ops, uneffective_ops -def update_no_grad_set_after_purne( +def update_no_grad_set_after_prune( block, effective_forward_op, no_grad_set, inputs, outputs ): ''' - update no_grad_set after forward purne + update no_grad_set after forward prune from inputs to outputs add value not in the path to no_grad_set, from outputs to inputs add value not in the path to no_grad_set, @@ -338,19 +338,19 @@ def append_backward_ops( else continue to next op. ''' - def make_output_grad(op, split_op): + def make_output_grad(op): zero_flag = [False] * op.num_results() for i, value in enumerate(op.results()): if ( value not in state.value_to_valuegrad or state.value_to_valuegrad[value] is None ): - if split_op is not None and value == split_op.operand_source(0): + if value.first_use().owner().name() == "builtin.split": # pattern case: # this fwd_op's output is vectorType, it will split to # Type by builtin.split op, so need get from split op's ouput split_zero_flag, split_output_grad = make_output_grad( - split_op, None + value.first_use().owner() ) zero_flag[i] = all(split_zero_flag) grad_value = [op_list[0] for op_list in split_output_grad] @@ -400,11 +400,11 @@ def append_backward_ops( output_grad = state.value_to_valuegrad[value][0] return zero_flag, output_grad - def make_input_stopgradient(combine_op, op): + def make_input_stopgradient(op): input_grad_stopgradient_list = [] for input in op.operands_source(): - if combine_op is not None and input == combine_op.result(0): - stop_gradient = make_input_stopgradient(None, combine_op) + if input.get_defining_op().name() == "builtin.combine": + stop_gradient = make_input_stopgradient(input.get_defining_op()) input_grad_stopgradient_list.append( [info[0] for info in stop_gradient] ) @@ -413,13 +413,14 @@ def append_backward_ops( input_grad_stopgradient_list.append([True]) else: input_grad_stopgradient_list.append([False]) - return input_grad_stopgradient_list - def update_input_grad_map(combine_op, op, input_grad_list): + def update_input_grad_map(op, input_grad_list): for i, input in enumerate(op.operands_source()): - if combine_op is not None and input == combine_op.reslut(0): - update_input_grad_map(None, combine_op, input_grad_list[i]) + if input.get_defining_op().name() == "builtin.combine": + update_input_grad_map( + input.get_defining_op(), input_grad_list[i] + ) else: input_grad = input_grad_list[i] if isinstance(input_grad, list): @@ -427,48 +428,24 @@ def append_backward_ops( else: state.value_to_valuegrad[input].append([input_grad]) - # make op to op pattern, there are four patterns: + # there are four patterns: # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) # [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) # [op4] (op4's inputs and outputs are not vectorType) # einsum has twp vectorType outputs, special pattern - pattern_effective_op_list = [] - for idx, op in enumerate(effective_forward_op): - if op.name() == "builtin.combine": - pattern_effective_op_list.append([op]) - pattern_effective_op_list[-1].append(effective_forward_op[idx + 1]) - elif op.name() == "builtin.split": - pattern_effective_op_list[-1].append(op) - else: - if ( - not pattern_effective_op_list - or op not in pattern_effective_op_list[-1] - ): - pattern_effective_op_list.append([op]) - - for op_pattern in pattern_effective_op_list: - combine_op = None - split_op = None - if len(op_pattern) == 1: - op = op_pattern[0] - elif len(op_pattern) == 2: - if op_pattern[0] == 'builtin.combine': - combine_op = op_pattern[0] - op = op_pattern[1] - else: - op = op_pattern[0] - split_op = op_pattern[1] - else: - combine_op = op_pattern[0] - op = op_pattern[1] - split_op = op_pattern[2] + clear_effective_forward_op = [] + + for op in effective_forward_op: + if op.name() != "builtin.combine" and op.name() != "builtin.split": + clear_effective_forward_op.append(op) + for op in clear_effective_forward_op: if paddle.framework.core.has_vjp(op): # prepare output_grad output_grad_list = [] # (opresult) - zero_flag, output_grad = make_output_grad(op, split_op) + zero_flag, output_grad = make_output_grad(op) output_grad_list.append(output_grad) # all(zero_flag) support this op has no contribution for grad @@ -477,9 +454,7 @@ def append_backward_ops( continue # prepare input_grad stop_gradient info. - input_grad_stopgradient_list = make_input_stopgradient( - combine_op, op - ) + input_grad_stopgradient_list = make_input_stopgradient(op) # create grad_op before_ops_num = len(block.ops) @@ -495,7 +470,7 @@ def append_backward_ops( ) # update input_grad map - update_input_grad_map(combine_op, op, input_grad_list) + update_input_grad_map(op, input_grad_list) else: if op.num_operands() == 0 and op.num_results() != 0: @@ -526,17 +501,23 @@ def append_backward_ops( state.op_to_opgrad[op] = [] -def create_backward_purne_set(inputs, outputs, no_grad_set, state): +def create_backward_prune_set(inputs, outputs, no_grad_set, state): outputs_set = set() for input in inputs: - if state.value_to_valuegrad[input] != []: - outputs_set.add(state.value_to_valuegrad[input][0][0]) - + for item in input.first_use().owner().operands_source(): + if state.value_to_valuegrad[item] != []: + outputs_set.add(state.value_to_valuegrad[item][0][0]) inputs_set = set() for output in outputs: if state.value_to_valuegrad[output] != []: inputs_set.add(state.value_to_valuegrad[output][0][0]) + inputs_set_tmp = set() + for out_grad in inputs_set: + for item in out_grad.first_use().owner().operands_source(): + inputs_set_tmp.add(item) + inputs_set.update(inputs_set_tmp) + no_gradvar_set = set() # grad_value of value in no_grad_set for key in state.value_to_valuegrad: if key in no_grad_set: @@ -590,31 +571,31 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): effective_forward_op, _ = prune_ops( block.ops, inputs_set, outputs_set, no_grad_set ) - update_no_grad_set_after_purne( + update_no_grad_set_after_prune( block, effective_forward_op, no_grad_set, inputs, complete_outputs ) - sorted_effective_forward_op = inverse_sort_op(effective_forward_op) + inverse_effective_forward_op = inverse_sort_op(effective_forward_op) append_backward_ops( - block, sorted_effective_forward_op, no_grad_set, backward_ops, state + block, inverse_effective_forward_op, no_grad_set, backward_ops, state ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) - outputs_set, inputs_set, no_gradvar_set = create_backward_purne_set( + outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( inputs, complete_outputs, no_grad_set, state ) _, remove_ops = prune_ops( backward_ops, inputs_set, outputs_set, no_gradvar_set ) - state.turn_map() + state.turn_map() for bwd_op in inverse_sort_op(remove_ops): remove_op(block, bwd_op, state) + state.turn_map() input_grad_map = state.value_to_valuegrad - state.turn_map() return input_grad_map diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 8053c86cba9de98f66c148564d7edd157b774c37..98da3330238a1bf9da5fc411f1484f1340930d58 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1120,6 +1120,10 @@ def concat(x, axis=0, name=None): input = [t for t in input if t.shape.count(0) == 0] return _C_ops.concat(input, axis) else: + if paddle.ir.core._use_new_ir_api(): + if not isinstance(input, Variable): + input = [t for t in input if t.shape.count(0) == 0] + return paddle._ir_ops.concat(input, axis) check_type(input, 'input', (list, tuple, Variable), 'concat') if not isinstance(input, Variable): for id, x in enumerate(input): diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index c4409e5dd2a0cd8a67382d975dfc9bc7dac1a6d4..7ceb38ffcbfb36bf580b0a94130044ee01b724a5 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -205,6 +205,68 @@ TEST(VJP, MeanBackwardTest) { ASSERT_EQ(grad_out_tensor.data()[3], 0.25); } +TEST(VJP, ConcatBackwardTest) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + paddle::dialect::APIBuilder::Instance().SetProgram(&program); + + std::shared_ptr builder = + paddle::dialect::APIBuilder::Instance().GetBuilder(); + paddle::dialect::FullOp op1 = builder->Build( + std::vector{1, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + std::vector combine_input{{op1.out(), op1.out()}}; + ir::CombineOp op2 = builder->Build(combine_input); + paddle::dialect::ConcatOp op3 = + builder->Build(op2.out(), 0); + + paddle::dialect::FullOp op4 = builder->Build( + std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + std::vector> stop_gradients{{false, false}}; + std::vector> out_grads{{op4.out()}}; + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.concat"); + auto concat_vjp_interface_impl = + op2_info.GetInterfaceImpl(); + concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + + ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + std::stringstream os; + os << reinterpret_cast( + const_cast(test_core.Impl())); + std::string prefix_str = os.str(); + test_core.SetSkipGcVars({prefix_str + "_inner_var_3", + prefix_str + "_inner_var_7", + prefix_str + "_inner_var_8"}); + test_core.Run({}); + auto out_tensor = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_3") + ->Get(); + auto grad_out_tensor_0 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_7")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_7") + ->Get(); + auto grad_out_tensor_1 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_8") + ->Get(); + ASSERT_EQ(out_tensor.data()[0], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[1], 1.0); + ASSERT_EQ(grad_out_tensor_1.data()[0], 1.0); + ASSERT_EQ(grad_out_tensor_1.data()[1], 1.0); +} + TEST(VJP, AddBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ir::Program program((ctx)); diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/new_ir/test_build_op.py index c49b0ae14939cefd5d95a914b696ea0f99edf4f5..e54e493b99a773de9fa5c3fe3bc3f6c070472691 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/new_ir/test_build_op.py @@ -102,5 +102,24 @@ class TestBuildOp3(unittest.TestCase): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) +class TestBuildOp4(unittest.TestCase): + def test_build_concat_op(self): + newir_program = get_ir_program() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.concat([tanh_out, tanh_out], 0) + self.assertEqual(out.get_defining_op().name(), "pd.concat") + self.assertEqual( + out.get_defining_op() + .operands()[0] + .source() + .get_defining_op() + .name(), + "builtin.combine", + ) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + if __name__ == "__main__": unittest.main() diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index e6b47bbcd106ed54d61d861a6bdf254b9333eac8..63e5bdbc9e4c703993b818a59745e6f722bc0a53 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -116,7 +116,6 @@ def get_ir_program_1(): class TesBackward_2(unittest.TestCase): def test_add_n(self): - # test add_n op newir_program = get_ir_program_1() input_x = newir_program.block().ops[-3].operand(0).source() @@ -130,6 +129,43 @@ class TesBackward_2(unittest.TestCase): self.assertEqual( newir_program.block().ops[-2].name(), "builtin.combine" ) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + def test_concat(self): + newir_program = get_ir_program_1() + input_x = newir_program.block().ops[-3].operand(0).source() + + add_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.concat([add_out, add_out]) + input_grad = grad(out, input_x) + + ops_name = [ + "pd.data", + "pd.data", + "pd.tanh", + "pd.tanh", + "pd.add", + "builtin.combine", + "pd.full", + "pd.concat", + "pd.full", + "builtin.combine", + "pd.concat_grad", + "builtin.split", + "builtin.combine", + "pd.add_n", + "pd.add_grad", + "pd.tanh_grad", + "pd.tanh_grad", + "builtin.combine", + "pd.add_n", + ] + for i, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), ops_name[i]) + + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) if __name__ == "__main__":