diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 18e5ee845d26e7d3ac58f2be0c354f266fd565fa..5b11eee61a0fdf07302f91c9399f289098680553 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -836,7 +836,13 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, FINAL_MASK); } else { if (bias_is_mask) { -#ifndef __HIPCC__ +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) + PADDLE_ENFORCE_EQ(bias_is_mask, + false, + platform::errors::InvalidArgument( + "QK_bias is mask can't be supported on rocm or " + "cuda_arch<700")); +#else constexpr int ITEMS_PER_THREAD = 1; bool is_half2 = true; @@ -853,11 +859,6 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, batch_size, head_num, seq_len); -#else - PADDLE_ENFORCE_EQ(bias_is_mask, - false, - platform::errors::InvalidArgument( - "rocm can't support that QK_bias is mask")); #endif } else { SoftmaxKernelWithEltadd2<__half2><<>>(