未验证 提交 7306d1fb 编写于 作者: W WangXi 提交者: GitHub

fix fused_attention_op cacheKV InferShape (#42900)

上级 f36a9464
......@@ -163,11 +163,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
y_dim[1], c_dim[2])); // num_head
PADDLE_ENFORCE_GE(
c_dim[3], 0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
// In compile stage, input seq_len can be -1, in that case
// c_dim[3] may < 0 in while
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
c_dim[3], 0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
}
PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册