未验证 提交 acfdb8b3 编写于 作者: X xiongkun 提交者: GitHub

fix test-einsum-v2 unittest in cuda 11.7 (#44772)

上级 65a3530c
...@@ -17,6 +17,7 @@ import contextlib ...@@ -17,6 +17,7 @@ import contextlib
import unittest import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph.amp.auto_cast import _is_gpu_bfloat16_supported
import os import os
...@@ -529,16 +530,15 @@ class TestBF16(unittest.TestCase): ...@@ -529,16 +530,15 @@ class TestBF16(unittest.TestCase):
""" """
def test_shape(self): def test_shape(self):
cuda_major = paddle.version.cuda().split('.')[0].strip() if paddle.is_compiled_with_cuda() and _is_gpu_bfloat16_supported():
if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11: """ MatmulKernel support bfloat16 only if cuda_major >= 11.0 and Compute Capability >= 8.0
""" MatmulKernel support bfloat16 only if cuda_major > 11.0.
""" """
A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16) A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
A = A.cuda() A = A.cuda()
B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16) B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
B = B.cuda() B = B.cuda()
C = paddle.einsum('i,i->', A, B) C = paddle.einsum('i,i->', A, B)
self.assertEqual(C.item(), 8.0) self.assertEqual(C.astype(paddle.float32).item(), 8.0)
class TestComplex(unittest.TestCase): class TestComplex(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册