未验证 提交 f4ec7d54 编写于 作者: C chengduo 提交者: GitHub

fix bug of scatter op (#18640)

test=develop
上级 112cf850
......@@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
if (ctx->HasOutput(framework::GradVarName("Updates"))) {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
}
protected:
......
......@@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dX = dO
dX->ShareDataWith(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
}
};
......
......@@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
if (dX) {
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册