未验证 提交 850b7371 编写于 作者: G gongweibao 提交者: GitHub

Fix nparray.all() bug. (#16472)

上级 1b4e4e7e
......@@ -68,9 +68,9 @@ class TestDistSaveLoadDense2x2(TestDistBase):
train0_np = np.array(tr0_var)
train1_np = np.array(tr1_var)
self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta)
self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta)
self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta)
np.testing.assert_almost_equal(local_np, train0_np, decimal=2)
np.testing.assert_almost_equal(local_np, train1_np, decimal=2)
np.testing.assert_almost_equal(train0_np, train1_np, decimal=2)
def test_dist(self):
need_envs = {
......@@ -134,10 +134,8 @@ class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase):
train0_2_np = np.array(tr0_var_2)
train1_2_np = np.array(tr1_var_2)
self.assertAlmostEqual(
train0_1_np.all(), train0_2_np.all(), delta=delta)
self.assertAlmostEqual(
train1_1_np.all(), train1_2_np.all(), delta=delta)
np.testing.assert_almost_equal(train0_1_np, train0_2_np, decimal=2)
np.testing.assert_almost_equal(train1_1_np, train1_2_np, decimal=2)
def test_dist(self):
need_envs = {
......
......@@ -205,9 +205,9 @@ class TestListenAndServOp(unittest.TestCase):
out = nce(x_array, param_array, bias_array, sample_weight,
label_array, 5, 2)
self.assertAlmostEqual(o_cost.all(), out[0].all(), delta=1e-6)
self.assertAlmostEqual(o_logits.all(), out[1].all(), delta=1e-6)
self.assertAlmostEqual(o_labels.all(), out[2].all(), delta=1e-6)
np.testing.assert_almost_equal(o_cost, out[0], decimal=6)
np.testing.assert_almost_equal(o_logits, out[1], decimal=6)
np.testing.assert_almost_equal(o_labels, out[2], decimal=6)
def test_nce_op_remote(self):
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册