提交 6fc3e8ec 编写于 作者: D danleifeng 提交者: gongweibao

edit elementwise_mul doublegrad inplace (#21245)

上级 508b898d
......@@ -128,7 +128,7 @@ REGISTER_OPERATOR(
ops::ElementwiseDivDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad,
ops::ElementwiseDivDoubleGradOpInplace);
ops::ElementwiseDoubleGradOpInplace);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
......
......@@ -246,7 +246,5 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
}
};
DECLARE_INPLACE_OP_INFERER(ElementwiseDivDoubleGradOpInplace, {"DDX", "DDOut"});
} // namespace operators
} // namespace paddle
......@@ -127,7 +127,7 @@ REGISTER_OPERATOR(
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseMulDoubleGradOpInplace);
ops::ElementwiseDoubleGradOpInplace);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
......
......@@ -172,13 +172,32 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
// (2) dy = dout * ddx
// (3) ddout = ddx * y
// (4) ddout = ddout + dx
// (5) dx = dout *ddy
// (5) dx = dout * ddy
if (ddout) {
int axis = ctx.Attr<int>("axis");
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if (ddout->numel() > ddx->numel()) {
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
Tensor ddout_tmp;
ddout_tmp.mutable_data<T>(ddout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, y, &ddx_safe, ddout);
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, x,
&ddout_tmp);
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
// use dx to save memory, other than alloc tmp tensor
Tensor* ddout_tmp = dx;
default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp);
int axis = ctx.Attr<int>("axis");
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
......@@ -189,17 +208,14 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
MulGradDX<T>(), MulGradDY<T>());
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx);
}
}
}
};
DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"});
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册