提交 4ef63101 编写于 作者: L lvmengsi 提交者: Kaipeng Deng

Double backward sqrt (#17387)

* double backward sqrt

* refine unittest. test=develop

* refine test. test=develop

* remove alpha in unittest. test=develop
上级 829fcc98
......@@ -681,6 +681,26 @@ class LeakyReluDoubleGradMaker
}
};
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("sqrt_grad_grad");
op->SetInput("Out", Input("Out"));
op->SetInput("DX", Output(framework::GradVarName("X")));
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
op->SetOutput("DOut", InputGrad("Out"));
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
......@@ -794,6 +814,27 @@ REGISTER_OP_CPU_KERNEL(
plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_OPERATOR(
sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
sqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== square register ============================ */
REGISTER_OPERATOR(
square, ops::ActivationOp, ops::SquareOpMaker,
......
......@@ -60,6 +60,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CUDA_KERNEL(
sqrt_grad_grad,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== square register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(square, Square, SquareFunctor,
SquareGradFunctor);
......
......@@ -1359,6 +1359,28 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -1433,8 +1455,8 @@ class SquareDoubleGradKernel
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
dX->mutable_data<T>(X->dims(), ctx.GetPlace());
ddOut->mutable_data<T>(ctx.GetPlace());
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
......@@ -1443,6 +1465,61 @@ class SquareDoubleGradKernel
}
};
template <typename DeviceContext, typename Functor>
class SqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *dX, *ddX;
Out = dX = ddX = nullptr;
framework::Tensor *ddOut, *dOut;
ddOut = dOut = nullptr;
// extract ddx(input), ddout(output)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable DDX, variable name = %s",
ctx.op().Input("DDX"));
ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE(ddX != nullptr,
"Cannot get input Variable DDX, variable name = %s",
ctx.op().Input("DDX"));
// extract out(input), dout(output)
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
ctx.op().Input("Out"));
auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) {
dOut = ctx.Output<framework::Tensor>("DOut");
}
// extract dx(input)
auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE(dx_var != nullptr,
"Cannot get input Variable DX, variable name = %s",
ctx.op().Input("DX"));
if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX");
}
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, ddOut, dOut, dX);
}
};
} // namespace operators
} // namespace paddle
......@@ -1454,7 +1531,6 @@ class SquareDoubleGradKernel
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \
__macro(abs, Abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
......
......@@ -88,7 +88,33 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
places = [fluid.CUDAPlace(0)]
for p in places:
self.func(p)
class TestSqrtDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [7, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.sqrt(x)
x_arr = np.random.uniform(0.1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps, atol=1e-3)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
for p in places:
self.func(p)
class TestConvDoubleGradCheck(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册