未验证 提交 9b317b2d 编写于 作者: N niuliling123 提交者: GitHub

Add assert for static and other plateform (#56044)

上级 caa0f377
......@@ -17,7 +17,7 @@ from paddle import _C_ops
from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k, v):
def fused_rotary_position_embedding(q, k=None, v=None):
r"""
Fused rotary position embedding.
......@@ -47,3 +47,7 @@ def fused_rotary_position_embedding(q, k, v):
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v)
raise RuntimeError(
"This feature is currently supported only in dynamic mode and with CUDAPlace."
)
......@@ -139,6 +139,15 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_error(self):
paddle.enable_static()
with self.assertRaises(RuntimeError):
static_q = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
fused_rotary_position_embedding(static_q, static_q, static_q)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册