未验证 提交 f4ec7d54 编写于 作者: C chengduo 提交者: GitHub

fix bug of scatter op (#18640)

test=develop
上级 112cf850
......@@ -58,11 +58,15 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("Updates"))) {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......
......@@ -47,13 +47,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
dX->ShareDataWith(*dOut);
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
}
};
} // namespace operators
......
......@@ -74,12 +74,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册