提交 abb7deee 编写于 作者: Q qiaolongfei

optimize test_lookup_table_op.py

上级 e64dda7e
...@@ -115,18 +115,18 @@ class TestLookupTableWIsSelectedRows(OpTest): ...@@ -115,18 +115,18 @@ class TestLookupTableWIsSelectedRows(OpTest):
w_array = np.ones((len(rows), row_numel)).astype("float32") w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)): for i in range(len(rows)):
w_array[i] *= i w_array[i] *= i
ids_tensor = w_selected_rows.get_tensor() w_tensor = w_selected_rows.get_tensor()
ids_tensor.set(w_array, place) w_tensor.set(w_array, place)
# create Out Variable # create Out Variable
Out_tensor = scope.var('Out').get_tensor() out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator # create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place) lookup_table.run(scope, place)
# get result from Out # get result from Out
result_array = np.array(Out_tensor) result_array = np.array(out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty) # all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array): for idx, row in enumerate(ids_array):
assert (row[0] == result_array[idx]).all() assert (row[0] == result_array[idx]).all()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册