From 71ab8ae9870a43658129c89d188ca233586154c5 Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 15 Jan 2021 12:50:50 +0800 Subject: [PATCH] Support double backward rsqrt (#29589) (#30431) --- paddle/fluid/operators/activation_op.cc | 48 ++++++++++ paddle/fluid/operators/activation_op.cu | 14 +++ paddle/fluid/operators/activation_op.h | 91 ++++++++++++++++++- .../unittests/test_activation_nn_grad.py | 24 +++++ 4 files changed, 176 insertions(+), 1 deletion(-) mode change 100755 => 100644 paddle/fluid/operators/activation_op.h diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 26b4ed71e00..81146854285 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -886,6 +886,25 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { } }; +// rsqrt Grad: dx = -0.5 * dy * y * y * y +// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * ddx +template +class RsqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { + public: + using ::paddle::framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("rsqrt_grad_grad"); + op->SetInput("Out", this->Input("Out")); + op->SetInput("DX", this->Output(framework::GradVarName("X"))); + op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); + op->SetAttrMap(this->Attrs()); + op->SetOutput("DOut", this->InputGrad("Out")); + op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); + } +}; + // square Grad: dx=2x*dy // square GradGrad: ddy=2x*ddx, dx=2dy*ddx template @@ -1157,6 +1176,35 @@ REGISTER_OP_CPU_KERNEL( ops::SqrtGradGradFunctor>); /* ========================================================================== */ +/* =========================== rsqrt register ============================= + */ +REGISTER_OPERATOR( + rsqrt, ops::ActivationOp, ops::RsqrtOpMaker, ops::ActivationOpInferVarType, + ops::ActivationGradOpMaker::FwdDeps(), + paddle::framework::OpDesc>, + ops::ActivationGradOpMaker::FwdDeps(), + paddle::imperative::OpBase>, + ops::ActFwdInplaceInferer); +REGISTER_OPERATOR(rsqrt_grad, ops::ActivationOpGrad, + ops::ActivationGradOpInplaceInferer, + ops::RsqrtDoubleGradMaker, + ops::RsqrtDoubleGradMaker); +REGISTER_OPERATOR( + rsqrt_grad_grad, + ops::ActivationOpDoubleGrad::FwdDeps()>, + ops::ActivationDoubleGradOpInplaceInferer); + +REGISTER_ACTIVATION_CPU_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); +REGISTER_OP_CPU_KERNEL( + rsqrt_grad_grad, + ops::RsqrtDoubleGradKernel>, + ops::RsqrtDoubleGradKernel>, + ops::RsqrtDoubleGradKernel>); +/* ========================================================================== */ + /* ========================== square register ============================ */ REGISTER_OPERATOR( square, ops::ActivationOp, ops::SquareOpMaker, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 839776ad58d..1a6d5de18ec 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -85,6 +85,20 @@ REGISTER_OP_CUDA_KERNEL( ops::SqrtGradGradFunctor>); /* ========================================================================== */ +/* =========================== rsqrt register ============================= + */ +REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); + +REGISTER_OP_CUDA_KERNEL( + rsqrt_grad_grad, + ops::RsqrtDoubleGradKernel>, + ops::RsqrtDoubleGradKernel>, + ops::RsqrtDoubleGradKernel>); +/* ========================================================================== */ + /* =========================== square register ============================ */ REGISTER_OP_CUDA_KERNEL( square, diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h old mode 100755 new mode 100644 index f220fe878bf..065c070be7e --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1610,6 +1610,35 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +template +struct RsqrtGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, const framework::Tensor* Out, + const framework::Tensor* ddX, framework::Tensor* ddOut, + framework::Tensor* dOut, const framework::Tensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); + + // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx + if (dOut) { + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); + dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; + } + if (ddOut) { + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); + ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + template struct SquareGradGradFunctor : public BaseActivationFunctor { template @@ -1795,6 +1824,67 @@ class SqrtDoubleGradKernel } }; +// rsqrt Grad: dx = -0.5 * dy * y * y * y +// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx +template +class RsqrtDoubleGradKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor *Out, *dX, *ddX; + Out = dX = ddX = nullptr; + framework::Tensor *ddOut, *dOut; + ddOut = dOut = nullptr; + + // extract ddx(input), ddout(output) + auto ddx_var = ctx.InputVar("DDX"); + auto ddo_var = ctx.OutputVar("DDOut"); + PADDLE_ENFORCE_NOT_NULL( + ddx_var, platform::errors::NotFound( + "Cannot get input Variable DDX, variable name = %s", + ctx.InputName("DDX"))); + ddX = ctx.Input("DDX"); + if (ddo_var) { + ddOut = ctx.Output("DDOut"); + } + PADDLE_ENFORCE_NOT_NULL( + ddX, platform::errors::NotFound( + "Cannot get input Variable DDX, variable name = %s", + ctx.InputName("DDX"))); + + // extract out(input), dout(output) + auto out_var = ctx.InputVar("Out"); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "Cannot get input Variable Out, variable name = %s", + ctx.InputName("Out"))); + auto dout_var = ctx.OutputVar("DOut"); + Out = ctx.Input("Out"); + if (dout_var) { + dOut = ctx.Output("DOut"); + } + + // extract dx(input) + auto dx_var = ctx.InputVar("DX"); + PADDLE_ENFORCE_NOT_NULL( + dx_var, platform::errors::NotFound( + "Cannot get input Variable DX, variable name = %s", + ctx.InputName("DX"))); + if (dx_var) { + dX = ctx.Input("DX"); + } + + if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); + if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); + + auto& place = ctx.template device_context(); + + Functor functor; + functor(place, Out, ddX, ddOut, dOut, dX); + } +}; + template class PowKernel : public framework::OpKernel { public: @@ -1938,7 +2028,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor { __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ - __macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(cos, Cos, CosFunctor, CosGradFunctor); \ diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index b663f0ffc2d..2aad90b3dda 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -125,6 +125,30 @@ class TestSqrtDoubleGradCheck(unittest.TestCase): self.func(p) +class TestRsqrtDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 0.0001 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + x.persistable = True + + y = layers.rsqrt(x) + x_arr = np.random.uniform(0.1, 1, shape).astype(dtype) + + gradient_checker.double_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places = [fluid.CUDAPlace(0)] + for p in places: + self.func(p) + + class TestSquareDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): -- GitLab