未验证 提交 d595928e 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Fused QKVBiasAdd and Transpose with Split Q, KV (#47680)

* fused qkvBiasAdd and transpose with split qkv

* fix typo

* fix format

* fix name

* add annotation

* fix comment
上级 e5408835
......@@ -258,6 +258,180 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
}
void ComputeForwardWithoutTranspose(const phi::DenseTensor& qkv_input_tensor,
const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>();
T* softmax_out_data = softmax_out_tensor->data<T>();
T* dropout_out_data = dropout_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>();
auto out_seq_len = seq_len_;
if (cache_kv_tensor) {
// kv [2, bs, num_head, seq_len, head_dim]
phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat(dev_ctx_,
{*cache_kv_tensor, *kv_transpose_out_tensor},
3,
cache_kv_out_tensor);
out_seq_len = cache_kv_out_tensor->dims()[3];
}
int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
T* q_ptr = q_transpose_out_tensor->data<T>();
T* k_ptr = nullptr;
T* v_ptr = nullptr;
if (cache_kv_tensor) {
int64_t k_size = cache_kv_out_tensor->numel() / 2;
k_ptr = cache_kv_out_tensor->data<T>();
v_ptr = k_ptr + k_size;
} else {
int64_t k_size = q_size;
k_ptr = kv_transpose_out_tensor->data<T>();
v_ptr = k_ptr + k_size;
}
{
// NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
// float16 calculation, INF may appear in QK^T if we do not scale before.
float alpha = 1.0 / sqrt(head_dim_);
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const phi::DenseTensor*> ins = {q_transpose_out_tensor};
std::vector<phi::DenseTensor*> outs = {q_transpose_out_tensor};
phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
}
// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
q_ptr,
k_ptr,
beta,
qk_out_data,
gemm_batch_size,
stride_a,
stride_b);
int softmax_axis = -1;
if (src_mask_tensor != nullptr) {
if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) {
LaunchFusedSoftmaxMaskKernel<T>(qk_out_data,
src_mask_tensor->data<T>(),
softmax_out_data,
batch_size_,
num_head_,
seq_len_,
dev_ctx_.stream());
} else {
std::vector<const phi::DenseTensor*> ins;
std::vector<phi::DenseTensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
}
} else {
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor);
}
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = out_seq_len;
alpha = static_cast<T>(1.0);
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
DropoutFwGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_,
dropout_param_.dropout_prob_,
dropout_param_.is_upscale_in_train_,
dropout_param_.is_fix_seed_,
dropout_param_.seed_val_,
static_cast<const phi::DenseTensor&>(*softmax_out_tensor),
dropout_param_.seed_,
dropout_mask_out_tensor,
dropout_out_tensor,
false);
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
dropout_out_data,
v_ptr,
beta,
qktv_out_data,
gemm_batch_size,
stride_a,
stride_b);
} else {
// softmax_out * v, batched_gemm
// output shape: [batch_size, num_heads, seq_len, head_dim]
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
softmax_out_data,
v_ptr,
beta,
qktv_out_data,
gemm_batch_size,
stride_a,
stride_b);
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3};
TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
}
void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor& softmax_out_tensor,
......
......@@ -59,13 +59,16 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set
// compute_bias as false.
auto qkv_compute = AttnMatMul<T>(dev_ctx,
false,
trans_qkvw,
bsz_seq,
output_size,
input_size,
compute_bias);
/*compute_bias=*/false);
Tensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
auto *qkv_out_data =
......@@ -110,10 +113,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
out_seq_len += cache_offset;
}
Tensor transpose_out_2, qk_out;
transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}});
auto *transpose_out_2_data =
dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T));
Tensor q_transpose_out, kv_transpose_out, qk_out;
q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *q_transpose_out_data =
dev_ctx.Alloc<T>(&q_transpose_out, q_transpose_out.numel() * sizeof(T));
kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}});
auto *kv_transpose_out_data = dev_ctx.Alloc<T>(
&kv_transpose_out, kv_transpose_out.numel() * sizeof(T));
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
......@@ -305,10 +313,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
Tensor *pre_cache_kv_out_tmp =
cache_offset > 0 ? &pre_cache_kv_out : nullptr;
Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr;
fmha_compute.ComputeForward(qkv_out,
qkv_bias_add_transpose_split<T>(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
bsz,
num_head,
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
pre_cache_kv_tensor,
src_mask,
&transpose_out_2,
&q_transpose_out,
&kv_transpose_out,
pre_cache_kv_out_tmp,
&qk_out,
src_mask_tmp,
......@@ -317,7 +336,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_out,
&qktv_out,
&fmha_out);
const T *k_ptr = nullptr;
const T *v_ptr = nullptr;
......@@ -329,11 +347,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
v_ptr = k_ptr + k_size;
} else {
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
int64_t k_size = q_size;
const T *q_ptr = qkv_data;
k_ptr = q_ptr + q_size;
int64_t k_size = bsz * seq_len * num_head * dim_head;
const T *q_ptr = q_transpose_out_data;
k_ptr = kv_transpose_out_data;
v_ptr = k_ptr + k_size;
}
......@@ -358,10 +374,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head);
} else { // not generation
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(qkv_out,
qkv_bias_add_transpose_split<T>(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
bsz,
num_head,
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
cache_kv,
src_mask,
&transpose_out_2,
&q_transpose_out,
&kv_transpose_out,
cache_kv_out,
&qk_out,
nullptr,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -1155,6 +1152,137 @@ void write_cache_kv(const phi::GPUContext &dev_ctx,
cache_v, v, num_head, dim_head, seq_len, max_seq_len);
}
template <typename T, int VecSize, bool ComputeBias>
__global__ void add_fusedQKV_bias_transpose_split_kernel(
T *q_buf,
T *kv_buf,
const T *qkv,
const T *qkv_bias,
const int32_t elem_cnt,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head) {
const int32_t offset = batch_size * seq_len * head_num * size_per_head;
const int32_t hidden_size = head_num * size_per_head;
const int32_t fused_hidden_size = 3 * hidden_size;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
LoadT bias_vec;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
phi::Load<T, VecSize>(&qkv[linear_index], &src_vec);
int32_t bias_idx = linear_index % fused_hidden_size;
if (ComputeBias) {
phi::Load<T, VecSize>(&qkv_bias[bias_idx], &bias_vec);
#pragma unroll
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
src_vec[unroll_idx] += bias_vec[unroll_idx];
}
}
const int32_t token_idx = linear_index / fused_hidden_size;
// const int32_t token_padded_idx = token_idx + (padding_offset == nullptr ?
// 0 : padding_offset[token_idx]);
const int32_t target_batch_id = token_idx / seq_len;
const int32_t seq_id = token_idx % seq_len;
// equal to:
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
const int32_t qkv_id = bias_idx / hidden_size;
const int32_t head_id = (linear_index % hidden_size) / size_per_head;
const int32_t size_id = linear_index % size_per_head;
if (qkv_id == 0) {
phi::Store<T, VecSize>(
src_vec,
&q_buf[target_batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head + seq_id * size_per_head +
size_id]);
} else {
const int32_t kv_store_offset = (qkv_id - 1) * offset;
phi::Store<T, VecSize>(
src_vec,
&kv_buf[kv_store_offset +
target_batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head + seq_id * size_per_head +
size_id]);
}
}
}
inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
constexpr int kBlockSize = 128;
constexpr int kNumWaves = 16;
const int device_id = paddle::platform::GetCurrentDeviceId();
const int sm_count = paddle::platform::GetGPUMultiProcessors(device_id);
const int max_thread_per_multiprocessor =
paddle::platform::GetGPUMultiProcessors(device_id);
*num_blocks =
std::max<int>(1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * max_thread_per_multiprocessor /
kBlockSize * kNumWaves));
return cudaSuccess;
}
template <typename T>
void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
T *q_buf,
T *kv_buf,
const T *qkv,
const T *qkv_bias,
const int batch_size,
const int head_num,
const int seq_len,
const int size_per_head,
bool compute_bias) {
const int32_t token_num = batch_size * seq_len;
const int32_t elem_cnt = token_num * head_num * size_per_head * 3;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(size_per_head % PackSize,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d",
size_per_head,
PackSize));
const int32_t pack_num = elem_cnt / PackSize;
const int32_t blocksize = 128;
int32_t grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
if (compute_bias) {
add_fusedQKV_bias_transpose_split_kernel<T, PackSize, true>
<<<grid_size, blocksize, 0, dev_ctx.stream()>>>(q_buf,
kv_buf,
qkv,
qkv_bias,
elem_cnt,
batch_size,
seq_len,
token_num,
head_num,
size_per_head);
} else {
add_fusedQKV_bias_transpose_split_kernel<T, PackSize, false>
<<<grid_size, blocksize, 0, dev_ctx.stream()>>>(q_buf,
kv_buf,
qkv,
qkv_bias,
elem_cnt,
batch_size,
seq_len,
token_num,
head_num,
size_per_head);
}
}
} // namespace
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册