From 7306d1fba1efefe48b9bc151800ec3a42f5336ee Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 20 May 2022 20:25:31 +0800 Subject: [PATCH] fix fused_attention_op cacheKV InferShape (#42900) --- paddle/fluid/operators/fused/fused_attention_op.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index e473f8ff06..1f377810a2 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -163,11 +163,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "The third dim of CacheKV must be equal with num " "head %d, but got %d", y_dim[1], c_dim[2])); // num_head - PADDLE_ENFORCE_GE( - c_dim[3], 0, - paddle::platform::errors::InvalidArgument( - "The forth dim of CacheKV must be greater than 0, but got %d", - c_dim[3])); // cache_seq_len + // In compile stage, input seq_len can be -1, in that case + // c_dim[3] may < 0 in while + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_GE( + c_dim[3], 0, + paddle::platform::errors::InvalidArgument( + "The forth dim of CacheKV must be greater than 0, but got %d", + c_dim[3])); // cache_seq_len + } PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2], paddle::platform::errors::InvalidArgument( "The fifth dim of CacheKV must be equal with head " -- GitLab