You need to sign in or sign up before continuing.
提交 bad0c27e 编写于 作者: Q Qiao Longfei

add test_lookup_sparse_table_op

上级 8d205c85
...@@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest): ...@@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest):
assert (result_array2[3] == w_array[6]).all() assert (result_array2[3] == w_array[6]).all()
assert (result_array2[4] == w_array[7]).all() assert (result_array2[4] == w_array[7]).all()
# create and run lookup_table operator
test_lookup_table = Operator(
"lookup_sparse_table",
W='W',
Ids='Ids',
Out='Out',
min=-5.0,
max=10.0,
seed=10,
is_test=True)
ids = scope.var("Ids").get_tensor()
unknown_id = [44, 22, 33]
ids_array2 = np.array([4, 2, 3, 7, 100000] + unknown_id).astype("int64")
ids.set(ids_array2, place)
test_lookup_table.run(scope, place)
result_array2 = np.array(out_tensor)
assert (result_array2[0] == w_array[5]).all()
assert (result_array2[1] == w_array[1]).all()
assert (result_array2[2] == w_array[2]).all()
assert (result_array2[3] == w_array[6]).all()
assert (result_array2[4] == w_array[7]).all()
for i in [5, 6, 7]:
assert np.all(result_array2[i] == 0)
def test_w_is_selected_rows(self): def test_w_is_selected_rows(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
# currently only support CPU # currently only support CPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册