From 1a93253f1665e7d85b11048c04df44c4753a46b3 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 27 Apr 2018 13:54:38 +0800 Subject: [PATCH] fix unittest --- paddle/fluid/operators/lookup_sparse_table_op.cc | 8 ++++---- .../fluid/tests/unittests/test_lookup_sparse_table_op.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index ff3734b8f0a..f1839e456d6 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -56,7 +56,7 @@ class LookupSparseTableOp : public framework::OperatorBase { PADDLE_ENFORCE(w_var->IsType(), "The type of W var should be SelectedRows."); PADDLE_ENFORCE(ids_var->IsType(), - "The type of Ids var should be SelectedRows."); + "The type of Ids var should be LoDTensor."); auto &ids_t = ids_var->Get(); auto out_t = out_var->GetMutable(); auto w_t = w_var->GetMutable(); @@ -111,10 +111,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { "(SelectedRows) The input represents embedding table, " "which is a learnable parameter."); AddInput("Ids", - "(SelectedRows) Ids's type should be SelectedRows " - "the rows of Ids contains the Ids to be looked up in W."); + "(LoDTensor) Ids's type should be LoDTensor" + "THe ids to be looked up in W."); AddOutput("Out", - "(SelectedRows) The lookup results, which have the " + "(LoDTensor) The lookup results, which have the " "same type as W."); AddAttr("padding_idx", "(int64, default -1) " diff --git a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py index 6c339cba83c..aa9eae1e882 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py @@ -32,9 +32,9 @@ class TestLookupSpraseTable(OpTest): scope = core.Scope() # create and initialize Id Variable - ids = scope.var("Ids").get_selected_rows() - ids_array = [0, 2, 3, 5, 100] - ids.set_rows(ids_array) + ids = scope.var("Ids").get_tensor() + ids_array = np.array([0, 2, 3, 5, 100]).astype("int64") + ids.set(ids_array, place) # create and initialize W Variable rows = [0, 1, 2, 3, 4, 5, 6] -- GitLab