diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 19077d29266b639eb7ff728c9cec1fdcad190e97..19898e0c562f49afc14b7b14f7ebcd02b877ef04 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -180,5 +180,15 @@ void divide_grad(const Tensor& x, } } // indicate we will compute dx } + +template +void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto div_x = full(phi::vectorize(out.dims()), 0.5); + auto tmp = divide(div_x, out); + auto x_grad_tmp = multiply(out_grad, tmp); + x_grad->set_impl(x_grad_tmp.impl()); + } +} } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 0858901f1e5e7506f60aca4766f122940716d48b..277ef04c6888b91a83690b4f2cc2b2e9d390d2b9 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1254,6 +1254,7 @@ param : [out] kernel : func : sqrt_grad + composite : sqrt_grad(out, out_grad, x_grad) backward : sqrt_double_grad inplace : (out_grad -> x_grad) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt index c126c13a3901fe4d66bbac1c3cc1dde6be46d13c..7d5fc1006d1e8ef175256a2a196482c199b98c48 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt @@ -14,3 +14,4 @@ set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_eager_sum_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_eager_sqrt_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..7abb91e912ac4fa74cd5b8d4ca1e9d59d4ca219b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import autograd +import autograd.numpy +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + +core.set_prim_enabled(True) + + +@param.parameterized_class( + ('primal', 'cotangent', 'dtype'), + [ + (np.random.rand(10, 10), np.random.rand(10, 10), np.float32), + ], +) +class TestSqrtGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal = cls.primal.astype(cls.dtype) + cls.cotangent = cls.cotangent.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_sqrt_grad_comp(self): + def actual(primal, cotangent): + paddle.disable_static() + x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) + x.stop_gradient = False + v = paddle.to_tensor( + cotangent, dtype='float32', stop_gradient=False + ) + y = paddle.sqrt(x) + return paddle.grad(y, x, v, create_graph=True, retain_graph=True)[0] + + def desired(primal, cotangent): + return autograd.make_vjp(autograd.numpy.sqrt)(primal)[0](cotangent) + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent), + desired=desired(self.primal, self.cotangent), + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt index d267bd627a96a449f78034322f55a5e9de3992bf..96f0a86291a8b0027e39e7153f23b9b8b255902e 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt @@ -15,3 +15,4 @@ set_tests_properties(test_comp_sum_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60) set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_sqrt_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2eae9c86e25fba10f97d5e64ba3e4098abb2a671 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.fluid import core + +core.set_prim_enabled(True) + +import autograd +import autograd.numpy +import numpy as np +import parameterized as param + +import paddle + + +@param.parameterized_class( + ('primal', 'cotangent', 'dtype'), + [ + (np.random.rand(10, 10), np.random.rand(10, 10), np.float32), + ], +) +class TestSqrtGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal = cls.primal.astype(cls.dtype) + cls.cotangent = cls.cotangent.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_sqrt_grad_comp(self): + def actual(primal, cotangent): + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal', primal.shape, primal.dtype) + x.stop_gradient = False + v = paddle.static.data( + 'cotangent', cotangent.shape, cotangent.dtype + ) + y = paddle.sqrt(x) + x_cotangent = paddle.static.gradients(y, x, v) + exe = paddle.static.Executor() + exe.run(sp) + return exe.run( + program=mp, + feed={'primal': primal, 'cotangent': cotangent}, + fetch_list=[x_cotangent[0].name], + )[0] + + def desired(primal, cotangent): + return autograd.make_vjp(autograd.numpy.sqrt)(primal)[0](cotangent) + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent), + desired=desired(self.primal, self.cotangent), + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main()