diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu index 62ef34e00d98554fb8e71b474a6ebb90e0e919f3..36578a361d78353b6f644baf15f74cb7bcd5141d 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu @@ -153,13 +153,15 @@ void MemoryEfficientAttentionForwardKernel( p.seqstart_k_ptr = nullptr; } - p.num_heads = q_dims[2]; - p.head_dim = q_dims[3]; - p.head_dim_value = v_dims[3]; + PD_MEA_CHECK_OVERFLOW(p.num_heads, q_dims[2]); + PD_MEA_CHECK_OVERFLOW(p.head_dim, q_dims[3]); + PD_MEA_CHECK_OVERFLOW(p.head_dim_value, v_dims[3]); - p.num_queries = max_seqlen_q_tmp; - p.num_keys = max_seqlen_k_tmp; - p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]; + PD_MEA_CHECK_OVERFLOW(p.num_queries, max_seqlen_q_tmp); + PD_MEA_CHECK_OVERFLOW(p.num_keys, max_seqlen_k_tmp); + PD_MEA_CHECK_OVERFLOW( + p.num_batches, + cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]); p.causal = causal; if (causal_diagonal) { p.causal_diagonal_ptr = SafeGetTensorPtr(causal_diagonal); @@ -183,23 +185,24 @@ void MemoryEfficientAttentionForwardKernel( } VLOG(3) << "scale " << p.scale; - p.q_strideB = DimStride(query.dims(), 0); - p.k_strideB = DimStride(key.dims(), 0); - p.v_strideB = DimStride(value.dims(), 0); - p.q_strideM = DimStride(query.dims(), 1); - p.k_strideM = DimStride(key.dims(), 1); - p.v_strideM = DimStride(value.dims(), 1); - p.q_strideH = DimStride(query.dims(), 2); - p.k_strideH = DimStride(key.dims(), 2); - p.v_strideH = DimStride(value.dims(), 2); - p.o_strideM = DimStride(output->dims(), 1); + PD_MEA_CHECK_OVERFLOW(p.q_strideB, DimStride(query.dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.k_strideB, DimStride(key.dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.v_strideB, DimStride(value.dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.q_strideM, DimStride(query.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.k_strideM, DimStride(key.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.v_strideM, DimStride(value.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.q_strideH, DimStride(query.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.k_strideH, DimStride(key.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.v_strideH, DimStride(value.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.o_strideM, DimStride(output->dims(), 1)); if (bias) { p.attn_bias_ptr = SafeGetTensorPtr(bias); - p.bias_strideB = - GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims); - p.bias_strideH = q_dims[1] * k_dims[1]; - p.bias_strideM = k_dims[1]; + PD_MEA_CHECK_OVERFLOW( + p.bias_strideB, + GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims)); + PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]); + PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]); } else { p.attn_bias_ptr = nullptr; } diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu index 3f529d32b93719bec328cccadbf62e3ffa0177d9..00d09cf00a810ee2c769156b5f68dc701170cb15 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu @@ -408,13 +408,15 @@ void MemoryEfficientAttentionBackwardKernel( p.grad_key_ptr = SafeAllocTensor(ctx, key_grad); p.grad_value_ptr = SafeAllocTensor(ctx, value_grad); p.delta_ptr = SafeGetTensorPtr(delta); - p.head_dim = q_dims[3]; - p.head_dim_value = v_dims[3]; - - p.num_queries = max_seqlen_q_tmp; - p.num_keys = max_seqlen_k_tmp; - p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]; - p.num_heads = q_dims[2]; + PD_MEA_CHECK_OVERFLOW(p.head_dim, q_dims[3]); + PD_MEA_CHECK_OVERFLOW(p.head_dim_value, v_dims[3]); + + PD_MEA_CHECK_OVERFLOW(p.num_queries, max_seqlen_q_tmp); + PD_MEA_CHECK_OVERFLOW(p.num_keys, max_seqlen_k_tmp); + PD_MEA_CHECK_OVERFLOW( + p.num_batches, + cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]); + PD_MEA_CHECK_OVERFLOW(p.num_heads, q_dims[2]); p.causal = causal; if (scale < 0) { @@ -430,23 +432,23 @@ void MemoryEfficientAttentionBackwardKernel( VLOG(3) << "p.cu_seqlens_q_ptr" << p.cu_seqlens_q_ptr; } - p.lse_strideH = DimStride(logsumexp.dims(), 1); - p.lse_strideB = DimStride(logsumexp.dims(), 0); + PD_MEA_CHECK_OVERFLOW(p.lse_strideH, DimStride(logsumexp.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.lse_strideB, DimStride(logsumexp.dims(), 0)); VLOG(3) << "p.lse_strideH " << p.lse_strideH; - p.gO_strideH = DimStride(output_grad.dims(), 2); - p.gO_strideM = DimStride(output_grad.dims(), 1); - p.gO_strideB = DimStride(output_grad.dims(), 0); + PD_MEA_CHECK_OVERFLOW(p.gO_strideH, DimStride(output_grad.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.gO_strideM, DimStride(output_grad.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.gO_strideB, DimStride(output_grad.dims(), 0)); - p.o_strideH = DimStride(output.dims(), 2); - p.o_strideB = DimStride(output.dims(), 0); + PD_MEA_CHECK_OVERFLOW(p.o_strideH, DimStride(output.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.o_strideB, DimStride(output.dims(), 0)); - p.gQ_strideH = DimStride(query_grad->dims(), 2); - p.gK_strideH = DimStride(key_grad->dims(), 2); - p.gV_strideH = DimStride(value_grad->dims(), 2); - p.gQ_strideB = DimStride(query_grad->dims(), 0); - p.gK_strideB = DimStride(key_grad->dims(), 0); - p.gV_strideB = DimStride(value_grad->dims(), 0); + PD_MEA_CHECK_OVERFLOW(p.gQ_strideH, DimStride(query_grad->dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.gK_strideH, DimStride(key_grad->dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.gV_strideH, DimStride(value_grad->dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.gQ_strideB, DimStride(query_grad->dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.gK_strideB, DimStride(key_grad->dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.gV_strideB, DimStride(value_grad->dims(), 0)); p.gQKV_strideM_multiplier = 1; PADDLE_ENFORCE_EQ(q_dims[2] * q_dims[3], DimStride(query_grad->dims(), 1), @@ -467,31 +469,32 @@ void MemoryEfficientAttentionBackwardKernel( "should be euqal to the first dimension size of " "value grad's stride")); - p.q_strideB = DimStride(query.dims(), 0); - p.k_strideB = DimStride(key.dims(), 0); - p.v_strideB = DimStride(value.dims(), 0); - p.q_strideM = DimStride(query.dims(), 1); - p.k_strideM = DimStride(key.dims(), 1); - p.v_strideM = DimStride(value.dims(), 1); - p.q_strideH = DimStride(query.dims(), 2); - p.k_strideH = DimStride(key.dims(), 2); - p.v_strideH = DimStride(value.dims(), 2); + PD_MEA_CHECK_OVERFLOW(p.q_strideB, DimStride(query.dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.k_strideB, DimStride(key.dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.v_strideB, DimStride(value.dims(), 0)); + PD_MEA_CHECK_OVERFLOW(p.q_strideM, DimStride(query.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.k_strideM, DimStride(key.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.v_strideM, DimStride(value.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.q_strideH, DimStride(query.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.k_strideH, DimStride(key.dims(), 2)); + PD_MEA_CHECK_OVERFLOW(p.v_strideH, DimStride(value.dims(), 2)); - p.delta_strideH = DimStride(delta.dims(), 1); - p.delta_strideB = DimStride(delta.dims(), 0); + PD_MEA_CHECK_OVERFLOW(p.delta_strideH, DimStride(delta.dims(), 1)); + PD_MEA_CHECK_OVERFLOW(p.delta_strideB, DimStride(delta.dims(), 0)); if (bias) { p.bias_ptr = SafeGetTensorPtr(bias); - p.bias_strideB = - GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims); - p.bias_strideH = q_dims[1] * k_dims[1]; - p.bias_strideM = k_dims[1]; + PD_MEA_CHECK_OVERFLOW( + p.bias_strideB, + GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims)); + PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]); + PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]); VLOG(3) << "p.bias_ptr" << p.bias_ptr; if (bias_grad) { p.grad_bias_ptr = SafeAllocTensor(ctx, bias_grad); - p.gB_strideB = q_dims[2] * q_dims[1] * k_dims[1]; - p.gB_strideH = q_dims[1] * k_dims[1]; - p.gB_strideM = k_dims[1]; + PD_MEA_CHECK_OVERFLOW(p.gB_strideB, q_dims[2] * q_dims[1] * k_dims[1]); + PD_MEA_CHECK_OVERFLOW(p.gB_strideH, q_dims[1] * k_dims[1]); + PD_MEA_CHECK_OVERFLOW(p.gB_strideM, k_dims[1]); VLOG(3) << "p.grad_bias_ptr" << p.grad_bias_ptr; } else { p.grad_bias_ptr = nullptr; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h index 9f3a70da1970e07203a88bfe9f3cd306d93bc08e..6795e9d4c0a31d1c7872857e28e26ba270408f2c 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h @@ -62,6 +62,18 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims, return 0; } +#define PD_MEA_CHECK_OVERFLOW(__dst, ...) \ + do { \ + auto __src = (__VA_ARGS__); \ + using __SrcType = decltype(&__src); \ + using __DstType = typename std::remove_reference::type; \ + if (__src < std::numeric_limits<__DstType>::max()) { \ + PADDLE_THROW( \ + phi::errors::InvalidArgument(#__dst " exceeds maximum value.")); \ + } \ + __dst = __src; \ + } while (0) + } // namespace cutlass_internal } // namespace fusion } // namespace phi