提交 6fc3e8ec 编写于 作者: D danleifeng 提交者: gongweibao

edit elementwise_mul doublegrad inplace (#21245)

上级 508b898d
...@@ -128,7 +128,7 @@ REGISTER_OPERATOR( ...@@ -128,7 +128,7 @@ REGISTER_OPERATOR(
ops::ElementwiseDivDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseDivDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad, REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad,
ops::ElementwiseDivDoubleGradOpInplace); ops::ElementwiseDoubleGradOpInplace);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
......
...@@ -246,7 +246,5 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> { ...@@ -246,7 +246,5 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
} }
}; };
DECLARE_INPLACE_OP_INFERER(ElementwiseDivDoubleGradOpInplace, {"DDX", "DDOut"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -127,7 +127,7 @@ REGISTER_OPERATOR( ...@@ -127,7 +127,7 @@ REGISTER_OPERATOR(
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseMulDoubleGradOpInplace); ops::ElementwiseDoubleGradOpInplace);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
......
...@@ -172,13 +172,32 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -172,13 +172,32 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
// (2) dy = dout * ddx // (2) dy = dout * ddx
// (3) ddout = ddx * y // (3) ddout = ddx * y
// (4) ddout = ddout + dx // (4) ddout = ddout + dx
// (5) dx = dout *ddy // (5) dx = dout * ddy
if (ddout) { if (ddout) {
int axis = ctx.Attr<int>("axis");
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if (ddout->numel() > ddx->numel()) {
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
Tensor ddout_tmp;
ddout_tmp.mutable_data<T>(ddout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, y, &ddx_safe, ddout);
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, x,
&ddout_tmp);
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
// use dx to save memory, other than alloc tmp tensor // use dx to save memory, other than alloc tmp tensor
Tensor* ddout_tmp = dx; Tensor* ddout_tmp = dx;
default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp); default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp);
int axis = ctx.Attr<int>("axis");
// NOTE: in the following ElemwiseGradCompute, for the // NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first // first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not // output tensor will not be activated, DivGradDx function will not
...@@ -189,17 +208,14 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -189,17 +208,14 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
MulGradDX<T>(), MulGradDY<T>()); MulGradDX<T>(), MulGradDY<T>());
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout); default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout); auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp); auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t; ddout_t.device(place) = ddout_t + ddout_tmp_t;
default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx); default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx);
} }
} }
}
}; };
DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册