From 8bae8590ac2df254a0ab5b5857ab7b1d324ffcf1 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 13 May 2019 13:05:14 +0800 Subject: [PATCH] add double grad for elementwise_mul op (#17255) * add double grad for elementwise_mul. test=develop * remove comment. test=develop * fix grad sum. test=develop * fix for axis expand. test=develop * add test for axis expand. test=develop --- .../elementwise/elementwise_mul_op.cc | 39 +++++++++++++- .../elementwise/elementwise_mul_op.cu | 6 +++ .../elementwise/elementwise_mul_op.h | 51 ++++++++++++++++++ .../operators/elementwise/elementwise_op.h | 37 +++++++++++++ .../elementwise/elementwise_op_function.h | 15 ++++++ .../fluid/tests/unittests/test_nn_grad.py | 54 +++++++++++++++++++ 6 files changed, 201 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index d5e3300ac9..0f6af96ff3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" +#include #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -43,6 +44,30 @@ class ElementwiseMulOpMaker : public ElementwiseOpMaker { virtual std::string GetEquation() const { return "Out = X \\\\odot Y"; } }; +class ElementwiseMulDoubleGradDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("elementwise_mul_grad_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Y", Input("Y")); + op->SetInput("DOut", Input(framework::GradVarName("Out"))); + op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); + op->SetInput("DDY", OutputGrad(framework::GradVarName("Y"))); + + op->SetAttrMap(Attrs()); + + op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + return op; + } +}; + } // namespace operators } // namespace paddle @@ -50,7 +75,9 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMulOpGradDescMaker); -REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad); +REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad, + ops::ElementwiseMulDoubleGradDescMaker); +REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad); REGISTER_OP_CPU_KERNEL( elementwise_mul, @@ -64,3 +91,13 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_mul_grad_grad, + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index e36cc8f9f2..303070bd19 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -88,3 +88,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_mul_grad_grad, + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 7a7a3989c0..f67c55f310 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -123,5 +123,56 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel { ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); } }; + +template +class ElementwiseMulDoubleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input("DOut"); + auto* ddx = ctx.Input("DDX"); + auto* ddy = ctx.Input("DDY"); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* ddout = ctx.Output("DDOut"); + + if (ddout) ddout->mutable_data(ctx.GetPlace()); + + // dx = dout * ddy + // dy = dout * ddx + Tensor ddx_safe, ddy_safe; + GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); + GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, MulGradDY>( + ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX(), + MulGradDY()); + + // ddout = ddx * y + x * ddy + if (ddout) { + if (ddx && ddy) { + Tensor ddout_tmp; + ddout_tmp.mutable_data(ddout->dims(), ctx.GetPlace()); + + default_elementwise_mul(ctx, ddx, y, ddout); + default_elementwise_mul(ctx, x, ddy, &ddout_tmp); + + auto& place = + *ctx.template device_context().eigen_device(); + auto ddout_t = framework::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = framework::EigenVector::Flatten(ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + } else { + if (ddx) default_elementwise_mul(ctx, ddx, y, ddout); + if (ddy) default_elementwise_mul(ctx, x, ddy, ddout); + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 5ec335972a..c6615b635e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -212,6 +212,43 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { } }; +class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + using Tensor = framework::Tensor; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->ShareDim("X", x_grad_name); + ctx->ShareLoD("X", x_grad_name); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->ShareDim("Y", y_grad_name); + ctx->ShareLoD("Y", y_grad_name); + } + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("DOut", "DDOut"); + ctx->ShareLoD("DOut", "DDOut"); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = ctx.Input("DDX")->type(); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + // For Add, Sub op, the X, Out is not needed. class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { public: diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 2e91ec8484..5467858508 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -1636,5 +1636,20 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx, } } } + +template +static inline void GetDoubleGradSafeTensor( + const framework::ExecutionContext &ctx, const framework::Tensor *x, + const framework::Tensor *ddx, framework::Tensor *ddx_safe) { + if (ddx) { + *ddx_safe = *ddx; + } else { + ddx_safe->mutable_data(x->dims(), ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.template device_context(), ddx_safe, + static_cast(0)); + } +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 4b6b43b716..dbfdca0e44 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -139,5 +139,59 @@ class TestSquareDoubleGradCheck(unittest.TestCase): self.func(p) +class TestElementwiseMulDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable shoule be clearly specified, not inlcude -1. + shape = [7, 9] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + y = layers.data('y', shape, False, dtype) + x.persistable = True + y.persistable = True + out = layers.elementwise_mul(x, y) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable shoule be clearly specified, not inlcude -1. + shape = [7, 9] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + y = layers.data('y', shape[:-1], False, dtype) + x.persistable = True + y.persistable = True + out = layers.elementwise_mul(x, y, axis=0) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype) + + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main() -- GitLab