diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index b88cd14d78f616b0e57386ab891dad1a872bfe65..ad86a2e5bc23b2b0ea853971cf79dec745e9706a 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); } @@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { " which is a learnable parameter."); AddInput("Ids", "An input with type int32 or int64" - "contains the ids to be looked up in W."); + "contains the ids to be looked up in W." + "Ids must be a column vector with rank = 2." + "The 2nd dimension size must be 1"); AddOutput("Out", "The lookup results, which have the same type with W."); AddComment(R"DOC( This operator is used to perform lookups on the parameter W, diff --git a/python/paddle/v2/framework/tests/test_lookup_table_op.py b/python/paddle/v2/framework/tests/test_lookup_table_op.py index b259bb67e832adcb31b0ab4e992738be2b85f884..2c48f9bf93b939aa631cd54e8fb14b5cba22f2e0 100644 --- a/python/paddle/v2/framework/tests/test_lookup_table_op.py +++ b/python/paddle/v2/framework/tests/test_lookup_table_op.py @@ -8,7 +8,8 @@ class TestLookupTableOp(OpTest): self.op_type = "lookup_table" table = np.random.random((17, 31)).astype("float32") ids = np.random.randint(0, 17, 4).astype("int32") - self.inputs = {'W': table, 'Ids': ids} + ids_expand = np.expand_dims(ids, axis=1) + self.inputs = {'W': table, 'Ids': ids_expand} self.outputs = {'Out': table[ids]} def test_check_output(self):