未验证 提交 35ed11f3 编写于 作者: L Leo Chen 提交者: GitHub

[cherry-pick] fix wrong place in ut (#42488)

* fix wrong place

* skip bf16 test if not supported (#42503)
上级 58f40144
......@@ -919,7 +919,7 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
# load_inference_model
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
exe = paddle.static.Executor()
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(path, exe))
tensor_img = x
......@@ -927,8 +927,8 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
print("pred.numpy()", pred.numpy())
print("results", results)
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5))
print("result", results[0])
self.assertTrue(np.array_equal(pred.numpy(), results[0]))
paddle.disable_static()
def test_inference_save_load(self):
......@@ -1254,18 +1254,17 @@ class TestBf16(unittest.TestCase):
def test_bf16(self):
def func_isinstance():
if fluid.core.is_compiled_with_cuda():
cudnn_version = paddle.device.get_cudnn_version()
if cudnn_version is not None and cudnn_version >= 8100:
out_fp32 = self.train(enable_amp=False)
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
if fluid.core.is_compiled_with_cuda(
) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)):
out_fp32 = self.train(enable_amp=False)
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
with _test_eager_guard():
func_isinstance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册