diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index cd695511d31f22994b344a96c149b012476e7a14..afef765c6ff71ce9f1a97e915d4f933558c138d2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" @@ -63,6 +64,35 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker { } }; +class ReduceSumCompositeGradOpMaker : public prim::GradCompositeOpMakerBase { + public: + using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; + void Apply() override { + // get inputs + paddle::experimental::Tensor x = this->GetSingleForwardInput("X"); + paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); + + // get attr + std::vector axis = this->Attr>("dim"); + bool keep_dim = this->Attr("keep_dim"); + bool reduce_all = this->Attr("reduce_all"); + // get output + paddle::experimental::Tensor x_grad_t = this->GetSingleInputGrad("X"); + + // get output ptr + paddle::experimental::Tensor* x_grad = this->GetOutputPtr(&x_grad_t); + + // get output orginal name + std::string x_grad_name = this->GetOutputName(x_grad_t); + + // call composite backward func + prim::sum_grad( + x, out_grad, axis, keep_dim, reduce_all, x_grad); + // recover output name + this->RecoverOutputName(x_grad_t, x_grad_name); + } +}; + template class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker { public: @@ -114,6 +144,7 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceSumVarTypeInference, ops::ReduceSumOpGradMaker, ops::ReduceSumOpGradMaker, + ops::ReduceSumCompositeGradOpMaker, ReduceSumInferShapeFunctor); REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index fa0d3d640d13b6a684182af59072d9245a692ad4..19077d29266b639eb7ff728c9cec1fdcad190e97 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -15,11 +15,17 @@ #pragma once #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/ddim.h" + namespace paddle { namespace prim { - -// This function should have as same signature as phi, which defined in -// paddle/phi/api/backward/backward_api.h +using Tensor = paddle::experimental::Tensor; +using IntArray = + paddle::experimental::IntArrayBase; +// using IntArray = paddle::experimental::IntArray; +// This function should have as same signature as phi, which defined in +// paddle/phi/api/backward/backward_api.h template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { auto tmp = pow(out, 2.0); @@ -94,6 +100,44 @@ void add_grad(const Tensor& x, } } +template +void sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (!keepdim) { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 1; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + } + auto out_grad_ = unsqueeze(out_grad, axis_); + x_grad_tmp = expand(out_grad_, x_dim); + } else { + x_grad_tmp = expand(out_grad, x_dim); + } + + x_grad->set_impl(x_grad_tmp.impl()); +} + template void divide_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc b/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc index be123ecde7c5818363365df52ba8e693b46239de..7dac02ea5b203e45adf5166602d6b41d3752194f 100644 --- a/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc +++ b/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc @@ -36,6 +36,16 @@ Tensor multiply(const Tensor& x, const Tensor& y) { return ::multiply_ad_func(x, y); } +template <> +Tensor expand(const Tensor& x, const IntArray& shape) { + return ::expand_ad_func(x, shape); +} + +template <> +Tensor unsqueeze(const Tensor& x, const IntArray& axis) { + return ::unsqueeze_ad_func(x, axis); +} + template <> Tensor divide(const Tensor& x, const Tensor& y) { return ::divide_ad_func(x, y); diff --git a/paddle/fluid/prim/api/manual/prim_api/prim_api.h b/paddle/fluid/prim/api/manual/prim_api/prim_api.h index 8de90919c5b4432f7289e482dd4dbff189d319e6..5465cdb601e9557be56ddd8efd5640ae95abbc19 100644 --- a/paddle/fluid/prim/api/manual/prim_api/prim_api.h +++ b/paddle/fluid/prim/api/manual/prim_api/prim_api.h @@ -21,18 +21,25 @@ namespace prim { using Tensor = paddle::experimental::Tensor; using IntArray = paddle::experimental::IntArray; using Scalar = paddle::experimental::Scalar; + template -Tensor pow(const Tensor& x, const paddle::experimental::Scalar& y); +Tensor pow(const Tensor& x, const Scalar& y); template Tensor scale(const Tensor& X, - const paddle::experimental::Scalar& scale, + const Scalar& scale, float bias, bool bias_after_scale); template Tensor multiply(const Tensor& x, const Tensor& y); +template +Tensor expand(const Tensor& x, const IntArray& shape); + +template +Tensor unsqueeze(const Tensor& x, const IntArray& axis); + template Tensor divide(const Tensor& x, const Tensor& y); diff --git a/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc index 5b393a76df20a82546ed6135c31bae67bd45f01d..0bf14b5955ba5c028d40eb38d6387a9a233e592e 100644 --- a/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc @@ -94,6 +94,23 @@ Tensor multiply(const Tensor& x, const Tensor& y) { return out; } +template <> +Tensor expand(const Tensor& x, const IntArray& shape) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("expand_v2"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + std::vector new_shape(shape.GetData().begin(), shape.GetData().end()); + op->SetAttr("shape", new_shape); + op->CheckAttrs(); + op->InferVarType(block); + return out; +} + template <> Tensor divide(const Tensor& x, const Tensor& y) { // Grad infershape @@ -113,6 +130,23 @@ Tensor divide(const Tensor& x, const Tensor& y) { return out; } +template <> +Tensor unsqueeze(const Tensor& x, const IntArray& axis) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("unsqueeze2"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + std::vector new_shape(axis.GetData().begin(), axis.GetData().end()); + op->SetAttr("axes", new_shape); + op->CheckAttrs(); + op->InferVarType(block); + return out; +} + template <> Tensor full(paddle::experimental::IntArray shape, paddle::experimental::Scalar value, @@ -141,6 +175,7 @@ Tensor full(paddle::experimental::IntArray shape, op->InferShape(*block); return out; } + template <> Tensor sum(Tensor x, paddle::experimental::IntArray axis, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 9810242037bcec02dee243f74a0af434b334440e..50640d313ef33445a4c4db93cdcafd6bfa81a418 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1356,6 +1356,7 @@ param : [x] kernel : func : sum_grad + composite : sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad) no_need_buffer : x backward : sum_double_grad diff --git a/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt index 9ec3cf15454e5a1c455b842a45d58a0452b70bc8..ab3ee7ba1a3ce577786bed4ea3911d3ba0de94eb 100644 --- a/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt @@ -8,5 +8,4 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() -add_subdirectory(comp) add_subdirectory(prim) diff --git a/python/paddle/fluid/tests/unittests/prim/comp/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/comp/CMakeLists.txt deleted file mode 100644 index 72c6bbd7d05e8fdf99fce350ad15c216dcac5c92..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/prim/comp/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -file( - GLOB TEST_OPS - RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "test_*.py") -string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") - -foreach(TEST_OP ${TEST_OPS}) - py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) -endforeach() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt index fc8d0234d5acd21f19322839c8a840e477b3c257..c126c13a3901fe4d66bbac1c3cc1dde6be46d13c 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt @@ -11,5 +11,6 @@ endforeach() set_tests_properties(test_comp_eager_tanh_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_eager_sum_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..5586f7c0ccaf64bd924b30e6053ac9f2932bab30 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core + + +def actual(primal, cotangent, axis, keep_dim): + core.set_prim_enabled(False) + x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) + v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) + y = paddle.sum(x, axis=axis, keepdim=keep_dim) + x_cotangent = paddle.grad(y, x, v, create_graph=True, retain_graph=True) + return x_cotangent[0] + + +def desired(primal, cotangent, axis, keep_dim): + core.set_prim_enabled(True) + x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) + v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) + y = paddle.sum(x, axis=axis, keepdim=keep_dim) + x_cotangent = paddle.grad(y, x, v, create_graph=True, retain_graph=True) + return x_cotangent[0] + + +class TestSumGradComp(unittest.TestCase): + def test_sum_grad_comp_1(self): + self.primal = np.random.rand(10, 10) + self.cotangent = np.random.rand(1) + paddle.disable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, [], False), + desired=desired(self.primal, self.cotangent, [], False), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_2(self): + self.primal = np.random.rand(4, 3, 2) + self.cotangent = np.random.rand(4, 2) + paddle.disable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, 1, False), + desired=desired(self.primal, self.cotangent, 1, False), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_3(self): + self.primal = np.random.rand(4, 3, 2) + self.cotangent = np.random.rand(4, 1, 2) + paddle.disable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, 1, True), + desired=desired(self.primal, self.cotangent, 1, True), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_4(self): + self.primal = np.random.rand(4, 3, 2, 5) + self.cotangent = np.random.rand(4, 1, 2, 1) + paddle.disable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, [1, 3], True), + desired=desired(self.primal, self.cotangent, [1, 3], True), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_5(self): + self.primal = np.random.rand(4, 3, 2, 5) + self.cotangent = np.random.rand(4, 2) + paddle.disable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, [1, 3], False), + desired=desired(self.primal, self.cotangent, [1, 3], False), + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt index 58375c16966d710d0cb32256ae953d62654dc55c..d267bd627a96a449f78034322f55a5e9de3992bf 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt @@ -11,6 +11,7 @@ endforeach() set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_sum_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b2ad03913cb7127c63c1a57c0d1af5944cff2f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core + + +def actual(primal, cotangent, axis, keep_dim): + core.set_prim_enabled(False) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal', primal.shape, primal.dtype) + x.stop_gradient = False + v = paddle.static.data('cotangent', cotangent.shape, cotangent.dtype) + y = paddle.sum(x, axis=axis, keepdim=keep_dim) + x_cotangent = paddle.static.gradients(y, x, None) + exe = paddle.static.Executor() + exe.run(sp) + result = exe.run( + program=mp, + feed={'primal': primal, 'cotangent': cotangent}, + fetch_list=[x_cotangent], + )[0] + return result + + +def desired(primal, cotangent, axis, keep_dim): + core.set_prim_enabled(True) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal', primal.shape, primal.dtype) + x.stop_gradient = False + v = paddle.static.data('cotangent', cotangent.shape, cotangent.dtype) + y = paddle.sum(x, axis=axis, keepdim=keep_dim) + x_cotangent = paddle.static.gradients(y, x, None) + exe = paddle.static.Executor() + exe.run(sp) + result = exe.run( + program=mp, + feed={'primal': primal, 'cotangent': cotangent}, + fetch_list=[x_cotangent], + )[0] + return result + + +class TestSumGradComp(unittest.TestCase): + def test_sum_grad_comp_1(self): + self.primal = np.random.rand(10, 10) + self.cotangent = np.random.rand(1, 1) + paddle.enable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, [], True), + desired=desired(self.primal, self.cotangent, [], True), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_2(self): + self.primal = np.random.rand(4, 3, 2) + self.cotangent = np.random.rand(4, 2) + paddle.enable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, 1, False), + desired=desired(self.primal, self.cotangent, 1, False), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_3(self): + self.primal = np.random.rand(4, 3, 2) + self.cotangent = np.random.rand(4, 1, 2) + paddle.enable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, 1, True), + desired=desired(self.primal, self.cotangent, 1, True), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_4(self): + self.primal = np.random.rand(4, 3, 2, 5) + self.cotangent = np.random.rand(4, 1, 2, 1) + paddle.enable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, [1, 3], True), + desired=desired(self.primal, self.cotangent, [1, 3], True), + rtol=1e-6, + atol=0, + ) + + def test_sum_grad_comp_5(self): + self.primal = np.random.rand(4, 3, 2, 5) + self.cotangent = np.random.rand(4, 2) + paddle.enable_static() + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, [1, 3], False), + desired=desired(self.primal, self.cotangent, [1, 3], False), + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main()