From cc262c5591f3ad704fb7fc91a26a926b3881a0b7 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 19 Jul 2023 00:04:47 +0800 Subject: [PATCH] Fix mea segmentation fault error (#55408) * fix mea seg fault develop * fix bias_grad seg fault --- paddle/phi/infermeta/backward.cc | 2 +- .../cutlass/memory_efficient_attention.cu | 7 ++++-- .../memory_efficient_attention_backward.cu | 7 ++++-- .../memory_efficient_attention_utils.h | 24 +++++++++++++++++++ 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 2e33d9fbf23..b028fd15b1b 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -728,7 +728,7 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query, value_grad->set_dtype(value.dtype()); value_grad->set_layout(value.layout()); - if (bias) { + if (bias && bias_grad) { const int64_t bias_batch_size = bias.dims()[0]; const int64_t bias_seq_length = bias.dims()[1]; const int64_t bias_num_head = bias.dims()[2]; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu index 8e564ff7dfc..cc4fd467dfc 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu @@ -202,10 +202,13 @@ void MemoryEfficientAttentionForwardKernel( if (bias) { p.attn_bias_ptr = phi::SafeGetTensorPtr(bias); + const auto& bias_dims = bias.get().dims(); PD_MEA_CHECK_OVERFLOW( p.bias_strideB, - GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims)); - PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]); + GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims)); + PD_MEA_CHECK_OVERFLOW( + p.bias_strideH, + GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims)); PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]); } else { p.attn_bias_ptr = nullptr; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu index 2e16f9db347..23cd06c13c7 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu @@ -486,10 +486,13 @@ void MemoryEfficientAttentionBackwardKernel( if (bias) { p.bias_ptr = phi::SafeGetTensorPtr(bias); + const auto& bias_dims = bias.get().dims(); PD_MEA_CHECK_OVERFLOW( p.bias_strideB, - GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims)); - PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]); + GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims)); + PD_MEA_CHECK_OVERFLOW( + p.bias_strideH, + GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims)); PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]); VLOG(3) << "p.bias_ptr" << p.bias_ptr; if (bias_grad) { diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h index beb0bdd770a..65dfb1bc8ec 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h @@ -62,6 +62,30 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims, return 0; } +inline int64_t GetMemoryEfficientBiasStrideH(const phi::DDim &bias_dims, + const phi::DDim &q_dims, + const phi::DDim &k_dims) { + int bias_dims_rank = bias_dims.size(); + if (bias_dims_rank == 2) { + return 0; + } else { + PADDLE_ENFORCE_EQ(bias_dims_rank, + 4, + phi::errors::InvalidArgument( + "The rank of attn_bias should be 2 or 4.")); + if (bias_dims[1] != q_dims[2]) { + PADDLE_ENFORCE_EQ( + bias_dims[1], + 1, + phi::errors::InvalidArgument( + "The second dim of attn_bias should be 1 or num_heads.")); + return 0; + } else { + return q_dims[1] * k_dims[1]; + } + } +} + #define PD_MEA_CHECK_OVERFLOW(__dst, ...) \ do { \ auto __src = (__VA_ARGS__); \ -- GitLab