diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu index 4990fea3e028b9fa7b042a5d88b01edb538aa16a..e1a91a8a8bb020497069f3de91123b8ff336602b 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention.cu @@ -17,6 +17,7 @@ #include "paddle/fluid/platform/errors.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h" namespace phi { namespace fusion { @@ -195,7 +196,8 @@ void MemoryEfficientAttentionForwardKernel( if (bias) { p.attn_bias_ptr = SafeGetTensorPtr(bias); - p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1]; + p.bias_strideB = + GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims); p.bias_strideH = q_dims[1] * k_dims[1]; p.bias_strideM = k_dims[1]; } else { diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py index b8fe4e63ece7ae6066974280adfe70832c9684d8..2baa7b07d98492b7d918a50ef61cce1ad1de7007 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py @@ -31,21 +31,50 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple, TypeVar +DEFAULT_ARCH = [50, 70, 75, 80] MAX_ARCH = 90 ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION" +assert sorted(DEFAULT_ARCH) == DEFAULT_ARCH + + +def find_arch_range(min_arch, max_arch): + assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH + assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH + assert min_arch <= max_arch + n = len(DEFAULT_ARCH) + + start_idx = 0 + for i in range(n - 1): + if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]: + start_idx = i + break + + end_idx = n + for i in range(n - 1): + if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]: + end_idx = i + 1 + return DEFAULT_ARCH[start_idx:end_idx] + + +def find_max_arch(arch): + arch = list(sorted(arch)) + idx = DEFAULT_ARCH.index(arch[-1]) + if idx == len(DEFAULT_ARCH) - 1: + return MAX_ARCH + else: + return DEFAULT_ARCH[idx + 1] + def convert_to_arch_list(arch): arch = arch.lower().strip() if arch == "all": - return [50, 70, 75, 80] + return DEFAULT_ARCH - arch = [int(s.strip()) for s in arch.split(' ') if s.strip()] + arch = [int(s.strip()) for s in arch.split(';') if s.strip()] arch = list(set(arch)) arch.sort() - for each_arch in arch: - assert each_arch < MAX_ARCH - return arch + return find_arch_range(arch[0], arch[-1]) def parse_args(): @@ -64,7 +93,9 @@ def parse_args(): default=convert_to_arch_list("All"), help="The CUDA architecture to be generated.", ) - return parser.parse_args() + args = parser.parse_args() + args.max_arch = find_max_arch(args.cuda_arch) + return args args = parse_args() @@ -170,7 +201,7 @@ class FwdKernel: def get_all(cls) -> List["FwdKernel"]: kernels: List[FwdKernel] = [] for aligned, dtype, (sm, sm_max) in itertools.product( - [True, False], DTYPES.keys(), zip(SM, SM[1:] + [MAX_ARCH]) + [True, False], DTYPES.keys(), zip(SM, SM[1:] + [args.max_arch]) ): # Remove some kernels we don't use if dtype == "bf16" and sm < 80: @@ -280,7 +311,7 @@ class BwdKernel: ) in itertools.product( [True, False], DTYPES.keys(), - zip(SM, SM[1:] + [MAX_ARCH]), + zip(SM, SM[1:] + [args.max_arch]), [True, False], [32, 64, 128, 2**16], ): diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu index 4aee8205033739950b1e05299226e43b02f48e79..ac9eb64c120acfb289a006ea821ca124ac0400b8 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_backward.cu @@ -18,6 +18,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen/memory_efficient_attention.h" +#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cum_kernel.h" @@ -481,7 +482,8 @@ void MemoryEfficientAttentionBackwardKernel( if (bias) { p.bias_ptr = SafeGetTensorPtr(bias); - p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1]; + p.bias_strideB = + GetMemoryEfficientBiasStrideB(bias.get().dims(), q_dims, k_dims); p.bias_strideH = q_dims[1] * k_dims[1]; p.bias_strideM = k_dims[1]; VLOG(3) << "p.bias_ptr" << p.bias_ptr; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9f3a70da1970e07203a88bfe9f3cd306d93bc08e --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_utils.h @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/ddim.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { + +inline int64_t GetMemoryEfficientBiasStrideB(const phi::DDim &bias_dims, + const phi::DDim &q_dims, + const phi::DDim &k_dims) { + int bias_dims_rank = bias_dims.size(); + if (bias_dims_rank != 2) { + PADDLE_ENFORCE_EQ(bias_dims_rank, + 4, + phi::errors::InvalidArgument( + "The rank of attn_bias should be 2 or 4.")); + } + PADDLE_ENFORCE_EQ( + bias_dims[bias_dims_rank - 1], + k_dims[1], + phi::errors::InvalidArgument("The last dim of attn_bias should be " + "equal to the sequence length of key.")); + PADDLE_ENFORCE_EQ(bias_dims[bias_dims_rank - 2], + q_dims[1], + phi::errors::InvalidArgument( + "The 2nd last dim of attn_bias should be equal to " + "the sequence length of query.")); + + if (bias_dims_rank == 2) { + return 0; + } + + if (bias_dims[0] == q_dims[0] && bias_dims[1] == q_dims[2]) { + return q_dims[2] * q_dims[1] * k_dims[1]; + } + + PADDLE_ENFORCE_EQ( + bias_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dim of attn_bias should be 1 or batch size.")); + PADDLE_ENFORCE_EQ( + bias_dims[1], + 1, + phi::errors::InvalidArgument( + "The second dim of attn_bias should be 1 or num_heads.")); + return 0; +} + +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi