diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 0343276b755bce29999e90b129709c7b7365fb22..115b5b9ef205d6a043f13337059e70c68cb47a94 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -128,7 +128,7 @@ REGISTER_OPERATOR( ops::ElementwiseDivDoubleGradMaker); REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad, - ops::ElementwiseDivDoubleGradOpInplace); + ops::ElementwiseDoubleGradOpInplace); REGISTER_OP_CPU_KERNEL( elementwise_div, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 8bb1dc4708f2650fede8ffa7fc49476f132cd83a..53793571db67a33071e4096084b0d72534ca7a78 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -246,7 +246,5 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel { } }; -DECLARE_INPLACE_OP_INFERER(ElementwiseDivDoubleGradOpInplace, {"DDX", "DDOut"}); - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 87dc15fadcc7aed7e6b84c88badbe5671faaf194..aaa6cfe034616a63e574186a1573f32075f52d7d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -127,7 +127,7 @@ REGISTER_OPERATOR( ops::ElementwiseMulDoubleGradMaker); REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, - ops::ElementwiseMulDoubleGradOpInplace); + ops::ElementwiseDoubleGradOpInplace); REGISTER_OP_CPU_KERNEL( elementwise_mul, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index cc3226135e34d36abf3dce3d266b7d5468353a49..502da88cf04e6f92a8f6872b7f7f431efe4cf58a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -172,34 +172,50 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel { // (2) dy = dout * ddx // (3) ddout = ddx * y // (4) ddout = ddout + dx - // (5) dx = dout *ddy + // (5) dx = dout * ddy if (ddout) { - // use dx to save memory, other than alloc tmp tensor - Tensor* ddout_tmp = dx; - - default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); int axis = ctx.Attr("axis"); - // NOTE: in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, - MulGradDX(), MulGradDY()); - default_elementwise_mul(ctx, &ddx_safe, y, ddout); - auto& place = *ctx.template device_context().eigen_device(); - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - default_elementwise_mul(ctx, dout, &ddy_safe, dx); + // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace + if (ddout->numel() > ddx->numel()) { + ElemwiseGradCompute, MulGradDY>( + ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX(), + MulGradDY()); + + Tensor ddout_tmp; + ddout_tmp.mutable_data(ddout->dims(), ctx.GetPlace()); + + default_elementwise_mul(ctx, y, &ddx_safe, ddout); + default_elementwise_mul(ctx, &ddy_safe, x, + &ddout_tmp); + + auto ddout_t = framework::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = framework::EigenVector::Flatten(ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + } else { + // use dx to save memory, other than alloc tmp tensor + Tensor* ddout_tmp = dx; + + default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + ElemwiseGradCompute, MulGradDY>( + ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, + MulGradDX(), MulGradDY()); + default_elementwise_mul(ctx, &ddx_safe, y, ddout); + + auto ddout_t = framework::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + default_elementwise_mul(ctx, dout, &ddy_safe, dx); + } } } }; -DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"}); - } // namespace operators } // namespace paddle