From 850b737112ec82bcdc98f54264a449ca21eec176 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 27 Mar 2019 14:24:32 +0800 Subject: [PATCH] Fix nparray.all() bug. (#16472) --- .../fluid/tests/unittests/test_dist_save_load.py | 12 +++++------- .../tests/unittests/test_nce_remote_table_op.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_save_load.py b/python/paddle/fluid/tests/unittests/test_dist_save_load.py index e795bc410..8c2d6d9b4 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_dist_save_load.py @@ -68,9 +68,9 @@ class TestDistSaveLoadDense2x2(TestDistBase): train0_np = np.array(tr0_var) train1_np = np.array(tr1_var) - self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta) - self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta) - self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta) + np.testing.assert_almost_equal(local_np, train0_np, decimal=2) + np.testing.assert_almost_equal(local_np, train1_np, decimal=2) + np.testing.assert_almost_equal(train0_np, train1_np, decimal=2) def test_dist(self): need_envs = { @@ -134,10 +134,8 @@ class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase): train0_2_np = np.array(tr0_var_2) train1_2_np = np.array(tr1_var_2) - self.assertAlmostEqual( - train0_1_np.all(), train0_2_np.all(), delta=delta) - self.assertAlmostEqual( - train1_1_np.all(), train1_2_np.all(), delta=delta) + np.testing.assert_almost_equal(train0_1_np, train0_2_np, decimal=2) + np.testing.assert_almost_equal(train1_1_np, train1_2_np, decimal=2) def test_dist(self): need_envs = { diff --git a/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py b/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py index cc6f40de8..d24532b95 100644 --- a/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py @@ -205,9 +205,9 @@ class TestListenAndServOp(unittest.TestCase): out = nce(x_array, param_array, bias_array, sample_weight, label_array, 5, 2) - self.assertAlmostEqual(o_cost.all(), out[0].all(), delta=1e-6) - self.assertAlmostEqual(o_logits.all(), out[1].all(), delta=1e-6) - self.assertAlmostEqual(o_labels.all(), out[2].all(), delta=1e-6) + np.testing.assert_almost_equal(o_cost, out[0], decimal=6) + np.testing.assert_almost_equal(o_logits, out[1], decimal=6) + np.testing.assert_almost_equal(o_labels, out[2], decimal=6) def test_nce_op_remote(self): os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1" -- GitLab