From 2a42a73db1ff32f7d1b9dfc772070fc7827aa4e7 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 11 Aug 2017 15:10:59 -0700 Subject: [PATCH] modify gather_op with test --- paddle/operators/gather_op.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 3414a3c26..5a4f889f3 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -27,6 +27,9 @@ class GatherOp : public framework::OperatorWithKernel { "Inputs of GatherOp must all be set"); int batch_size = ctx.Input(1)->dims()[0]; PADDLE_ENFORCE(batch_size > 0); + paddle::framework::DDim output_dims(ctx.Input(0)->dims()); + output_dims[0] = batch_size; + ctx.Output(0)->Resize(output_dims); } }; @@ -48,8 +51,8 @@ Y = X[Index] class GatherGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - // ctx.Output("X" + framework::kGradVarSuffix) - // ->Resize(ctx.Input("X")->dims()); + ctx.Output("X" + framework::kGradVarSuffix) + ->Resize(ctx.Input("X")->dims()); } }; -- GitLab