From abb7deee39f023f16d2afdad9e369105e7be0744 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 3 Apr 2018 15:03:34 +0800 Subject: [PATCH] optimize test_lookup_table_op.py --- .../paddle/fluid/tests/unittests/test_lookup_table_op.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 3f739afd25..f8d5785fbf 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -115,18 +115,18 @@ class TestLookupTableWIsSelectedRows(OpTest): w_array = np.ones((len(rows), row_numel)).astype("float32") for i in range(len(rows)): w_array[i] *= i - ids_tensor = w_selected_rows.get_tensor() - ids_tensor.set(w_array, place) + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) # create Out Variable - Out_tensor = scope.var('Out').get_tensor() + out_tensor = scope.var('Out').get_tensor() # create and run lookup_table operator lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') lookup_table.run(scope, place) # get result from Out - result_array = np.array(Out_tensor) + result_array = np.array(out_tensor) # all(): return True if all elements of the iterable are true (or if the iterable is empty) for idx, row in enumerate(ids_array): assert (row[0] == result_array[idx]).all() -- GitLab