未验证 提交 f4ec7d54 编写于 作者: C chengduo 提交者: GitHub

fix bug of scatter op (#18640)

test=develop
上级 112cf850
...@@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"), if (ctx->HasOutput(framework::GradVarName("Updates"))) {
ctx->GetInputDim("Updates")); ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Updates"));
ctx->GetInputDim(framework::GradVarName("Out"))); }
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
} }
protected: protected:
......
...@@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO // In place gradient: dX = dO
dX->ShareDataWith(*dOut); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
dUpdates->mutable_data<T>(ctx.GetPlace()); }
// Gradient by Gather: dUpdates = dO[Ids] if (dUpdates) {
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
} }
}; };
......
...@@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dX = dO if (dX) {
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); // In place gradient: dX = dO
dUpdates->mutable_data<T>(ctx.GetPlace()); framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
// Gradient by Gather: dUpdates = dO[Ids] }
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates); if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册