From d595928e00ff6de2c4a58677c33060de58f17cd2 Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Fri, 18 Nov 2022 19:38:25 +0800 Subject: [PATCH] Fused QKVBiasAdd and Transpose with Split Q, KV (#47680) * fused qkvBiasAdd and transpose with split qkv * fix typo * fix format * fix name * add annotation * fix comment --- paddle/fluid/operators/fused/fmha_ref.h | 174 ++++++++++++++++++ .../fused/fused_multi_transformer_op.cu | 97 ++++++---- .../fused/fused_multi_transformer_op.cu.h | 134 +++++++++++++- 3 files changed, 367 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 4854f81eae4..66176c9e754 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -258,6 +258,180 @@ class FMHARef { dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } + void ComputeForwardWithoutTranspose(const phi::DenseTensor& qkv_input_tensor, + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + phi::DenseTensor* q_transpose_out_tensor, + phi::DenseTensor* kv_transpose_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor) { + // input shape: [bs, seq_len, 3, num_head, head_dim] + // transpose with perm [2, 0, 3, 1, 4], + // output_shape: [3, bs, num_head, seq_len, head_dim] + T* qk_out_data = qk_out_tensor->data(); + T* qktv_out_data = qktv_out_tensor->data(); + T* softmax_out_data = softmax_out_tensor->data(); + T* dropout_out_data = dropout_out_tensor->data(); + T* fmha_out_data = fmha_out_tensor->data(); + + auto out_seq_len = seq_len_; + if (cache_kv_tensor) { + // kv [2, bs, num_head, seq_len, head_dim] + phi::funcs::ConcatFunctor concat; + // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] + concat(dev_ctx_, + {*cache_kv_tensor, *kv_transpose_out_tensor}, + 3, + cache_kv_out_tensor); + out_seq_len = cache_kv_out_tensor->dims()[3]; + } + + int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + T* q_ptr = q_transpose_out_tensor->data(); + T* k_ptr = nullptr; + T* v_ptr = nullptr; + + if (cache_kv_tensor) { + int64_t k_size = cache_kv_out_tensor->numel() / 2; + k_ptr = cache_kv_out_tensor->data(); + v_ptr = k_ptr + k_size; + } else { + int64_t k_size = q_size; + k_ptr = kv_transpose_out_tensor->data(); + v_ptr = k_ptr + k_size; + } + + { + // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for + // float16 calculation, INF may appear in QK^T if we do not scale before. + float alpha = 1.0 / sqrt(head_dim_); + auto functor = phi::funcs::ScaleFunctor(alpha); + std::vector ins = {q_transpose_out_tensor}; + std::vector outs = {q_transpose_out_tensor}; + phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); + } + + // q*k^t, batched_gemm + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + auto blas = phi::funcs::GetBlas(dev_ctx_); + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = out_seq_len; + int gemm_k = head_dim_; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + q_ptr, + k_ptr, + beta, + qk_out_data, + gemm_batch_size, + stride_a, + stride_b); + int softmax_axis = -1; + if (src_mask_tensor != nullptr) { + if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { + LaunchFusedSoftmaxMaskKernel(qk_out_data, + src_mask_tensor->data(), + softmax_out_data, + batch_size_, + num_head_, + seq_len_, + dev_ctx_.stream()); + } else { + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; + phi::funcs::BroadcastKernel( + dev_ctx_, + ins, + &outs, + elewise_add_axis, + phi::funcs::AddFunctor()); + + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); + } + } else { + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); + } + + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = out_seq_len; + alpha = static_cast(1.0); + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + + if (dropout_param_.dropout_prob_) { + DropoutFwGPUKernelDriver( + static_cast(dev_ctx_), + dropout_param_.is_test_, + dropout_param_.dropout_prob_, + dropout_param_.is_upscale_in_train_, + dropout_param_.is_fix_seed_, + dropout_param_.seed_val_, + static_cast(*softmax_out_tensor), + dropout_param_.seed_, + dropout_mask_out_tensor, + dropout_out_tensor, + false); + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + dropout_out_data, + v_ptr, + beta, + qktv_out_data, + gemm_batch_size, + stride_a, + stride_b); + } else { + // softmax_out * v, batched_gemm + // output shape: [batch_size, num_heads, seq_len, head_dim] + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + softmax_out_data, + v_ptr, + beta, + qktv_out_data, + gemm_batch_size, + stride_a, + stride_b); + } + // transpose: [0, 2, 1, 3] + // output shape: [batch_size, seq_len, num_heads, head_dim] + std::vector perm_3 = {0, 2, 1, 3}; + TransposeGPUKernelDriver( + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); + } + void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, const phi::DenseTensor* src_mask_tensor, const phi::DenseTensor& softmax_out_tensor, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 1274e247e69..f52bc2a7f54 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -59,13 +59,16 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set + // compute_bias as false. auto qkv_compute = AttnMatMul(dev_ctx, false, trans_qkvw, bsz_seq, output_size, input_size, - compute_bias); + /*compute_bias=*/false); + Tensor qkv_out; qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); auto *qkv_out_data = @@ -110,10 +113,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { out_seq_len += cache_offset; } - Tensor transpose_out_2, qk_out; - transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); - auto *transpose_out_2_data = - dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + Tensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); @@ -305,19 +313,29 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { Tensor *pre_cache_kv_out_tmp = cache_offset > 0 ? &pre_cache_kv_out : nullptr; Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr; - fmha_compute.ComputeForward(qkv_out, - pre_cache_kv_tensor, - src_mask, - &transpose_out_2, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + fmha_compute.ComputeForwardWithoutTranspose(qkv_out, + pre_cache_kv_tensor, + src_mask, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); const T *k_ptr = nullptr; const T *v_ptr = nullptr; @@ -329,11 +347,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { v_ptr = k_ptr + k_size; } else { // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - k_ptr = q_ptr + q_size; + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; v_ptr = k_ptr + k_size; } @@ -358,18 +374,29 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - cache_kv, - src_mask, - &transpose_out_2, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + fmha_compute.ComputeForwardWithoutTranspose(qkv_out, + cache_kv, + src_mask, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index e0795616fd9..e6f4461f0c1 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -1,12 +1,9 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -1155,6 +1152,137 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, cache_v, v, num_head, dim_head, seq_len, max_seq_len); } +template +__global__ void add_fusedQKV_bias_transpose_split_kernel( + T *q_buf, + T *kv_buf, + const T *qkv, + const T *qkv_bias, + const int32_t elem_cnt, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = batch_size * seq_len * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + if (ComputeBias) { + phi::Load(&qkv_bias[bias_idx], &bias_vec); +#pragma unroll + for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + src_vec[unroll_idx] += bias_vec[unroll_idx]; + } + } + const int32_t token_idx = linear_index / fused_hidden_size; + // const int32_t token_padded_idx = token_idx + (padding_offset == nullptr ? + // 0 : padding_offset[token_idx]); + const int32_t target_batch_id = token_idx / seq_len; + const int32_t seq_id = token_idx % seq_len; + + // equal to: + // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { + phi::Store( + src_vec, + &q_buf[target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + size_id]); + } else { + const int32_t kv_store_offset = (qkv_id - 1) * offset; + phi::Store( + src_vec, + &kv_buf[kv_store_offset + + target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + size_id]); + } + } +} + +inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { + constexpr int kBlockSize = 128; + constexpr int kNumWaves = 16; + + const int device_id = paddle::platform::GetCurrentDeviceId(); + const int sm_count = paddle::platform::GetGPUMultiProcessors(device_id); + const int max_thread_per_multiprocessor = + paddle::platform::GetGPUMultiProcessors(device_id); + + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * max_thread_per_multiprocessor / + kBlockSize * kNumWaves)); + return cudaSuccess; +} + +template +void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *kv_buf, + const T *qkv, + const T *qkv_bias, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head, + bool compute_bias) { + const int32_t token_num = batch_size * seq_len; + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + if (compute_bias) { + add_fusedQKV_bias_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + qkv_bias, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); + } else { + add_fusedQKV_bias_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + qkv_bias, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); + } +} + } // namespace } // namespace operators -- GitLab