From 6fc3e8ec84af9e9fa2f191cbcc402de080962a67 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Wed, 20 Nov 2019 19:45:20 +0800 Subject: [PATCH] edit elementwise_mul doublegrad inplace (#21245) --- .../elementwise/elementwise_div_op.cc | 2 +- .../elementwise/elementwise_div_op.h | 2 - .../elementwise/elementwise_mul_op.cc | 2 +- .../elementwise/elementwise_mul_op.h | 58 ++++++++++++------- 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 0343276b755..115b5b9ef20 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -128,7 +128,7 @@ REGISTER_OPERATOR( ops::ElementwiseDivDoubleGradMaker); REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad, - ops::ElementwiseDivDoubleGradOpInplace); + ops::ElementwiseDoubleGradOpInplace); REGISTER_OP_CPU_KERNEL( elementwise_div, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 8bb1dc4708f..53793571db6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -246,7 +246,5 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel { } }; -DECLARE_INPLACE_OP_INFERER(ElementwiseDivDoubleGradOpInplace, {"DDX", "DDOut"}); - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 87dc15fadcc..aaa6cfe0346 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -127,7 +127,7 @@ REGISTER_OPERATOR( ops::ElementwiseMulDoubleGradMaker); REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, - ops::ElementwiseMulDoubleGradOpInplace); + ops::ElementwiseDoubleGradOpInplace); REGISTER_OP_CPU_KERNEL( elementwise_mul, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index cc3226135e3..502da88cf04 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -172,34 +172,50 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel { // (2) dy = dout * ddx // (3) ddout = ddx * y // (4) ddout = ddout + dx - // (5) dx = dout *ddy + // (5) dx = dout * ddy if (ddout) { - // use dx to save memory, other than alloc tmp tensor - Tensor* ddout_tmp = dx; - - default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); int axis = ctx.Attr("axis"); - // NOTE: in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, - MulGradDX(), MulGradDY()); - default_elementwise_mul(ctx, &ddx_safe, y, ddout); - auto& place = *ctx.template device_context().eigen_device(); - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - default_elementwise_mul(ctx, dout, &ddy_safe, dx); + // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace + if (ddout->numel() > ddx->numel()) { + ElemwiseGradCompute, MulGradDY>( + ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX(), + MulGradDY()); + + Tensor ddout_tmp; + ddout_tmp.mutable_data(ddout->dims(), ctx.GetPlace()); + + default_elementwise_mul(ctx, y, &ddx_safe, ddout); + default_elementwise_mul(ctx, &ddy_safe, x, + &ddout_tmp); + + auto ddout_t = framework::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = framework::EigenVector::Flatten(ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + } else { + // use dx to save memory, other than alloc tmp tensor + Tensor* ddout_tmp = dx; + + default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + ElemwiseGradCompute, MulGradDY>( + ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, + MulGradDX(), MulGradDY()); + default_elementwise_mul(ctx, &ddx_safe, y, ddout); + + auto ddout_t = framework::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + default_elementwise_mul(ctx, dout, &ddy_safe, dx); + } } } }; -DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"}); - } // namespace operators } // namespace paddle -- GitLab