diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index 224f44d74864b90f0107587953698017d4c0fb7c..e97e089252ae31cf6163429c216a9ebc28590aa8 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -17,6 +17,7 @@ import contextlib import unittest import paddle from paddle.fluid import core +from paddle.fluid.dygraph.amp.auto_cast import _is_gpu_bfloat16_supported import os @@ -529,16 +530,15 @@ class TestBF16(unittest.TestCase): """ def test_shape(self): - cuda_major = paddle.version.cuda().split('.')[0].strip() - if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11: - """ MatmulKernel support bfloat16 only if cuda_major > 11.0. + if paddle.is_compiled_with_cuda() and _is_gpu_bfloat16_supported(): + """ MatmulKernel support bfloat16 only if cuda_major >= 11.0 and Compute Capability >= 8.0 """ A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16) A = A.cuda() B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16) B = B.cuda() C = paddle.einsum('i,i->', A, B) - self.assertEqual(C.item(), 8.0) + self.assertEqual(C.astype(paddle.float32).item(), 8.0) class TestComplex(unittest.TestCase):