diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index ff3734b8f0ab7ef32b98a4eee0439ae5336f8c3a..f1839e456d66ab95fb4ccac933cf7b635c54f5a0 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -56,7 +56,7 @@ class LookupSparseTableOp : public framework::OperatorBase { PADDLE_ENFORCE(w_var->IsType(), "The type of W var should be SelectedRows."); PADDLE_ENFORCE(ids_var->IsType(), - "The type of Ids var should be SelectedRows."); + "The type of Ids var should be LoDTensor."); auto &ids_t = ids_var->Get(); auto out_t = out_var->GetMutable(); auto w_t = w_var->GetMutable(); @@ -111,10 +111,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { "(SelectedRows) The input represents embedding table, " "which is a learnable parameter."); AddInput("Ids", - "(SelectedRows) Ids's type should be SelectedRows " - "the rows of Ids contains the Ids to be looked up in W."); + "(LoDTensor) Ids's type should be LoDTensor" + "THe ids to be looked up in W."); AddOutput("Out", - "(SelectedRows) The lookup results, which have the " + "(LoDTensor) The lookup results, which have the " "same type as W."); AddAttr("padding_idx", "(int64, default -1) " diff --git a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py index 6c339cba83c821a1b26efd5e877d809afef08a6b..aa9eae1e882f55ef51f38e158317a1a9aeed641c 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py @@ -32,9 +32,9 @@ class TestLookupSpraseTable(OpTest): scope = core.Scope() # create and initialize Id Variable - ids = scope.var("Ids").get_selected_rows() - ids_array = [0, 2, 3, 5, 100] - ids.set_rows(ids_array) + ids = scope.var("Ids").get_tensor() + ids_array = np.array([0, 2, 3, 5, 100]).astype("int64") + ids.set(ids_array, place) # create and initialize W Variable rows = [0, 1, 2, 3, 4, 5, 6]