From a4689c9086e749185f85908824ea4c8719572c48 Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Thu, 2 Mar 2023 12:34:15 +0800 Subject: [PATCH] Add concat grad cinn (#50972) * [cinn] concat_grad * [cinn] concat_grad * [cinn] concat_grad build success * [Add PGLBOX] fix unnitest * [Add PGLBOX] fix unnitest * [Add PGLBOX] fix codestyle * [cinn] update by comments * [cinn] update by comment * [cinn] add axis check --- paddle/fluid/operators/concat_op.cc | 39 +++++++++++ .../composite_backward_api.h | 23 +++++++ .../prim/api/manual_prim/eager_prim_api.cc | 8 +++ .../prim/api/manual_prim/prim_manual_api.h | 8 +++ .../prim/api/manual_prim/static_prim_api.cc | 27 ++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_concat_op.py | 66 +++++++++++++++++-- python/paddle/tensor/manipulation.py | 1 + 9 files changed, 168 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 21e4bfcf709..869eae74a18 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -21,6 +21,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" @@ -153,6 +156,41 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker { } }; +class ConcatCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + std::vector input = + this->GetMultiForwardInput("X"); + paddle::optional tensor_axis = + this->GetOptionalSingleForwardInput("AxisTensor"); + paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); + std::vector input_grad = + this->GetMultiForwardInput("X"); + + std::vector input_grad_ptr( + input_grad.size()); + for (auto sub_tensor : input_grad) { + input_grad_ptr.push_back(&sub_tensor); + } + int axis = static_cast(this->Attr("axis")); + std::vector dx_ptr = + this->GetOutputPtr(input_grad_ptr); + std::vector dx_name = this->GetOutputName(input_grad); + + VLOG(6) << "Runing concat_grad composite func"; + if (tensor_axis.is_initialized()) { + PADDLE_THROW(platform::errors::Unimplemented( + "We don't support dynamic index from tensor for concat composite " + "grad for now. ")); + } else { + prim::concat_grad(input, out_grad, axis, dx_ptr); + } + this->RecoverOutputName(input_grad, dx_name); + } +}; + template class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker { public: @@ -184,6 +222,7 @@ REGISTER_OPERATOR(concat, ConcatInferShapeFunctor); REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, + ops::ConcatCompositeGradOpMaker, ops::ConcatDoubleGradOpMaker, ops::ConcatDoubleGradOpMaker, ops::ConcatOpGradNoNeedBufferVarInferer); diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index aa92cd17c34..a7fc0a4a930 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -284,6 +284,29 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { } } +template +void concat_grad(const std::vector& x, + const Tensor& out_grad, + const Scalar& axis, + std::vector x_grad) { + int axis_value = axis.to(); + int rank = x[0].dims().size(); + if (axis_value < 0) { + axis_value = axis_value + rank; + } + axis_value = axis_value > 0 ? axis_value : 0; + std::vector sections; + int x_num = x.size(); + for (int i = 0; i < x_num; ++i) { + sections.push_back(x[i].dims()[axis_value]); + } + std::vector x_grad_tmp = + split(out_grad, phi::IntArray(sections), axis); + for (int i = 0; i < x_num; ++i) { + set_output(x_grad_tmp.at(i), x_grad.at(i)); + } +} + template void multiply_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/fluid/prim/api/manual_prim/eager_prim_api.cc b/paddle/fluid/prim/api/manual_prim/eager_prim_api.cc index 6e35f67ad51..196bfc20bf8 100644 --- a/paddle/fluid/prim/api/manual_prim/eager_prim_api.cc +++ b/paddle/fluid/prim/api/manual_prim/eager_prim_api.cc @@ -33,6 +33,14 @@ Tensor full(const IntArray& shape, VLOG(4) << "Eager Prim API full_ad_func call"; return ::full_ad_func(shape, value, dtype, place); } + +template <> +std::vector split(const Tensor& x, + const IntArray& sections, + const Scalar& axis) { + VLOG(4) << "Eager Prim API split_ad_func call"; + return ::split_ad_func(x, sections, axis); +} template <> Tensor cast(const Tensor& x, DataType dtype) { return ::cast_ad_func(x, dtype); diff --git a/paddle/fluid/prim/api/manual_prim/prim_manual_api.h b/paddle/fluid/prim/api/manual_prim/prim_manual_api.h index 91c51ef74c2..0e60955a69e 100644 --- a/paddle/fluid/prim/api/manual_prim/prim_manual_api.h +++ b/paddle/fluid/prim/api/manual_prim/prim_manual_api.h @@ -14,6 +14,7 @@ #pragma once +#include #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" @@ -36,7 +37,14 @@ Tensor full(const IntArray& shape, const Scalar& value, DataType dtype = DataType::FLOAT32, const Place& place = CPUPlace()); + +template +std::vector split(const Tensor& x, + const IntArray& sections, + const Scalar& axis); + template Tensor cast(const Tensor& x, DataType dtype); + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual_prim/static_prim_api.cc b/paddle/fluid/prim/api/manual_prim/static_prim_api.cc index dcd4ab62232..d137183db81 100644 --- a/paddle/fluid/prim/api/manual_prim/static_prim_api.cc +++ b/paddle/fluid/prim/api/manual_prim/static_prim_api.cc @@ -120,6 +120,33 @@ Tensor full(const IntArray& shape, op->InferShape(*block); return out; } + +template <> +std::vector split(const Tensor& x, + const IntArray& sections, + const Scalar& axis) { + int elem_num = sections.size(); + std::vector outs_name; + std::vector outs; + for (int i = 0; i < elem_num; ++i) { + Tensor out = empty({}, x.dtype(), paddle::Place()); + std::string out_name = + std::static_pointer_cast(out.impl())->Name(); + outs_name.push_back(std::move(out_name)); + outs.push_back(out); + } + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("split"); + op->SetAttr("sections", sections.GetData()); + op->SetAttr("axis", axis.to()); + op->SetOutput("Out", outs_name); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return outs; +} + template <> Tensor cast(const Tensor& x, DataType dtype) { Tensor out = empty({}, DataType::FLOAT32, paddle::Place()); diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index d2cf387df76..22106521f29 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -207,6 +207,7 @@ param : [x] kernel : func : concat_grad + composite : concat_grad(x, out_grad, axis, x_grad) no_need_buffer : x backward : concat_double_grad diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 35c8fecaa74..e754729e3ce 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1214,6 +1214,7 @@ set(TEST_CINN_OPS test_activation_op test_full_like_op test_fill_any_like_op + test_concat_op test_top_k_v2_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index abc565b8375..401ee260d7f 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -32,6 +32,8 @@ class TestConcatOp(OpTest): def setUp(self): self.op_type = "concat" self.python_api = paddle.concat + self.prim_op_type = "prim" + self.enable_cinn = False self.dtype = self.get_dtype() self.init_test_data() self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} @@ -61,13 +63,13 @@ class TestConcatOp(OpTest): def test_check_grad(self): if self.dtype == np.uint16: place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['x0'], 'Out') - self.check_grad_with_place(place, ['x1'], 'Out') - self.check_grad_with_place(place, ['x2'], 'Out') + self.check_grad_with_place(place, ['x0'], 'Out', check_prim=True) + self.check_grad_with_place(place, ['x1'], 'Out', check_prim=True) + self.check_grad_with_place(place, ['x2'], 'Out', check_prim=True) else: - self.check_grad(['x0'], 'Out', check_eager=True) - self.check_grad(['x1'], 'Out', check_eager=True) - self.check_grad(['x2'], 'Out', check_eager=True) + self.check_grad(['x0'], 'Out', check_eager=True, check_prim=True) + self.check_grad(['x1'], 'Out', check_eager=True, check_prim=True) + self.check_grad(['x2'], 'Out', check_eager=True, check_prim=True) def init_test_data(self): if self.dtype == np.uint16: @@ -133,6 +135,8 @@ class TestConcatOp6(TestConcatOp): self.op_type = "concat" self.dtype = self.get_dtype() self.python_api = paddle.concat + self.prim_op_type = "prim" + self.enable_cinn = False self.init_test_data() self.lod = [[20, 80]] self.out_lod = [[20, 80, 20, 80, 20, 80]] @@ -167,6 +171,54 @@ class TestConcatOp6(TestConcatOp): self.axis = 0 +class TestConcatOp7(TestConcatOp): + def setUp(self): + self.op_type = "concat" + self.python_api = paddle.concat + self.prim_op_type = "prim" + self.enable_cinn = True + self.dtype = self.get_dtype() + self.init_test_data() + self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} + self.attrs = {'axis': self.axis} + if self.axis < 0: + self.actual_axis = self.axis + len(self.x0.shape) + self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0 + else: + self.actual_axis = self.axis + + self.outputs = { + 'Out': np.concatenate( + (self.x0, self.x1, self.x2), axis=self.actual_axis + ) + } + + def get_dtype(self): + return "float64" + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['x0'], 'Out', check_eager=True, check_prim=True) + self.check_grad(['x1'], 'Out', check_eager=True, check_prim=True) + self.check_grad(['x2'], 'Out', check_eager=True, check_prim=True) + + def init_test_data(self): + if self.dtype == np.uint16: + x0 = np.random.random((5, 1, 4, 5)).astype(np.float32) + self.x0 = convert_float_to_uint16(x0) + x1 = np.random.random((5, 2, 4, 5)).astype(np.float32) + self.x1 = convert_float_to_uint16(x1) + x2 = np.random.random((5, 3, 4, 5)).astype(np.float32) + self.x2 = convert_float_to_uint16(x2) + else: + self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype) + self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype) + self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype) + self.axis = 1 + + def create_test_AxisTensor(parent): class TestConcatAxisTensor(parent): def setUp(self): @@ -175,6 +227,8 @@ def create_test_AxisTensor(parent): self.dtype = self.get_dtype() self.init_test_data() + self.prim_op_type = "prim" + self.enable_cinn = False self.inputs = { 'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)], 'AxisTensor': np.array([self.axis]).astype("int32"), diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index cd45c33d74e..fed360474c8 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1136,6 +1136,7 @@ def concat(x, axis=0, name=None): 'int64', 'int8', 'unit8', + 'uint16', ], 'concat', ) -- GitLab