未验证 提交 71ab8ae9 编写于 作者: W whs 提交者: GitHub

Support double backward rsqrt (#29589) (#30431)

上级 ae75affd
......@@ -886,6 +886,25 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * ddx
template <typename T>
class RsqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("rsqrt_grad_grad");
op->SetInput("Out", this->Input("Out"));
op->SetInput("DX", this->Output(framework::GradVarName("X")));
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
template <typename T>
......@@ -1157,6 +1176,35 @@ REGISTER_OP_CPU_KERNEL(
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
*/
REGISTER_OPERATOR(
rsqrt, ops::ActivationOp, ops::RsqrtOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(rsqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::RsqrtDoubleGradMaker<paddle::framework::OpDesc>,
ops::RsqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
rsqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<float>>,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<double>>,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== square register ============================ */
REGISTER_OPERATOR(
square, ops::ActivationOp, ops::SquareOpMaker,
......
......@@ -85,6 +85,20 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
*/
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);
REGISTER_OP_CUDA_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::RsqrtGradGradFunctor<float>>,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::RsqrtGradGradFunctor<double>>,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== square register ============================ */
REGISTER_OP_CUDA_KERNEL(
square,
......
......@@ -1610,6 +1610,35 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -1795,6 +1824,67 @@ class SqrtDoubleGradKernel
}
};
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template <typename DeviceContext, typename Functor>
class RsqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *dX, *ddX;
Out = dX = ddX = nullptr;
framework::Tensor *ddOut, *dOut;
ddOut = dOut = nullptr;
// extract ddx(input), ddout(output)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE_NOT_NULL(
ddx_var, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
// extract out(input), dout(output)
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) {
dOut = ctx.Output<framework::Tensor>("DOut");
}
// extract dx(input)
auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE_NOT_NULL(
dx_var, platform::errors::NotFound(
"Cannot get input Variable DX, variable name = %s",
ctx.InputName("DX")));
if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX");
}
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, ddOut, dOut, dX);
}
};
template <typename DeviceContext, typename Functor>
class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
......@@ -1938,7 +2028,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, Cos, CosFunctor, CosGradFunctor); \
......
......@@ -125,6 +125,30 @@ class TestSqrtDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestRsqrtDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0001
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.rsqrt(x)
x_arr = np.random.uniform(0.1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
for p in places:
self.func(p)
class TestSquareDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册