diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index d6e6990394e46ba06c4bacfe33ca522f3ff1413a..3f299ea1a6d63e3dd68bd9e5b637af5ac12bd8f0 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -17,6 +17,8 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 2e08ba8dcc72863ba0f4556e654e3e3d457f9f88..499def05a7f62d1c5f236e1f7dc5ab8f9c7b5bc3 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -24,13 +24,9 @@ class GatherOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - // PADDLE_ENFORCE(ctx.InputSize() == 2, ""); - // PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); - // PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), - // "Inputs of GatherOp must all be set"); int batch_size = ctx.Input("Index")->dims()[0]; - PADDLE_ENFORCE(batch_size > 0); - paddle::framework::DDim output_dims(ctx.Input(0)->dims()); + PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); + paddle::framework::DDim output_dims(ctx.Input("X")->dims()); output_dims[0] = batch_size; ctx.Output("Y")->Resize(output_dims); } diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py index 2ffbf172365562ebf88c701c5f81bb457a6fb5bf..049054d07b63c06686d65682e3e6373ac43b8518 100644 --- a/python/paddle/v2/framework/tests/test_gather_op.py +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -12,11 +12,12 @@ class TestGatherOp(unittest.TestCase): def setUp(self): self.type = "gather" + xnp = numpy.random.random((10, 20)).astype("float32") self.inputs = { - 'X': numpy.random.random((10, 20)).astype("float32"), - 'Index': numpy.array([1, 3, 5]).astype("int") + 'X': xnp, + 'Index': numpy.array([1, 3, 5]).astype("int32") } - self.outputs = {'Y': self.input['X'][self.input['Index']]} + self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]} if __name__ == "__main__":