未验证 提交 39a386c7 编写于 作者: S sneaxiy 提交者: GitHub

Fix MEA backward segmentation fault error (#55382)

* fix mea backward seg fault

* fix bias stride error
上级 79c922d0
......@@ -91,6 +91,8 @@ phi::DenseTensor get_pad_lse(const phi::GPUContext& dev_ctx,
ViewSliceHelper<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, stride, in_dim[2], out_second_dim);
return *lse;
} else {
return *lse;
}
}
} // namespace funcs
......
......@@ -198,10 +198,13 @@ void MemoryEfficientAttentionForwardKernel(
if (bias) {
p.attn_bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
const auto& bias_dims = bias.get().dims();
PD_MEA_CHECK_OVERFLOW(
p.bias_strideB,
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]);
GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(
p.bias_strideH,
GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]);
} else {
p.attn_bias_ptr = nullptr;
......
......@@ -484,10 +484,13 @@ void MemoryEfficientAttentionBackwardKernel(
if (bias) {
p.bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
const auto& bias_dims = bias.get().dims();
PD_MEA_CHECK_OVERFLOW(
p.bias_strideB,
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]);
GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(
p.bias_strideH,
GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]);
VLOG(3) << "p.bias_ptr" << p.bias_ptr;
if (bias_grad) {
......
......@@ -62,6 +62,30 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims,
return 0;
}
inline int64_t GetMemoryEfficientBiasStrideH(const phi::DDim &bias_dims,
const phi::DDim &q_dims,
const phi::DDim &k_dims) {
int bias_dims_rank = bias_dims.size();
if (bias_dims_rank == 2) {
return 0;
} else {
PADDLE_ENFORCE_EQ(bias_dims_rank,
4,
phi::errors::InvalidArgument(
"The rank of attn_bias should be 2 or 4."));
if (bias_dims[1] != q_dims[2]) {
PADDLE_ENFORCE_EQ(
bias_dims[1],
1,
phi::errors::InvalidArgument(
"The second dim of attn_bias should be 1 or num_heads."));
return 0;
} else {
return q_dims[1] * k_dims[1];
}
}
}
#define PD_MEA_CHECK_OVERFLOW(__dst, ...) \
do { \
auto __src = (__VA_ARGS__); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册