未验证 提交 1149a378 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Sparse] optimize sparse attention (#44743)

上级 c28bb981
...@@ -56,6 +56,7 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) ...@@ -56,6 +56,7 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#if CUDA_VERSION >= 11030 #if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \ #define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \ __macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \ __macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM); __macro(cusparseSDDMM);
......
...@@ -68,6 +68,7 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) ...@@ -68,6 +68,7 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#if CUDA_VERSION >= 11030 #if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \ #define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \ __macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \ __macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM); __macro(cusparseSDDMM);
......
...@@ -48,6 +48,15 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) { ...@@ -48,6 +48,15 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
} }
} }
inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCsrTensor& x) {
// TODO(zhouwei): will change to 'CUSPARSE_SPMM_CSR_ALG2' when support batch
return CUSPARSE_SPMM_CSR_ALG2;
}
inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCooTensor& x) {
return CUSPARSE_SPMM_ALG_DEFAULT;
}
/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/ /************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/
template <typename T, typename IntT> template <typename T, typename IntT>
...@@ -324,7 +333,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa, ...@@ -324,7 +333,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&beta, &beta,
out_descriptor.descriptor(), out_descriptor.descriptor(),
gpu_type, gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT, GetSpMMAlgorithm(mat_a),
&buffer_size); &buffer_size);
}); });
...@@ -341,7 +350,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa, ...@@ -341,7 +350,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&beta, &beta,
out_descriptor.descriptor(), out_descriptor.descriptor(),
gpu_type, gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT, GetSpMMAlgorithm(mat_a),
tmp_buffer_ptr); tmp_buffer_ptr);
}); });
} }
......
...@@ -43,21 +43,14 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows, ...@@ -43,21 +43,14 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]); int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return; if (row_nnz == 0) return;
int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE; T mul = 0;
T mul_result = 0; for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
for (int i = 0; i < kIteration; ++i) { mul += out_values[row_first + idx] * dout_values[row_first + idx];
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
} }
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF); T mul_sum = phi::funcs::warpReduceSum<T>(mul, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) * for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
dx_values[row_first + idx] = (dout_values[row_first + idx] - mul_sum) *
out_values[row_first + idx] / scale; out_values[row_first + idx] / scale;
} }
} }
...@@ -96,8 +89,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx, ...@@ -96,8 +89,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
int N = q_dim[q_rank - 1]; int N = q_dim[q_rank - 1];
int batch_nnz = softmax.nnz() / batch_num; int batch_nnz = softmax.nnz() / batch_num;
dim3 grid((total_row_num + 3) / 4); dim3 grid((total_row_num + 7) / 8);
dim3 block(WARP_SIZE, 4); dim3 block(WARP_SIZE, 8);
AttnSoftmaxGpuGradKernel<T><<<grid, block, 0, dev_ctx.stream()>>>( AttnSoftmaxGpuGradKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
softmax.non_zero_crows().data<int64_t>(), softmax.non_zero_crows().data<int64_t>(),
......
...@@ -26,30 +26,7 @@ limitations under the License. */ ...@@ -26,30 +26,7 @@ limitations under the License. */
namespace phi { namespace phi {
namespace sparse { namespace sparse {
#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \ template <typename T>
case size: { \
constexpr int HINT = size; \
__VA_ARGS__(); \
break; \
}
#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \
[&] { \
const auto& __size__ = SIZE; \
switch (__size__) { \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for columns>512 "); \
} \
}()
template <typename T, int BufferSize>
__global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
const int64_t* x_cols, const int64_t* x_cols,
const T* x_values, const T* x_values,
...@@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, ...@@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T* out_values, T* out_values,
int M, int M,
int total_row_num, int total_row_num,
float scale,
int num_heads, int num_heads,
int batch_nnz) { int batch_nnz) {
// out = exp(x-x_max) / sum(exp(x-x_max)) // out = exp(x-x_max) / sum(exp(x-x_max))
...@@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, ...@@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]); int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
if (row_nnz == 0) return; if (row_nnz == 0) return;
T buffer[BufferSize] = {0};
int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T max_val = -std::numeric_limits<T>::infinity(); T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) { for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
bool mask = false; bool mask = false;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
int col_idx = static_cast<int>(x_cols[row_first + idx]); int col_idx = static_cast<int>(x_cols[row_first + idx]);
if (kp_mask != nullptr && if (kp_mask != nullptr &&
kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) { kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) {
mask = true; mask = true;
...@@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, ...@@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
} }
if (!mask) { if (!mask) {
buffer[i] = x_values[row_first + idx] / scale; T val = x_values[row_first + idx];
if (buffer[i] > max_val) { if (val > max_val) {
max_val = buffer[i]; max_val = val;
} }
out_values[row_first + idx] = val;
} else {
// Note corner case: when all elements of the row are masked, result
// may be wrong because of exp('-inf' - '-inf'), just ignore now.
out_values[row_first + idx] = -std::numeric_limits<T>::infinity();
} }
} }
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF); T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp_sum = 0; T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) { for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
int idx = threadIdx.x + i * WARP_SIZE; auto functor = phi::funcs::CudaExpFunctor<T>();
if (idx >= row_nnz) break; T exp = functor(out_values[row_first + idx] - row_max_val);
if (buffer[i]) {
T exp = functor(buffer[i] - row_max_val);
exp_sum += exp; exp_sum += exp;
buffer[i] = exp; out_values[row_first + idx] = exp;
}
} }
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF); T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) { for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
int idx = threadIdx.x + i * WARP_SIZE; out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
if (idx >= row_nnz) break;
if (buffer[i]) {
out_values[row_first + idx] = buffer[i] / row_exp_sum;
} else {
out_values[row_first + idx] = static_cast<T>(0);
}
} }
} }
...@@ -219,35 +181,25 @@ void FusedAttentionCsrKernel( ...@@ -219,35 +181,25 @@ void FusedAttentionCsrKernel(
"shape of 'attn_mask' must be [seq_len, seq_len]")); "shape of 'attn_mask' must be [seq_len, seq_len]"));
} }
/* Step1: SDD Matmul, reuse */ /* Step1: SDD Matmul, reuse matmul */
SparseCsrTensor sdd_result; SparseCsrTensor sdd_result;
EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result); EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx); auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SDDMM(false, sparse_blas.SDDMM(false,
true, true,
static_cast<T>(1), static_cast<T>(1 / std::sqrt(N)),
query, query,
key, key,
static_cast<T>(0), static_cast<T>(0),
&sdd_result); &sdd_result);
/* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */
EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax); EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax);
int buffer_size; dim3 grid((total_row_num + 7) / 8);
if (M < 128) { dim3 block(WARP_SIZE, 8);
buffer_size = (M + 32 - 1) / 32;
} else {
buffer_size = ((M + 128 - 1) / 128) * 4;
}
dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
int batch_nnz = sdd_result.nnz() / batch_num; int batch_nnz = sdd_result.nnz() / batch_num;
AttnSoftmaxGpuKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] {
AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(), sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(), sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(), sdd_result.non_zero_elements().data<T>(),
...@@ -256,12 +208,9 @@ void FusedAttentionCsrKernel( ...@@ -256,12 +208,9 @@ void FusedAttentionCsrKernel(
softmax->mutable_non_zero_elements()->data<T>(), softmax->mutable_non_zero_elements()->data<T>(),
M, M,
total_row_num, total_row_num,
std::sqrt(N),
q_dim[1], q_dim[1],
batch_nnz); batch_nnz);
});
/* Step3: DSD Matmul, reuse */
softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]})); softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]}));
MatmulCsrDenseKernel<T, Context>(dev_ctx, *softmax, value, out); MatmulCsrDenseKernel<T, Context>(dev_ctx, *softmax, value, out);
#else #else
......
...@@ -37,7 +37,7 @@ def get_cuda_version(): ...@@ -37,7 +37,7 @@ def get_cuda_version():
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11070, not core.is_compiled_with_cuda() or get_cuda_version() < 11070,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" "core is not compiled with CUDA and cuda version need larger than or equal to 11.7"
) )
class TestSparseAttentionAPI1(unittest.TestCase): class TestSparseAttentionAPI1(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册