提交 1a93253f 编写于 作者: Y Yancey1989

fix unittest

上级 fb1167c3
......@@ -56,7 +56,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
"The type of W var should be SelectedRows.");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"The type of Ids var should be SelectedRows.");
"The type of Ids var should be LoDTensor.");
auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>();
......@@ -111,10 +111,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(SelectedRows) The input represents embedding table, "
"which is a learnable parameter.");
AddInput("Ids",
"(SelectedRows) Ids's type should be SelectedRows "
"the rows of Ids contains the Ids to be looked up in W.");
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.");
AddOutput("Out",
"(SelectedRows) The lookup results, which have the "
"(LoDTensor) The lookup results, which have the "
"same type as W.");
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
......
......@@ -32,9 +32,9 @@ class TestLookupSpraseTable(OpTest):
scope = core.Scope()
# create and initialize Id Variable
ids = scope.var("Ids").get_selected_rows()
ids_array = [0, 2, 3, 5, 100]
ids.set_rows(ids_array)
ids = scope.var("Ids").get_tensor()
ids_array = np.array([0, 2, 3, 5, 100]).astype("int64")
ids.set(ids_array, place)
# create and initialize W Variable
rows = [0, 1, 2, 3, 4, 5, 6]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册