From a43eee40f71352867868714a55dd9fa1135e368f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 13 Mar 2018 23:08:26 +0800 Subject: [PATCH] follow comments --- .../tests/unittests/test_lookup_table_op.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 518ef6a1bf7..ed920ad388f 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -53,18 +53,11 @@ class TestLookupTableIdsIsSelectedRows(OpTest): def check_with_place(self, place): scope = core.Scope() - # create and initialize Grad Variable + # create and initialize Variable height = 10 rows = [0, 4, 4, 7] row_numel = 12 - ids_selected_rows = scope.var('Ids').get_selected_rows() - ids_selected_rows.set_height(height) - ids_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), row_numel)).astype("float32") - ids_tensor = ids_selected_rows.get_tensor() - ids_tensor.set(np_array, place) - # create and initialize W Variable W = scope.var('W').get_tensor() W_array = np.full((height, row_numel), 1.0).astype("float32") @@ -72,20 +65,26 @@ class TestLookupTableIdsIsSelectedRows(OpTest): W_array[i] *= i W.set(W_array, place) + # create and initialize Ids Variable + ids_selected_rows = scope.var('Ids').get_selected_rows() + ids_selected_rows.set_height(len(rows)) + ids_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + ids_tensor = ids_selected_rows.get_tensor() + ids_tensor.set(np_array, place) + + # create Out Variable Out = scope.var('Out').get_selected_rows() - Out_array = np.full((len(rows), row_numel), -1.0).astype("float32") - Out.set_height(height) - Out.set_rows(rows) - Out_tensor = Out.get_tensor() - Out_tensor.set(Out_array, place) - # create and run concat_rows_op operator + # create and run lookup_table operator concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') concat_rows_op.run(scope, place) - # get and compare result + # get result from Out + Out_tensor = Out.get_tensor() result_array = np.array(Out_tensor) + # all(): return True if all elements of the iterable are true (or if the iterable is empty) for idx, row in enumerate(rows): assert (row == result_array[idx]).all() -- GitLab