未验证 提交 f4ec7d54 编写于 作者: C chengduo 提交者: GitHub

fix bug of scatter op (#18640)

test=develop
上级 112cf850
...@@ -58,11 +58,15 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -58,11 +58,15 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("Updates"))) {
ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates")); ctx->GetInputDim("Updates"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out"))); ctx->GetInputDim(framework::GradVarName("Out")));
} }
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
......
...@@ -47,13 +47,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -47,13 +47,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO // In place gradient: dX = dO
dX->ShareDataWith(*dOut); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates); GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -74,12 +74,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -74,12 +74,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO // In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates); CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
} }
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册