未验证 提交 70378584 编写于 作者: Z zqw_1997 提交者: GitHub

add sqrt_comp_grad composite rule (#49769)

上级 292f3f77
......@@ -180,5 +180,15 @@ void divide_grad(const Tensor& x,
}
} // indicate we will compute dx
}
template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp);
x_grad->set_impl(x_grad_tmp.impl());
}
}
} // namespace prim
} // namespace paddle
......@@ -1254,6 +1254,7 @@
param : [out]
kernel :
func : sqrt_grad
composite : sqrt_grad(out, out_grad, x_grad)
backward : sqrt_double_grad
inplace : (out_grad -> x_grad)
......
......@@ -14,3 +14,4 @@ set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_sum_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_sqrt_grad PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import autograd
import autograd.numpy
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
@param.parameterized_class(
('primal', 'cotangent', 'dtype'),
[
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32),
],
)
class TestSqrtGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_sqrt_grad_comp(self):
def actual(primal, cotangent):
paddle.disable_static()
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
x.stop_gradient = False
v = paddle.to_tensor(
cotangent, dtype='float32', stop_gradient=False
)
y = paddle.sqrt(x)
return paddle.grad(y, x, v, create_graph=True, retain_graph=True)[0]
def desired(primal, cotangent):
return autograd.make_vjp(autograd.numpy.sqrt)(primal)[0](cotangent)
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent),
desired=desired(self.primal, self.cotangent),
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -15,3 +15,4 @@ set_tests_properties(test_comp_sum_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sqrt_grad PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.fluid import core
core.set_prim_enabled(True)
import autograd
import autograd.numpy
import numpy as np
import parameterized as param
import paddle
@param.parameterized_class(
('primal', 'cotangent', 'dtype'),
[
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32),
],
)
class TestSqrtGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_sqrt_grad_comp(self):
def actual(primal, cotangent):
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
y = paddle.sqrt(x)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
)[0]
def desired(primal, cotangent):
return autograd.make_vjp(autograd.numpy.sqrt)(primal)[0](cotangent)
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent),
desired=desired(self.primal, self.cotangent),
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册