diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index b20d9c46557cfeb157e53b57a89cbf5f073c6f2c..8e38d5787bdadf4abd224386b3a9b03544a262e7 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -681,6 +681,26 @@ class LeakyReluDoubleGradMaker } }; +// sqrt Grad: dx = 0.5 * dy / y +// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx +class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker { + public: + using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { + auto* op = new ::paddle::framework::OpDesc(); + op->SetType("sqrt_grad_grad"); + op->SetInput("Out", Input("Out")); + op->SetInput("DX", Output(framework::GradVarName("X"))); + op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); + op->SetAttrMap(Attrs()); + op->SetOutput("DOut", InputGrad("Out")); + op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + return std::unique_ptr<::paddle::framework::OpDesc>(op); + } +}; + // square Grad: dx=2x*dy // square GradGrad: ddy=2x*ddx, dx=2dy*ddx class SquareDoubleGradMaker @@ -794,6 +814,27 @@ REGISTER_OP_CPU_KERNEL( plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor>); /* ========================================================================== */ +/* =========================== sqrt register ============================= */ +REGISTER_OPERATOR( + sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType, + ops::ActivationGradOpDescMaker::FwdDeps()>, + paddle::framework::SingleOpInplaceInToOut); +REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, + paddle::framework::SingleOpInplaceInToOut, + ops::SqrtDoubleGradMaker); +REGISTER_OPERATOR( + sqrt_grad_grad, + ops::ActivationOpDoubleGrad::FwdDeps()>); +REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); +REGISTER_OP_CPU_KERNEL( + sqrt_grad_grad, ops::SqrtDoubleGradKernel>, + ops::SqrtDoubleGradKernel>, + ops::SqrtDoubleGradKernel>); +/* ========================================================================== */ + /* ========================== square register ============================ */ REGISTER_OPERATOR( square, ops::ActivationOp, ops::SquareOpMaker, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 63c4f0a887e2e565f71eb667f93ab399fa5a630c..25514186de9e424a46131e5b238b215c22911b38 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -60,6 +60,19 @@ REGISTER_OP_CUDA_KERNEL( ops::ReluGradGradFunctor>); /* ========================================================================== */ +/* =========================== sqrt register ============================= */ +REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); + +REGISTER_OP_CUDA_KERNEL( + sqrt_grad_grad, + ops::SqrtDoubleGradKernel>, + ops::SqrtDoubleGradKernel>, + ops::SqrtDoubleGradKernel>); +/* ========================================================================== */ + /* =========================== square register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(square, Square, SquareFunctor, SquareGradFunctor); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index f2eee754b4e7abe3e73acdb560a3c08e8f58cef2..5a4fb0828a7328d40490006f89a89581cf1fb3df 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1359,6 +1359,28 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct SqrtGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, const framework::Tensor* Out, + const framework::Tensor* ddX, framework::Tensor* ddOut, + framework::Tensor* dOut, const framework::Tensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); + auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + if (ddOut) { + auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + ddout.device(*d) = ddx * static_cast(0.5) / out; + } + if (dOut) { + auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); + auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); + dout.device(*d) = dx * ddx * static_cast(-1) / out; + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + template struct SquareGradGradFunctor : public BaseActivationFunctor { template @@ -1433,8 +1455,8 @@ class SquareDoubleGradKernel ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); - dX->mutable_data(X->dims(), ctx.GetPlace()); - ddOut->mutable_data(ctx.GetPlace()); + if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); + if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); @@ -1443,6 +1465,61 @@ class SquareDoubleGradKernel } }; +template +class SqrtDoubleGradKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor *Out, *dX, *ddX; + Out = dX = ddX = nullptr; + framework::Tensor *ddOut, *dOut; + ddOut = dOut = nullptr; + + // extract ddx(input), ddout(output) + auto ddx_var = ctx.InputVar("DDX"); + auto ddo_var = ctx.OutputVar("DDOut"); + PADDLE_ENFORCE(ddx_var != nullptr, + "Cannot get input Variable DDX, variable name = %s", + ctx.op().Input("DDX")); + ddX = ctx.Input("DDX"); + if (ddo_var) { + ddOut = ctx.Output("DDOut"); + } + PADDLE_ENFORCE(ddX != nullptr, + "Cannot get input Variable DDX, variable name = %s", + ctx.op().Input("DDX")); + + // extract out(input), dout(output) + auto out_var = ctx.InputVar("Out"); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get input Variable Out, variable name = %s", + ctx.op().Input("Out")); + auto dout_var = ctx.OutputVar("DOut"); + Out = ctx.Input("Out"); + if (dout_var) { + dOut = ctx.Output("DOut"); + } + + // extract dx(input) + auto dx_var = ctx.InputVar("DX"); + PADDLE_ENFORCE(dx_var != nullptr, + "Cannot get input Variable DX, variable name = %s", + ctx.op().Input("DX")); + if (dx_var) { + dX = ctx.Input("DX"); + } + + if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); + if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); + + auto& place = ctx.template device_context(); + + Functor functor; + functor(place, Out, ddX, ddOut, dOut, dX); + } +}; + } // namespace operators } // namespace paddle @@ -1454,7 +1531,6 @@ class SquareDoubleGradKernel __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ - __macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); \ __macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \ __macro(abs, Abs, AbsFunctor, AbsGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index be0f4b239b6855e0f95f7c399abcfbc8ab64962a..7036eb8f17035fa8e9ae3ad1ba482bd319e5195f 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -88,7 +88,33 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase): def test_grad(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) + places = [fluid.CUDAPlace(0)] + for p in places: + self.func(p) + + +class TestSqrtDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [7, 9] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + x.persistable = True + + y = layers.sqrt(x) + x_arr = np.random.uniform(0.1, 1, shape).astype(dtype) + + gradient_checker.double_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps, atol=1e-3) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places = [fluid.CUDAPlace(0)] + for p in places: + self.func(p) class TestConvDoubleGradCheck(unittest.TestCase):