未验证 提交 fd5b8eea 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case2-3 (#49986)

* add shape check for fused_multi_head_attention

* use raise for coverage test

* add unittest

* remove unnecessary pass

* add unittest
上级 3cf50f91
......@@ -192,5 +192,17 @@ class TestFusedAttentionNormalizeBefore(TestFusedAttention):
self.normalize_before = True
class TestFusedAttentionAPIError(unittest.TestCase):
def test_invalid_x_rank(self):
def test_x_rank_1():
with paddle.fluid.dygraph.guard():
layer = FusedMultiHeadAttention(embed_dim=1, num_heads=1)
array = np.array([1.9], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1]), dtype='float32')
out = layer(x)
self.assertRaises(ValueError, test_x_rank_1)
if __name__ == "__main__":
unittest.main()
......@@ -1051,5 +1051,31 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
)
class TestFusedMultiAttentionAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([1.9], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1]), dtype='float32')
layer = paddle.incubate.nn.FusedMultiHeadAttention(
embed_dim=1, num_heads=1
)
out = layer(x)
self.assertRaises(ValueError, test_invalid_input_dim)
class TestFusedMultiTransformerAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32')
layer = paddle.incubate.nn.FusedTransformerEncoderLayer(
108, 108, 108, 0.0, 'relu'
)
out = layer(x)
self.assertRaises(ValueError, test_invalid_input_dim)
if __name__ == "__main__":
unittest.main()
......@@ -615,6 +615,11 @@ def fused_multi_head_attention(
'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
) # semantic transfer
if x.ndim != 3:
raise ValueError(
f"The rank of the x should be 3, but received {x.ndim}."
)
if _non_static_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册