未验证 提交 d12837d3 编写于 作者: Y Yichen Zhang 提交者: GitHub

fix the bug of test_fused_multi_transformer_op on cuda12 (#55296)

上级 9210b1af
......@@ -1412,7 +1412,9 @@ class TestFusedMultiTransformerOpPreCacheStatic1(TestFusedMultiTransformerOp):
)
class TestFusedMultiAttentionAPIError(unittest.TestCase):
# Starts the name of this test with 'Z' to make this test
# run after others. If not, it will make other tests fail.
class ZTestFusedMultiAttentionAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([1.9], dtype=np.float32)
......@@ -1425,7 +1427,7 @@ class TestFusedMultiAttentionAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_invalid_input_dim)
class TestFusedMultiTransformerAPIError(unittest.TestCase):
class ZTestFusedMultiTransformerAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([], dtype=np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册