未验证 提交 35ed11f3 编写于 作者: L Leo Chen 提交者: GitHub

[cherry-pick] fix wrong place in ut (#42488)

* fix wrong place

* skip bf16 test if not supported (#42503)
上级 58f40144
...@@ -919,7 +919,7 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): ...@@ -919,7 +919,7 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
# load_inference_model # load_inference_model
paddle.enable_static() paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor()
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(path, exe)) paddle.static.load_inference_model(path, exe))
tensor_img = x tensor_img = x
...@@ -927,8 +927,8 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): ...@@ -927,8 +927,8 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
feed={feed_target_names[0]: tensor_img}, feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets) fetch_list=fetch_targets)
print("pred.numpy()", pred.numpy()) print("pred.numpy()", pred.numpy())
print("results", results) print("result", results[0])
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) self.assertTrue(np.array_equal(pred.numpy(), results[0]))
paddle.disable_static() paddle.disable_static()
def test_inference_save_load(self): def test_inference_save_load(self):
...@@ -1254,18 +1254,17 @@ class TestBf16(unittest.TestCase): ...@@ -1254,18 +1254,17 @@ class TestBf16(unittest.TestCase):
def test_bf16(self): def test_bf16(self):
def func_isinstance(): def func_isinstance():
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda(
cudnn_version = paddle.device.get_cudnn_version() ) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)):
if cudnn_version is not None and cudnn_version >= 8100: out_fp32 = self.train(enable_amp=False)
out_fp32 = self.train(enable_amp=False) out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') self.assertTrue(
self.assertTrue( np.allclose(
np.allclose( out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) self.assertTrue(
self.assertTrue( np.allclose(
np.allclose( out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
with _test_eager_guard(): with _test_eager_guard():
func_isinstance() func_isinstance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册