未验证 提交 cc262c55 编写于 作者: S sneaxiy 提交者: GitHub

Fix mea segmentation fault error (#55408)

* fix mea seg fault develop

* fix bias_grad seg fault
上级 96ff6103
......@@ -728,7 +728,7 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
value_grad->set_dtype(value.dtype());
value_grad->set_layout(value.layout());
if (bias) {
if (bias && bias_grad) {
const int64_t bias_batch_size = bias.dims()[0];
const int64_t bias_seq_length = bias.dims()[1];
const int64_t bias_num_head = bias.dims()[2];
......
......@@ -202,10 +202,13 @@ void MemoryEfficientAttentionForwardKernel(
if (bias) {
p.attn_bias_ptr = phi::SafeGetTensorPtr<scalar_t>(bias);
const auto& bias_dims = bias.get().dims();
PD_MEA_CHECK_OVERFLOW(
p.bias_strideB,
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]);
GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(
p.bias_strideH,
GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]);
} else {
p.attn_bias_ptr = nullptr;
......
......@@ -486,10 +486,13 @@ void MemoryEfficientAttentionBackwardKernel(
if (bias) {
p.bias_ptr = phi::SafeGetTensorPtr<scalar_t>(bias);
const auto& bias_dims = bias.get().dims();
PD_MEA_CHECK_OVERFLOW(
p.bias_strideB,
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideH, q_dims[1] * k_dims[1]);
GetMemoryEfficientBiasStrideB(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(
p.bias_strideH,
GetMemoryEfficientBiasStrideH(bias_dims, q_dims, k_dims));
PD_MEA_CHECK_OVERFLOW(p.bias_strideM, k_dims[1]);
VLOG(3) << "p.bias_ptr" << p.bias_ptr;
if (bias_grad) {
......
......@@ -62,6 +62,30 @@ inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims,
return 0;
}
inline int64_t GetMemoryEfficientBiasStrideH(const phi::DDim &bias_dims,
const phi::DDim &q_dims,
const phi::DDim &k_dims) {
int bias_dims_rank = bias_dims.size();
if (bias_dims_rank == 2) {
return 0;
} else {
PADDLE_ENFORCE_EQ(bias_dims_rank,
4,
phi::errors::InvalidArgument(
"The rank of attn_bias should be 2 or 4."));
if (bias_dims[1] != q_dims[2]) {
PADDLE_ENFORCE_EQ(
bias_dims[1],
1,
phi::errors::InvalidArgument(
"The second dim of attn_bias should be 1 or num_heads."));
return 0;
} else {
return q_dims[1] * k_dims[1];
}
}
}
#define PD_MEA_CHECK_OVERFLOW(__dst, ...) \
do { \
auto __src = (__VA_ARGS__); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册