diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 518ef6a1bf7237d9e415485213317975555bcf01..ed920ad388ff0e01887404e70fe82565b4cd28fa 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -53,18 +53,11 @@ class TestLookupTableIdsIsSelectedRows(OpTest): def check_with_place(self, place): scope = core.Scope() - # create and initialize Grad Variable + # create and initialize Variable height = 10 rows = [0, 4, 4, 7] row_numel = 12 - ids_selected_rows = scope.var('Ids').get_selected_rows() - ids_selected_rows.set_height(height) - ids_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), row_numel)).astype("float32") - ids_tensor = ids_selected_rows.get_tensor() - ids_tensor.set(np_array, place) - # create and initialize W Variable W = scope.var('W').get_tensor() W_array = np.full((height, row_numel), 1.0).astype("float32") @@ -72,20 +65,26 @@ class TestLookupTableIdsIsSelectedRows(OpTest): W_array[i] *= i W.set(W_array, place) + # create and initialize Ids Variable + ids_selected_rows = scope.var('Ids').get_selected_rows() + ids_selected_rows.set_height(len(rows)) + ids_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + ids_tensor = ids_selected_rows.get_tensor() + ids_tensor.set(np_array, place) + + # create Out Variable Out = scope.var('Out').get_selected_rows() - Out_array = np.full((len(rows), row_numel), -1.0).astype("float32") - Out.set_height(height) - Out.set_rows(rows) - Out_tensor = Out.get_tensor() - Out_tensor.set(Out_array, place) - # create and run concat_rows_op operator + # create and run lookup_table operator concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') concat_rows_op.run(scope, place) - # get and compare result + # get result from Out + Out_tensor = Out.get_tensor() result_array = np.array(Out_tensor) + # all(): return True if all elements of the iterable are true (or if the iterable is empty) for idx, row in enumerate(rows): assert (row == result_array[idx]).all()