diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 5a4f889f3d82eb532dbd5cb28fe6c54893dfed58..05ba52ce064185d25fdb6136e8345b9dba6d44c9 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -51,8 +51,10 @@ Y = X[Index] class GatherGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("X" + framework::kGradVarSuffix) - ->Resize(ctx.Input("X")->dims()); + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + X_grad->Resize(X->dims()); } };