提交 a43eee40 编写于 作者: C chengduoZH

follow comments

上级 92e2207e
...@@ -53,18 +53,11 @@ class TestLookupTableIdsIsSelectedRows(OpTest): ...@@ -53,18 +53,11 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
# create and initialize Grad Variable # create and initialize Variable
height = 10 height = 10
rows = [0, 4, 4, 7] rows = [0, 4, 4, 7]
row_numel = 12 row_numel = 12
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(height)
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create and initialize W Variable # create and initialize W Variable
W = scope.var('W').get_tensor() W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32") W_array = np.full((height, row_numel), 1.0).astype("float32")
...@@ -72,20 +65,26 @@ class TestLookupTableIdsIsSelectedRows(OpTest): ...@@ -72,20 +65,26 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
W_array[i] *= i W_array[i] *= i
W.set(W_array, place) W.set(W_array, place)
# create and initialize Ids Variable
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(len(rows))
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create Out Variable
Out = scope.var('Out').get_selected_rows() Out = scope.var('Out').get_selected_rows()
Out_array = np.full((len(rows), row_numel), -1.0).astype("float32")
Out.set_height(height)
Out.set_rows(rows)
Out_tensor = Out.get_tensor()
Out_tensor.set(Out_array, place)
# create and run concat_rows_op operator # create and run lookup_table operator
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
concat_rows_op.run(scope, place) concat_rows_op.run(scope, place)
# get and compare result # get result from Out
Out_tensor = Out.get_tensor()
result_array = np.array(Out_tensor) result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(rows): for idx, row in enumerate(rows):
assert (row == result_array[idx]).all() assert (row == result_array[idx]).all()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册