未验证 提交 11d3a38f 编写于 作者: K Kaipeng Deng 提交者: GitHub

add double grad for square op (#17173)

* add double grad for square. test=develop

* formax code. test=develop

* fix for grad sum. test=develop

* refine shape. test=develop

* refine extract. test=develop
上级 31536016
...@@ -597,40 +597,31 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); ...@@ -597,40 +597,31 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad : public framework::OperatorWithKernel { class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput("DOut")) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
ctx->ShareDim("Out", "DOut"); if (ctx->HasOutput("DX")) {
ctx->ShareLoD("Out", "DOut"); ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
} }
if (ctx->HasOutput("DDOut")) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
ctx->ShareDim("Out", "DDOut"); if (ctx->HasOutput("DOut")) {
ctx->ShareLoD("Out", "DDOut"); ctx->ShareDim("Out", "DOut");
} ctx->ShareLoD("Out", "DOut");
} }
if (ctx->HasOutput("DDOut")) {
protected: ctx->ShareDim("Out", "DDOut");
framework::OpKernelType GetExpectedKernelType( ctx->ShareLoD("Out", "DDOut");
const framework::ExecutionContext& ctx) const override { }
return GetKernelType(ctx, *this, "Out");
}
};
class LeakyReluDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
} }
} }
...@@ -690,6 +681,33 @@ class LeakyReluDoubleGradMaker ...@@ -690,6 +681,33 @@ class LeakyReluDoubleGradMaker
} }
}; };
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
: public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("square_grad_grad");
op->SetInput("X", Input("X"));
// Out@GRAD: dy
op->SetInput("DOut", Input(framework::GradVarName("Out")));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// X@GRAD: dx
op->SetOutput("DX", InputGrad("X"));
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -727,6 +745,7 @@ namespace plat = paddle::platform; ...@@ -727,6 +745,7 @@ namespace plat = paddle::platform;
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
/* ========================== relu register ============================= */
REGISTER_OPERATOR( REGISTER_OPERATOR(
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType, relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
...@@ -734,7 +753,9 @@ REGISTER_OPERATOR( ...@@ -734,7 +753,9 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut, paddle::framework::SingleOpInplaceInToOut,
ops::ReluDoubleGradMaker); ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(relu_grad_grad, ops::ActivationOpDoubleGrad); REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
...@@ -746,7 +767,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -746,7 +767,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ReluGradGradFunctor<double>>, ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext, ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<plat::float16>>); ops::ReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ======================== leaky relu register ============================ */
REGISTER_OPERATOR( REGISTER_OPERATOR(
leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker, leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
ops::ActivationOpInferVarType, ops::ActivationOpInferVarType,
...@@ -755,7 +778,10 @@ REGISTER_OPERATOR( ...@@ -755,7 +778,10 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut, paddle::framework::SingleOpInplaceInToOut,
ops::LeakyReluDoubleGradMaker); ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(leaky_relu_grad_grad, ops::LeakyReluDoubleGrad); REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor); LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -766,3 +792,30 @@ REGISTER_OP_CPU_KERNEL( ...@@ -766,3 +792,30 @@ REGISTER_OP_CPU_KERNEL(
ops::LeakyReluGradGradFunctor<double>>, ops::LeakyReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel< ops::ActivationDoubleGradKernel<
plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>); plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== square register ============================ */
REGISTER_OPERATOR(
square, ops::ActivationOp, ops::SquareOpMaker,
ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
square_grad_grad,
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
SquareGradFunctor);
REGISTER_OP_CPU_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
...@@ -33,6 +33,7 @@ namespace plat = paddle::platform; ...@@ -33,6 +33,7 @@ namespace plat = paddle::platform;
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
/* ======================== leaky relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor); LeakyReluGradFunctor);
...@@ -44,7 +45,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -44,7 +45,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::LeakyReluGradGradFunctor<double>>, ops::LeakyReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel< ops::ActivationDoubleGradKernel<
plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>); plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -55,3 +58,18 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -55,3 +58,18 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReluGradGradFunctor<double>>, ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext, ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>); ops::ReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== square register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(square, Square, SquareFunctor,
SquareGradFunctor);
REGISTER_OP_CUDA_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -1358,6 +1359,90 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1358,6 +1359,90 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dx.device(*d) = ddx * static_cast<T>(2) * dout;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
// others. Impliment extraction kernel seperately here.
inline void ExtractDoubleGradTensorWithInputDOut(
const framework::ExecutionContext& ctx, const framework::Tensor** X,
const framework::Tensor** ddX, framework::Tensor** dX,
const framework::Tensor** dOut, framework::Tensor** ddOut) {
// extract ddX(output), ddOut(input)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE(ddx_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
ctx.op().Input("DDX"));
*ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
*ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE(*ddX != nullptr,
"Cannot get output tensor DDX, variable name = %s",
ctx.op().Output("DDX"));
// extract x(input), dx(output)
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
ctx.op().Input("X"));
auto dx_var = ctx.OutputVar("DX");
*X = ctx.Input<framework::Tensor>("X");
if (dx_var) {
*dX = ctx.Output<framework::Tensor>("DX");
}
// extract dOut(input)
auto dout_var = ctx.InputVar("DOut");
if (dout_var) {
*dOut = ctx.Input<framework::Tensor>("DOut");
}
}
template <typename DeviceContext, typename Functor>
class SquareDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
dX->mutable_data<T>(X->dims(), ctx.GetPlace());
ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, X, ddX, ddOut, dOut, dX);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -1381,7 +1466,6 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1381,7 +1466,6 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, Log, LogFunctor, LogGradFunctor); \ __macro(log, Log, LogFunctor, LogGradFunctor); \
__macro(square, Square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \ __macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, Pow, PowFunctor, PowGradFunctor); \ __macro(pow, Pow, PowFunctor, PowGradFunctor); \
......
...@@ -115,5 +115,29 @@ class TestConvDoubleGradCheck(unittest.TestCase): ...@@ -115,5 +115,29 @@ class TestConvDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestSquareDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [17, 23]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.square(x)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册