diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 2e33d9fbf23c2bdc13d1e170e5fa461fa497563f..b028fd15b1b9396940f0660f88ddd25986345095 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -728,7 +728,7 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query, value_grad->set_dtype(value.dtype()); value_grad->set_layout(value.layout()); - if (bias) { + if (bias && bias_grad) { const int64_t bias_batch_size = bias.dims()[0]; const int64_t bias_seq_length = bias.dims()[1]; const int64_t bias_num_head = bias.dims()[2]; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu index 8e564ff7dfc98ecc7bf4b1ac443d861761b4cc47..cc4fd467dfc20bb30fea8519bb5e90dfffad828d 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu @@ -202,10 +202,13 @@ void MemoryEfficientAttentionForwardKernel( if (bias) { p.attn_bias_ptr = phi::SafeGetTensorPtr(bias); + const auto& bias_dims = bias.get().dims(); PD_MEA_CHECK_OVERFLOW( p.bias_strideB, - GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims)); - PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]); + GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims)); + PD_MEA_CHECK_OVERFLOW( + p.bias_strideH, + GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims)); PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]); } else { p.attn_bias_ptr = nullptr; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu index 2e16f9db34750da7709ae1b0f2579a56efbeaa27..23cd06c13c71606f62394a5263a039cf971f302d 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu @@ -486,10 +486,13 @@ void MemoryEfficientAttentionBackwardKernel( if (bias) { p.bias_ptr = phi::SafeGetTensorPtr(bias); + const auto& bias_dims = bias.get().dims(); PD_MEA_CHECK_OVERFLOW( p.bias_strideB, - GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims)); - PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]); + GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims)); + PD_MEA_CHECK_OVERFLOW( + p.bias_strideH, + GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims)); PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]); VLOG(3) << "p.bias_ptr" << p.bias_ptr; if (bias_grad) { diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h index beb0bdd770aa404d7867840def3eacb0150e5e8e..65dfb1bc8eced49412ca3cdb633b6fef451f0db6 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h @@ -62,6 +62,30 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims, return 0; } +inline int64_t GetMemoryEfficientBiasStrideH(const phi::DDim &bias_dims, + const phi::DDim &q_dims, + const phi::DDim &k_dims) { + int bias_dims_rank = bias_dims.size(); + if (bias_dims_rank == 2) { + return 0; + } else { + PADDLE_ENFORCE_EQ(bias_dims_rank, + 4, + phi::errors::InvalidArgument( + "The rank of attn_bias should be 2 or 4.")); + if (bias_dims[1] != q_dims[2]) { + PADDLE_ENFORCE_EQ( + bias_dims[1], + 1, + phi::errors::InvalidArgument( + "The second dim of attn_bias should be 1 or num_heads.")); + return 0; + } else { + return q_dims[1] * k_dims[1]; + } + } +} + #define PD_MEA_CHECK_OVERFLOW(__dst, ...) \ do { \ auto __src = (__VA_ARGS__); \