diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py index 38e4d5ad5480beb195bcc0c3cc21f033df8fbd5d..8f11a585884636f4dea711d0ea07e197b3856a19 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py @@ -462,10 +462,17 @@ class TestSeResnet(unittest.TestCase): self.assertTrue( np.allclose(dy_jit_pre, st_pre), msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre)) - self.assertTrue( - np.allclose(predictor_pre, st_pre), - msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(predictor_pre, - st_pre)) + + flat_st_pre = st_pre.flatten() + flat_predictor_pre = np.array(predictor_pre).flatten() + for i in range(len(flat_predictor_pre)): + # modify precision to 1e-6, avoid unittest failed + self.assertAlmostEqual( + flat_predictor_pre[i], + flat_st_pre[i], + delta=1e-6, + msg="predictor_pre:\n {}\n, st_pre: \n{}.".format( + flat_predictor_pre[i], flat_st_pre[i])) def test_check_result(self): pred_1, loss_1, acc1_1, acc5_1 = train(