提交 40e7caf6 编写于 作者: Q QI JUN 提交者: Yu Yang

ensure ids in lookup table op must be a column vector (#4987)

* ensure ids in lookup table op must be a column vector

* follow comments
上级 7d653c41
......@@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
}
......@@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
" which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64"
"contains the ids to be looked up in W.");
"contains the ids to be looked up in W."
"Ids must be a column vector with rank = 2."
"The 2nd dimension size must be 1");
AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment(R"DOC(
This operator is used to perform lookups on the parameter W,
......
......@@ -8,7 +8,8 @@ class TestLookupTableOp(OpTest):
self.op_type = "lookup_table"
table = np.random.random((17, 31)).astype("float32")
ids = np.random.randint(0, 17, 4).astype("int32")
self.inputs = {'W': table, 'Ids': ids}
ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand}
self.outputs = {'Out': table[ids]}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册