提交 1a93253f 编写于 作者: Y Yancey1989

fix unittest

上级 fb1167c3
...@@ -56,7 +56,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -56,7 +56,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
"The type of W var should be SelectedRows."); "The type of W var should be SelectedRows.");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"The type of Ids var should be SelectedRows."); "The type of Ids var should be LoDTensor.");
auto &ids_t = ids_var->Get<framework::LoDTensor>(); auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>(); auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>(); auto w_t = w_var->GetMutable<framework::SelectedRows>();
...@@ -111,10 +111,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -111,10 +111,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(SelectedRows) The input represents embedding table, " "(SelectedRows) The input represents embedding table, "
"which is a learnable parameter."); "which is a learnable parameter.");
AddInput("Ids", AddInput("Ids",
"(SelectedRows) Ids's type should be SelectedRows " "(LoDTensor) Ids's type should be LoDTensor"
"the rows of Ids contains the Ids to be looked up in W."); "THe ids to be looked up in W.");
AddOutput("Out", AddOutput("Out",
"(SelectedRows) The lookup results, which have the " "(LoDTensor) The lookup results, which have the "
"same type as W."); "same type as W.");
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
......
...@@ -32,9 +32,9 @@ class TestLookupSpraseTable(OpTest): ...@@ -32,9 +32,9 @@ class TestLookupSpraseTable(OpTest):
scope = core.Scope() scope = core.Scope()
# create and initialize Id Variable # create and initialize Id Variable
ids = scope.var("Ids").get_selected_rows() ids = scope.var("Ids").get_tensor()
ids_array = [0, 2, 3, 5, 100] ids_array = np.array([0, 2, 3, 5, 100]).astype("int64")
ids.set_rows(ids_array) ids.set(ids_array, place)
# create and initialize W Variable # create and initialize W Variable
rows = [0, 1, 2, 3, 4, 5, 6] rows = [0, 1, 2, 3, 4, 5, 6]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册