From 6075928d5531b5eecff0d3183c1d47ab3b0962d4 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 16 Aug 2017 19:02:29 +0000 Subject: [PATCH] gather op added --- paddle/operators/gather.h | 2 ++ paddle/operators/gather_op.cc | 8 ++------ python/paddle/v2/framework/tests/test_gather_op.py | 7 ++++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index d6e6990394..3f299ea1a6 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -17,6 +17,8 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 2e08ba8dcc..499def05a7 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -24,13 +24,9 @@ class GatherOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - // PADDLE_ENFORCE(ctx.InputSize() == 2, ""); - // PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); - // PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), - // "Inputs of GatherOp must all be set"); int batch_size = ctx.Input("Index")->dims()[0]; - PADDLE_ENFORCE(batch_size > 0); - paddle::framework::DDim output_dims(ctx.Input(0)->dims()); + PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); + paddle::framework::DDim output_dims(ctx.Input("X")->dims()); output_dims[0] = batch_size; ctx.Output("Y")->Resize(output_dims); } diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py index 2ffbf17236..049054d07b 100644 --- a/python/paddle/v2/framework/tests/test_gather_op.py +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -12,11 +12,12 @@ class TestGatherOp(unittest.TestCase): def setUp(self): self.type = "gather" + xnp = numpy.random.random((10, 20)).astype("float32") self.inputs = { - 'X': numpy.random.random((10, 20)).astype("float32"), - 'Index': numpy.array([1, 3, 5]).astype("int") + 'X': xnp, + 'Index': numpy.array([1, 3, 5]).astype("int32") } - self.outputs = {'Y': self.input['X'][self.input['Index']]} + self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]} if __name__ == "__main__": -- GitLab