diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 18620f55367f63629d7746daac77aec57e0aec30..d200b77eea83fcdaa9bee238a638994a271b5f21 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -919,7 +919,7 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): # load_inference_model paddle.enable_static() - exe = paddle.static.Executor(paddle.CPUPlace()) + exe = paddle.static.Executor() [inference_program, feed_target_names, fetch_targets] = ( paddle.static.load_inference_model(path, exe)) tensor_img = x @@ -927,8 +927,8 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) print("pred.numpy()", pred.numpy()) - print("results", results) - self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) + print("result", results[0]) + self.assertTrue(np.array_equal(pred.numpy(), results[0])) paddle.disable_static() def test_inference_save_load(self): @@ -1254,18 +1254,17 @@ class TestBf16(unittest.TestCase): def test_bf16(self): def func_isinstance(): - if fluid.core.is_compiled_with_cuda(): - cudnn_version = paddle.device.get_cudnn_version() - if cudnn_version is not None and cudnn_version >= 8100: - out_fp32 = self.train(enable_amp=False) - out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') - out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) + if fluid.core.is_compiled_with_cuda( + ) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)): + out_fp32 = self.train(enable_amp=False) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) with _test_eager_guard(): func_isinstance()