未验证 提交 fd5b8eea 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case2-3 (#49986)

* add shape check for fused_multi_head_attention

* use raise for coverage test

* add unittest

* remove unnecessary pass

* add unittest
上级 3cf50f91
...@@ -192,5 +192,17 @@ class TestFusedAttentionNormalizeBefore(TestFusedAttention): ...@@ -192,5 +192,17 @@ class TestFusedAttentionNormalizeBefore(TestFusedAttention):
self.normalize_before = True self.normalize_before = True
class TestFusedAttentionAPIError(unittest.TestCase):
def test_invalid_x_rank(self):
def test_x_rank_1():
with paddle.fluid.dygraph.guard():
layer = FusedMultiHeadAttention(embed_dim=1, num_heads=1)
array = np.array([1.9], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1]), dtype='float32')
out = layer(x)
self.assertRaises(ValueError, test_x_rank_1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1051,5 +1051,31 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp): ...@@ -1051,5 +1051,31 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
) )
class TestFusedMultiAttentionAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([1.9], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1]), dtype='float32')
layer = paddle.incubate.nn.FusedMultiHeadAttention(
embed_dim=1, num_heads=1
)
out = layer(x)
self.assertRaises(ValueError, test_invalid_input_dim)
class TestFusedMultiTransformerAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32')
layer = paddle.incubate.nn.FusedTransformerEncoderLayer(
108, 108, 108, 0.0, 'relu'
)
out = layer(x)
self.assertRaises(ValueError, test_invalid_input_dim)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -615,6 +615,11 @@ def fused_multi_head_attention( ...@@ -615,6 +615,11 @@ def fused_multi_head_attention(
'downgrade_in_infer' if mode == 'downscale_in_infer' else mode 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
) # semantic transfer ) # semantic transfer
if x.ndim != 3:
raise ValueError(
f"The rank of the x should be 3, but received {x.ndim}."
)
if _non_static_mode(): if _non_static_mode():
if default_main_program().random_seed != 0: if default_main_program().random_seed != 0:
seed = default_main_program().random_seed seed = default_main_program().random_seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册