未验证 提交 317e18ab 编写于 作者: Q Qingsheng Li 提交者: GitHub

Remove Data Sharing between input and output in scatter_op (#12672)

* Remove Data Sharing between input and output in scatter_op

* Removed data sharing in backward op
上级 c44fb003
......@@ -35,7 +35,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
auto *Out = ctx.Output<Tensor>("Out");
// In place output: Out = X, Out[Ids] += Updates
Out->ShareDataWith(*X);
framework::TensorCopySync(*X, ctx.GetPlace(), Out);
// Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
}
......@@ -53,7 +53,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dX = dO
dX->ShareDataWith(*dOut);
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册