From 10b23a72c1ab1a54942bceb910bf9a4a7dfaeb33 Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Mon, 20 May 2019 10:57:50 +0800 Subject: [PATCH] Double backward elementwise div (#17416) * double backward, elementwise_div * fix dx empty. test=develop * bug fix (#17392) fix secure bug * Eanble stack operator for a Ngraph, test=develop (#17406) * fix sqrt_grad_grad unittest. test=develop (#17410) * fix sqrt_grad_grad unittest. test=develop * disable sqrt_grad_grad unittest. test=develop * test=develop, fix unittest * test=develop, fix unittest * test=develop, fix unittest * test=develop, fix bug * fix unittest. test=develop * fix unittest dx. test=develop * tmp fix! for test... test=develop * reduce tmp, test=develop * test=develop, reduce tmp * fix broadcast unittest. test=develop * fix format. test=develop * refine code. test=develop * refine code. test=develop * refine GetDoubleGradSafeTensor. test=develop * fix format. test=develop --- .../elementwise/elementwise_div_op.cc | 40 +++++- .../elementwise/elementwise_div_op.cu | 10 ++ .../elementwise/elementwise_div_op.h | 116 ++++++++++++++++++ .../elementwise/elementwise_op_function.h | 3 +- .../fluid/tests/unittests/test_nn_grad.py | 56 +++++++++ 5 files changed, 223 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 530a54b7ca1..6689823d4a2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -44,6 +44,31 @@ class ElementwiseDivGradOpDescMaker : public framework::SingleGradOpDescMaker { } }; +class ElementwiseDivDoubleGradDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("elementwise_div_grad_grad"); + op->SetInput("Y", Input("Y")); + op->SetInput("Out", Input("Out")); + op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); + op->SetInput("DDY", OutputGrad(framework::GradVarName("Y"))); + op->SetInput("DX", Output(framework::GradVarName("X"))); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + op->SetOutput("DOut", InputGrad("Out")); + op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + + return op; + } +}; + } // namespace operators } // namespace paddle @@ -53,7 +78,9 @@ REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseDivGradOpDescMaker); -REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad); +REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad, + ops::ElementwiseDivDoubleGradDescMaker); +REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad); REGISTER_OP_CPU_KERNEL( elementwise_div, @@ -67,3 +94,14 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); + +REGISTER_OP_CPU_KERNEL( + elementwise_div_grad_grad, + ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index ae669f55254..b38f84845b7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -33,3 +33,13 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_div_grad_grad, + ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 0f0ad863730..c604c9017ec 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -14,8 +14,13 @@ limitations under the License. */ #pragma once +#include +#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + namespace paddle { namespace operators { @@ -51,6 +56,13 @@ struct DivGradDY { } }; +template +struct DivDoubleDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return y * out * dout - x * dout; + } +}; + template class ElementwiseDivGradKernel : public ElemwiseGradKernel { public: @@ -72,5 +84,109 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel { } }; +class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + using Tensor = framework::Tensor; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput("DOut")) { + ctx->ShareDim("DX", "DOut"); + ctx->ShareLoD("DX", "DOut"); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->ShareDim("Y", y_grad_name); + ctx->ShareLoD("Y", y_grad_name); + } + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("DX", "DDOut"); + ctx->ShareLoD("DX", "DDOut"); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = ctx.Input("DDX")->type(); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class ElementwiseDivDoubleGradKernel : public framework::OpKernel { + using Tensor = framework::Tensor; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* Y = ctx.Input("Y"); + auto* Out = ctx.Input("Out"); + auto* ddX = ctx.Input("DDX"); + auto* ddY = ctx.Input("DDY"); + auto* dX = ctx.Input("DX"); + + auto* dY = ctx.Output(framework::GradVarName("Y")); + auto* dOut = ctx.Output("DOut"); + auto* ddOut = ctx.Output("DDOut"); + + int axis = ctx.Attr("axis"); + + if (dY) dY->mutable_data(Y->dims(), ctx.GetPlace()); + if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); + if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); + + // ddX_safe == null ? 0 : ddX + // ddY_safe == null ? 0 : ddY + Tensor ddX_safe, ddY_safe; + GetDoubleGradSafeTensor(ctx, Out, ddX, &ddX_safe); + GetDoubleGradSafeTensor(ctx, Y, ddY, &ddY_safe); + + if (dOut) { + // dOut = - dX * ddY + default_elementwise_mul(ctx, dX, &ddY_safe, dOut); + auto& place = + *ctx.template device_context().eigen_device(); + auto dout = framework::EigenVector::Flatten(*dOut); + dout.device(place) = static_cast(-1) * dout; + } + + if (dY) { + // dX_div_Y = dX / Y; + auto& dev_ctx = ctx.template device_context(); + Tensor dX_div_Y = + ctx.AllocateTmpTensor(Out->dims(), dev_ctx); + ElementwiseComputeEx, DeviceContext, T>( + ctx, dX, Y, axis, DivFunctor(), &dX_div_Y); + + // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + + // dY = Out * dX * ddY / Y - dX * ddX / Y + ElemwiseGradCompute, DivDoubleDY>( + ctx, ddX_safe, ddY_safe, *Out, dX_div_Y, axis, nullptr, dY, + DivGradDX(), DivDoubleDY()); + } + + if (ddOut) { + // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y + default_elementwise_mul(ctx, Out, &ddY_safe, ddOut); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &ddX_safe, ddOut, 0, SubFunctor(), ddOut); + ElementwiseComputeEx, DeviceContext, T>( + ctx, ddOut, Y, axis, DivFunctor(), ddOut); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 54678585086..ad9d0b2a0d2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -1644,7 +1644,8 @@ static inline void GetDoubleGradSafeTensor( if (ddx) { *ddx_safe = *ddx; } else { - ddx_safe->mutable_data(x->dims(), ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + *ddx_safe = ctx.AllocateTmpTensor(x->dims(), dev_ctx); math::SetConstant set_zero; set_zero(ctx.template device_context(), ddx_safe, static_cast(0)); diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 8dd1d294230..558a21c1f02 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -378,5 +378,61 @@ class TestMulDoubleGradCheck(unittest.TestCase): self.func(p) +class TestElementwiseDivDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable shoule be clearly specified, not inlcude -1. + shape = [2, 3, 7, 9] + eps = 0.0001 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + y = layers.data('y', shape, False, dtype) + x.persistable = True + y.persistable = True + out = layers.elementwise_div(x, y, axis=0) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr[np.abs(y_arr) < 0.005] = 0.02 + + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps, atol=1e-3) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestElementwiseDivBroadcastDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable shoule be clearly specified, not inlcude -1. + shape = [2, 3, 7, 9] + eps = 0.0001 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + y = layers.data('y', shape[1:-1], False, dtype) + x.persistable = True + y.persistable = True + out = layers.elementwise_div(x, y, axis=1) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape[1:-1]).astype(dtype) + y_arr[np.abs(y_arr) < 0.005] = 0.02 + + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps, atol=1e-3) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main() -- GitLab