未验证 提交 019e1cf5 编写于 作者: S sneaxiy 提交者: GitHub

Fix memory efficient attention bug (#52117)

* fix mea compile error

* support 2-D bias

* add inline to avoid compile error

* polish codes
上级 a5b88cba
......@@ -17,6 +17,7 @@
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h"
namespace phi {
namespace fusion {
......@@ -195,7 +196,8 @@ void MemoryEfficientAttentionForwardKernel(
if (bias) {
p.attn_bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1];
p.bias_strideB =
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims);
p.bias_strideH = q_dims[1] * k_dims[1];
p.bias_strideM = k_dims[1];
} else {
......
......@@ -31,21 +31,50 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar
DEFAULT_ARCH = [50, 70, 75, 80]
MAX_ARCH = 90
ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION"
assert sorted(DEFAULT_ARCH) == DEFAULT_ARCH
def find_arch_range(min_arch, max_arch):
assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH
assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH
assert min_arch <= max_arch
n = len(DEFAULT_ARCH)
start_idx = 0
for i in range(n - 1):
if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]:
start_idx = i
break
end_idx = n
for i in range(n - 1):
if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]:
end_idx = i + 1
return DEFAULT_ARCH[start_idx:end_idx]
def find_max_arch(arch):
arch = list(sorted(arch))
idx = DEFAULT_ARCH.index(arch[-1])
if idx == len(DEFAULT_ARCH) - 1:
return MAX_ARCH
else:
return DEFAULT_ARCH[idx + 1]
def convert_to_arch_list(arch):
arch = arch.lower().strip()
if arch == "all":
return [50, 70, 75, 80]
return DEFAULT_ARCH
arch = [int(s.strip()) for s in arch.split(' ') if s.strip()]
arch = [int(s.strip()) for s in arch.split(';') if s.strip()]
arch = list(set(arch))
arch.sort()
for each_arch in arch:
assert each_arch < MAX_ARCH
return arch
return find_arch_range(arch[0], arch[-1])
def parse_args():
......@@ -64,7 +93,9 @@ def parse_args():
default=convert_to_arch_list("All"),
help="The CUDA architecture to be generated.",
)
return parser.parse_args()
args = parser.parse_args()
args.max_arch = find_max_arch(args.cuda_arch)
return args
args = parse_args()
......@@ -170,7 +201,7 @@ class FwdKernel:
def get_all(cls) -> List["FwdKernel"]:
kernels: List[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), zip(SM, SM[1:] + [MAX_ARCH])
[True, False], DTYPES.keys(), zip(SM, SM[1:] + [args.max_arch])
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
......@@ -280,7 +311,7 @@ class BwdKernel:
) in itertools.product(
[True, False],
DTYPES.keys(),
zip(SM, SM[1:] + [MAX_ARCH]),
zip(SM, SM[1:] + [args.max_arch]),
[True, False],
[32, 64, 128, 2**16],
):
......
......@@ -18,6 +18,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen/memory_efficient_attention.h"
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
......@@ -481,7 +482,8 @@ void MemoryEfficientAttentionBackwardKernel(
if (bias) {
p.bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1];
p.bias_strideB =
GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims);
p.bias_strideH = q_dims[1] * k_dims[1];
p.bias_strideM = k_dims[1];
VLOG(3) << "p.bias_ptr" << p.bias_ptr;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims,
const phi::DDim &q_dims,
const phi::DDim &k_dims) {
int bias_dims_rank = bias_dims.size();
if (bias_dims_rank != 2) {
PADDLE_ENFORCE_EQ(bias_dims_rank,
4,
phi::errors::InvalidArgument(
"The rank of attn_bias should be 2 or 4."));
}
PADDLE_ENFORCE_EQ(
bias_dims[bias_dims_rank - 1],
k_dims[1],
phi::errors::InvalidArgument("The last dim of attn_bias should be "
"equal to the sequence length of key."));
PADDLE_ENFORCE_EQ(bias_dims[bias_dims_rank - 2],
q_dims[1],
phi::errors::InvalidArgument(
"The 2nd last dim of attn_bias should be equal to "
"the sequence length of query."));
if (bias_dims_rank == 2) {
return 0;
}
if (bias_dims[0] == q_dims[0] && bias_dims[1] == q_dims[2]) {
return q_dims[2] * q_dims[1] * k_dims[1];
}
PADDLE_ENFORCE_EQ(
bias_dims[0],
1,
phi::errors::InvalidArgument(
"The first dim of attn_bias should be 1 or batch size."));
PADDLE_ENFORCE_EQ(
bias_dims[1],
1,
phi::errors::InvalidArgument(
"The second dim of attn_bias should be 1 or num_heads."));
return 0;
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册