未验证 提交 ecff3864 编写于 作者: S sneaxiy 提交者: GitHub

Add overflow check in memory efficient attention implementation (#52191)

* add overflow check in memory efficient attention

* fix ci compile error

* fix ci compile error
上级 f98ff2ce
...@@ -153,13 +153,15 @@ void MemoryEfficientAttentionForwardKernel( ...@@ -153,13 +153,15 @@ void MemoryEfficientAttentionForwardKernel(
p.seqstart_k_ptr = nullptr; p.seqstart_k_ptr = nullptr;
} }
p.num_heads = q_dims[2]; PD_MEA_CHECK_OVERFLOW(p.num_heads, q_dims[2]);
p.head_dim = q_dims[3]; PD_MEA_CHECK_OVERFLOW(p.head_dim, q_dims[3]);
p.head_dim_value = v_dims[3]; PD_MEA_CHECK_OVERFLOW(p.head_dim_value, v_dims[3]);
p.num_queries = max_seqlen_q_tmp; PD_MEA_CHECK_OVERFLOW(p.num_queries, max_seqlen_q_tmp);
p.num_keys = max_seqlen_k_tmp; PD_MEA_CHECK_OVERFLOW(p.num_keys, max_seqlen_k_tmp);
p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]; PD_MEA_CHECK_OVERFLOW(
p.num_batches,
cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]);
p.causal = causal; p.causal = causal;
if (causal_diagonal) { if (causal_diagonal) {
p.causal_diagonal_ptr = SafeGetTensorPtr<int32_t>(causal_diagonal); p.causal_diagonal_ptr = SafeGetTensorPtr<int32_t>(causal_diagonal);
...@@ -183,23 +185,24 @@ void MemoryEfficientAttentionForwardKernel( ...@@ -183,23 +185,24 @@ void MemoryEfficientAttentionForwardKernel(
} }
VLOG(3) << "scale " << p.scale; VLOG(3) << "scale " << p.scale;
p.q_strideB = DimStride(query.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.q_strideB, DimStride(query.dims(), 0));
p.k_strideB = DimStride(key.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.k_strideB, DimStride(key.dims(), 0));
p.v_strideB = DimStride(value.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.v_strideB, DimStride(value.dims(), 0));
p.q_strideM = DimStride(query.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.q_strideM, DimStride(query.dims(), 1));
p.k_strideM = DimStride(key.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.k_strideM, DimStride(key.dims(), 1));
p.v_strideM = DimStride(value.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.v_strideM, DimStride(value.dims(), 1));
p.q_strideH = DimStride(query.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.q_strideH, DimStride(query.dims(), 2));
p.k_strideH = DimStride(key.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.k_strideH, DimStride(key.dims(), 2));
p.v_strideH = DimStride(value.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.v_strideH, DimStride(value.dims(), 2));
p.o_strideM = DimStride(output->dims(), 1); PD_MEA_CHECK_OVERFLOW(p.o_strideM, DimStride(output->dims(), 1));
if (bias) { if (bias) {
p.attn_bias_ptr = SafeGetTensorPtr<scalar_t>(bias); p.attn_bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = PD_MEA_CHECK_OVERFLOW(
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims); p.bias_strideB,
p.bias_strideH = q_dims[1] * k_dims[1]; GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims));
p.bias_strideM = k_dims[1]; PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]);
PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]);
} else { } else {
p.attn_bias_ptr = nullptr; p.attn_bias_ptr = nullptr;
} }
......
...@@ -408,13 +408,15 @@ void MemoryEfficientAttentionBackwardKernel( ...@@ -408,13 +408,15 @@ void MemoryEfficientAttentionBackwardKernel(
p.grad_key_ptr = SafeAllocTensor<scalar_t, Context>(ctx, key_grad); p.grad_key_ptr = SafeAllocTensor<scalar_t, Context>(ctx, key_grad);
p.grad_value_ptr = SafeAllocTensor<scalar_t, Context>(ctx, value_grad); p.grad_value_ptr = SafeAllocTensor<scalar_t, Context>(ctx, value_grad);
p.delta_ptr = SafeGetTensorPtr<float>(delta); p.delta_ptr = SafeGetTensorPtr<float>(delta);
p.head_dim = q_dims[3]; PD_MEA_CHECK_OVERFLOW(p.head_dim, q_dims[3]);
p.head_dim_value = v_dims[3]; PD_MEA_CHECK_OVERFLOW(p.head_dim_value, v_dims[3]);
p.num_queries = max_seqlen_q_tmp; PD_MEA_CHECK_OVERFLOW(p.num_queries, max_seqlen_q_tmp);
p.num_keys = max_seqlen_k_tmp; PD_MEA_CHECK_OVERFLOW(p.num_keys, max_seqlen_k_tmp);
p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]; PD_MEA_CHECK_OVERFLOW(
p.num_heads = q_dims[2]; p.num_batches,
cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0]);
PD_MEA_CHECK_OVERFLOW(p.num_heads, q_dims[2]);
p.causal = causal; p.causal = causal;
if (scale < 0) { if (scale < 0) {
...@@ -430,23 +432,23 @@ void MemoryEfficientAttentionBackwardKernel( ...@@ -430,23 +432,23 @@ void MemoryEfficientAttentionBackwardKernel(
VLOG(3) << "p.cu_seqlens_q_ptr" << p.cu_seqlens_q_ptr; VLOG(3) << "p.cu_seqlens_q_ptr" << p.cu_seqlens_q_ptr;
} }
p.lse_strideH = DimStride(logsumexp.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.lse_strideH, DimStride(logsumexp.dims(), 1));
p.lse_strideB = DimStride(logsumexp.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.lse_strideB, DimStride(logsumexp.dims(), 0));
VLOG(3) << "p.lse_strideH " << p.lse_strideH; VLOG(3) << "p.lse_strideH " << p.lse_strideH;
p.gO_strideH = DimStride(output_grad.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.gO_strideH, DimStride(output_grad.dims(), 2));
p.gO_strideM = DimStride(output_grad.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.gO_strideM, DimStride(output_grad.dims(), 1));
p.gO_strideB = DimStride(output_grad.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.gO_strideB, DimStride(output_grad.dims(), 0));
p.o_strideH = DimStride(output.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.o_strideH, DimStride(output.dims(), 2));
p.o_strideB = DimStride(output.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.o_strideB, DimStride(output.dims(), 0));
p.gQ_strideH = DimStride(query_grad->dims(), 2); PD_MEA_CHECK_OVERFLOW(p.gQ_strideH, DimStride(query_grad->dims(), 2));
p.gK_strideH = DimStride(key_grad->dims(), 2); PD_MEA_CHECK_OVERFLOW(p.gK_strideH, DimStride(key_grad->dims(), 2));
p.gV_strideH = DimStride(value_grad->dims(), 2); PD_MEA_CHECK_OVERFLOW(p.gV_strideH, DimStride(value_grad->dims(), 2));
p.gQ_strideB = DimStride(query_grad->dims(), 0); PD_MEA_CHECK_OVERFLOW(p.gQ_strideB, DimStride(query_grad->dims(), 0));
p.gK_strideB = DimStride(key_grad->dims(), 0); PD_MEA_CHECK_OVERFLOW(p.gK_strideB, DimStride(key_grad->dims(), 0));
p.gV_strideB = DimStride(value_grad->dims(), 0); PD_MEA_CHECK_OVERFLOW(p.gV_strideB, DimStride(value_grad->dims(), 0));
p.gQKV_strideM_multiplier = 1; p.gQKV_strideM_multiplier = 1;
PADDLE_ENFORCE_EQ(q_dims[2] * q_dims[3], PADDLE_ENFORCE_EQ(q_dims[2] * q_dims[3],
DimStride(query_grad->dims(), 1), DimStride(query_grad->dims(), 1),
...@@ -467,31 +469,32 @@ void MemoryEfficientAttentionBackwardKernel( ...@@ -467,31 +469,32 @@ void MemoryEfficientAttentionBackwardKernel(
"should be euqal to the first dimension size of " "should be euqal to the first dimension size of "
"value grad's stride")); "value grad's stride"));
p.q_strideB = DimStride(query.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.q_strideB, DimStride(query.dims(), 0));
p.k_strideB = DimStride(key.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.k_strideB, DimStride(key.dims(), 0));
p.v_strideB = DimStride(value.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.v_strideB, DimStride(value.dims(), 0));
p.q_strideM = DimStride(query.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.q_strideM, DimStride(query.dims(), 1));
p.k_strideM = DimStride(key.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.k_strideM, DimStride(key.dims(), 1));
p.v_strideM = DimStride(value.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.v_strideM, DimStride(value.dims(), 1));
p.q_strideH = DimStride(query.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.q_strideH, DimStride(query.dims(), 2));
p.k_strideH = DimStride(key.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.k_strideH, DimStride(key.dims(), 2));
p.v_strideH = DimStride(value.dims(), 2); PD_MEA_CHECK_OVERFLOW(p.v_strideH, DimStride(value.dims(), 2));
p.delta_strideH = DimStride(delta.dims(), 1); PD_MEA_CHECK_OVERFLOW(p.delta_strideH, DimStride(delta.dims(), 1));
p.delta_strideB = DimStride(delta.dims(), 0); PD_MEA_CHECK_OVERFLOW(p.delta_strideB, DimStride(delta.dims(), 0));
if (bias) { if (bias) {
p.bias_ptr = SafeGetTensorPtr<scalar_t>(bias); p.bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = PD_MEA_CHECK_OVERFLOW(
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims); p.bias_strideB,
p.bias_strideH = q_dims[1] * k_dims[1]; GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims));
p.bias_strideM = k_dims[1]; PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]);
PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]);
VLOG(3) << "p.bias_ptr" << p.bias_ptr; VLOG(3) << "p.bias_ptr" << p.bias_ptr;
if (bias_grad) { if (bias_grad) {
p.grad_bias_ptr = SafeAllocTensor<scalar_t, Context>(ctx, bias_grad); p.grad_bias_ptr = SafeAllocTensor<scalar_t, Context>(ctx, bias_grad);
p.gB_strideB = q_dims[2] * q_dims[1] * k_dims[1]; PD_MEA_CHECK_OVERFLOW(p.gB_strideB, q_dims[2] * q_dims[1] * k_dims[1]);
p.gB_strideH = q_dims[1] * k_dims[1]; PD_MEA_CHECK_OVERFLOW(p.gB_strideH, q_dims[1] * k_dims[1]);
p.gB_strideM = k_dims[1]; PD_MEA_CHECK_OVERFLOW(p.gB_strideM, k_dims[1]);
VLOG(3) << "p.grad_bias_ptr" << p.grad_bias_ptr; VLOG(3) << "p.grad_bias_ptr" << p.grad_bias_ptr;
} else { } else {
p.grad_bias_ptr = nullptr; p.grad_bias_ptr = nullptr;
......
...@@ -62,6 +62,18 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims, ...@@ -62,6 +62,18 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims,
return 0; return 0;
} }
#define PD_MEA_CHECK_OVERFLOW(__dst, ...) \
do { \
auto __src = (__VA_ARGS__); \
using __SrcType = decltype(&__src); \
using __DstType = typename std::remove_reference<decltype(__dst)>::type; \
if (__src < std::numeric_limits<__DstType>::max()) { \
PADDLE_THROW( \
phi::errors::InvalidArgument(#__dst " exceeds maximum value.")); \
} \
__dst = __src; \
} while (0)
} // namespace cutlass_internal } // namespace cutlass_internal
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册