From bad0c27e6ef9506058ca5a6ba41c34bc652c8b9d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 13 Nov 2018 16:33:24 +0800 Subject: [PATCH] add test_lookup_sparse_table_op --- .../unittests/test_lookup_sparse_table_op.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py index 11e5d8b53..c7f4f3e91 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py @@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest): assert (result_array2[3] == w_array[6]).all() assert (result_array2[4] == w_array[7]).all() + # create and run lookup_table operator + test_lookup_table = Operator( + "lookup_sparse_table", + W='W', + Ids='Ids', + Out='Out', + min=-5.0, + max=10.0, + seed=10, + is_test=True) + + ids = scope.var("Ids").get_tensor() + unknown_id = [44, 22, 33] + ids_array2 = np.array([4, 2, 3, 7, 100000] + unknown_id).astype("int64") + ids.set(ids_array2, place) + test_lookup_table.run(scope, place) + + result_array2 = np.array(out_tensor) + assert (result_array2[0] == w_array[5]).all() + assert (result_array2[1] == w_array[1]).all() + assert (result_array2[2] == w_array[2]).all() + assert (result_array2[3] == w_array[6]).all() + assert (result_array2[4] == w_array[7]).all() + + for i in [5, 6, 7]: + assert np.all(result_array2[i] == 0) + def test_w_is_selected_rows(self): places = [core.CPUPlace()] # currently only support CPU -- GitLab