diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index f5a1b32e5c240933d79a524937b5a8222118fdd9..4eb5b7ad9d1fe128ade904cf61e0178d59b374b8 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("Updates"), - ctx->GetInputDim("Updates")); - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); + if (ctx->HasOutput(framework::GradVarName("Updates"))) { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } } protected: diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index e9ad347538157342adb24813546e927040b4f9d2..e17617b40da356d74bdffcf53a6c9189d13c64f1 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - - // In place gradient: dX = dO - dX->ShareDataWith(*dOut); - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopy(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 9c237dc0f1f115ce76a3b982a8c6ca1dfccb0b87..3b6184de77f4fc05aa2f2900ebc656ed06a8edfc 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - // In place gradient: dX = dO - framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } };