From f4ec7d54c8134c670c9fb3c1e23e7f0024500313 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 16 Jul 2019 09:47:16 +0800 Subject: [PATCH] fix bug of scatter op (#18640) test=develop --- paddle/fluid/operators/scatter_op.cc | 12 ++++++++---- paddle/fluid/operators/scatter_op.cu | 15 +++++++++------ paddle/fluid/operators/scatter_op.h | 14 +++++++++----- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index f5a1b32e5c2..4eb5b7ad9d1 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("Updates"), - ctx->GetInputDim("Updates")); - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); + if (ctx->HasOutput(framework::GradVarName("Updates"))) { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } } protected: diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index e9ad3475381..e17617b40da 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - - // In place gradient: dX = dO - dX->ShareDataWith(*dOut); - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopy(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 9c237dc0f1f..3b6184de77f 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - // In place gradient: dX = dO - framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } }; -- GitLab