diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 3414a3c26339038fc128e16ad784513c67280c79..5a4f889f3d82eb532dbd5cb28fe6c54893dfed58 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -27,6 +27,9 @@ class GatherOp : public framework::OperatorWithKernel { "Inputs of GatherOp must all be set"); int batch_size = ctx.Input(1)->dims()[0]; PADDLE_ENFORCE(batch_size > 0); + paddle::framework::DDim output_dims(ctx.Input(0)->dims()); + output_dims[0] = batch_size; + ctx.Output(0)->Resize(output_dims); } }; @@ -48,8 +51,8 @@ Y = X[Index] class GatherGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - // ctx.Output("X" + framework::kGradVarSuffix) - // ->Resize(ctx.Input("X")->dims()); + ctx.Output("X" + framework::kGradVarSuffix) + ->Resize(ctx.Input("X")->dims()); } };