From 8c0bacd45eb385941f6d8663a500aa6c46b0a038 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:10:03 +0800 Subject: [PATCH] Add fused_attention_op: add impl wrappers. (#35903) (#36673) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据; --- .../elementwise/elementwise_op_impl.cu.h | 3 +- .../operators/fused/attention_layer_norm.h | 2 +- .../fluid/operators/fused/attn_bias_add.cu.h | 6 +- paddle/fluid/operators/fused/attn_gemm.h | 159 +++++++++ paddle/fluid/operators/fused/fmha_ref.h | 324 ++++++++++++++++++ paddle/fluid/operators/layer_norm_kernel.cu.h | 1 - 6 files changed, 487 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/operators/fused/attn_gemm.h create mode 100644 paddle/fluid/operators/fused/fmha_ref.h diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 81dff94730..e4074cc7d7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -108,7 +108,8 @@ struct ElementwisePrimitiveCaller { template struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, InT **args, OutT *result) { + __device__ inline void operator()(Functor func, InT (*args)[VecSize], + OutT *result) { kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } diff --git a/paddle/fluid/operators/fused/attention_layer_norm.h b/paddle/fluid/operators/fused/attention_layer_norm.h index d234a0f085..43491a9faf 100644 --- a/paddle/fluid/operators/fused/attention_layer_norm.h +++ b/paddle/fluid/operators/fused/attention_layer_norm.h @@ -50,7 +50,7 @@ class AttnLayerNorm { } } - void ComputeBackward(const T* x_data, const T* y_data, + void ComputeBackward(const T* x_data, const T* d_y_data, const LayerNormParamType* scale_data, const LayerNormParamType* mean_data, const LayerNormParamType* var_data, T* d_x_data, diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 27b903ff62..18ae932c93 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -34,6 +34,7 @@ namespace cub = hipcub; #define LAUNCH_BOUNDS(BlockDim) #endif +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" @@ -51,11 +52,6 @@ using CudnnDataType = platform::CudnnDataType; template using ReduceParamType = typename CudnnDataType::BatchNormParamType; -template -struct AddFunctor { - inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; } -}; - template __global__ void BroadcastKernelBinary( diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h new file mode 100644 index 0000000000..a2001d0a81 --- /dev/null +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/attn_bias_add.cu.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +// support gemm-nt and gemm-nn, which is used in fused_attention_op. +template +class AttnMatMul { + public: + // (m, n, k) = bsz_seq, output_size, input_size + AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA, + bool transB, int bsz_seq, int output_size, int input_size, + bool compute_bias) + : dev_ctx_(dev_ctx), + transA_(transA), + transB_(transB), + bsz_seq_(bsz_seq), + output_size_(output_size), + input_size_(input_size), + compute_bias_(compute_bias) {} + + ~AttnMatMul() {} + + void ComputeForward(const T* weight_data, const T* input_data, + const T* bias_data, T* output_data, T* bias_out_data) { + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. + // here: (transa, transb): nt, input * weight. + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + if (transA_) { + transA = CblasTrans; + } + if (transB_) { + transB = CblasTrans; + } + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) + auto blas = math::GetBlas(dev_ctx_); + blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, + input_data, weight_data, beta, output_data); + if (compute_bias_) { + // compute output + bias + LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data, + bias_data, bias_out_data); + } + } + + void ComputeBackward(const T* input, const T* weight, const T* d_output, + T* d_input, T* d_weight, T* d_bias) { + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + auto blas = math::GetBlas(dev_ctx_); + + CBLAS_TRANSPOSE dB_transA = CblasNoTrans; + CBLAS_TRANSPOSE dB_transB = CblasNoTrans; + CBLAS_TRANSPOSE dA_transA = CblasNoTrans; + CBLAS_TRANSPOSE dA_transB = CblasNoTrans; + int dB_m = 1; + int dB_n = 1; + int dB_k = 1; + int dA_m = 1; + int dA_n = 1; + int dA_k = 1; + + T* dB_input_1_ptr = nullptr; + T* dB_input_2_ptr = nullptr; + T* dB_output_ptr = d_weight; + + T* dA_input_1_ptr = nullptr; + T* dA_input_2_ptr = nullptr; + T* dA_output_ptr = d_input; + + if (!transA_) { + // fw: gemm-nt + if (transB_) { + // bw: gemm-tn, dB = (dC)^t * A + dB_transA = CblasTrans; + dB_transB = CblasNoTrans; + dB_m = output_size_; + dB_n = input_size_; + dB_k = bsz_seq_; + + // bw: gemm-nn, dA = dC * B + dA_transA = CblasNoTrans; + dA_transB = CblasNoTrans; + dA_m = bsz_seq_; + dA_n = input_size_; + dA_k = output_size_; + + blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output, + input, beta, dB_output_ptr); + blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, + weight, beta, dA_output_ptr); + } else { // fw: gemm-nn + // bw: gemm-tn, dB = A^t * dC + dB_transA = CblasTrans; + dB_transB = CblasNoTrans; + dB_m = input_size_; + dB_n = output_size_; + dB_k = bsz_seq_; + + // bw: gemm-nt, dA = dC * B^t + dA_transA = CblasNoTrans; + dA_transB = CblasTrans; + dA_m = bsz_seq_; + dA_n = input_size_; + dA_k = output_size_; + + blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input, + d_output, beta, dB_output_ptr); + blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, + weight, beta, dA_output_ptr); + } + } else if (transB_) { + PADDLE_THROW(platform::errors::InvalidArgument( + "AttnMatMul wrapper do not support (transA=T, transB=T)" + "parameters.")); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "AttnMatMul wrapper do not support (transA=T, transB=N)" + "parameters.")); + } + if (compute_bias_) { + LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias); + } + } + + private: + const platform::CUDADeviceContext& dev_ctx_; + + bool transA_; + bool transB_; + + int bsz_seq_; + int output_size_; + int input_size_; + + int compute_bias_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h new file mode 100644 index 0000000000..bef0052a00 --- /dev/null +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -0,0 +1,324 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/dropout_impl.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/softmax_cudnn_op.cu.h" +#include "paddle/fluid/operators/transpose_op.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class AttnDropoutParam { + public: + AttnDropoutParam() { + is_test_ = false; + dropout_implementation_ = "downgrade_in_infer"; + dropout_prob_ = 0.5; + is_upscale_in_train_ = false; + is_fix_seed_ = false; + seed_val_ = 0; + seed_ = nullptr; + } + AttnDropoutParam(bool is_test, const std::string dropout_implementation, + float dropout_prob, bool is_upscale_in_train, + bool is_fix_seed, int seed_val, const Tensor* seed) { + is_test_ = is_test; + dropout_implementation_ = dropout_implementation; + dropout_prob_ = dropout_prob; + is_upscale_in_train_ = is_upscale_in_train; + is_fix_seed_ = is_fix_seed; + seed_val_ = seed_val; + seed_ = seed; + } + bool is_test_; + std::string dropout_implementation_; + float dropout_prob_; + bool is_upscale_in_train_; + bool is_fix_seed_; + int seed_val_; + const Tensor* seed_; +}; + +template +class FMHARef { + public: + FMHARef(const platform::CUDADeviceContext& dev_ctx, int64_t batch_size, + int64_t seq_len, int64_t num_head, int64_t head_dim, + AttnDropoutParam param) + : dev_ctx_(dev_ctx), + batch_size_(batch_size), + seq_len_(seq_len), + num_head_(num_head), + head_dim_(head_dim), + dropout_param_(param) {} + + ~FMHARef() {} + + void ComputeForward(const Tensor& qkv_input_tensor, + const Tensor& src_mask_tensor, + Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, + Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, + Tensor* dropout_mask_out_tensor, + Tensor* dropout_out_tensor, Tensor* qktv_out_tensor, + Tensor* fmha_out_tensor) { + // input shape: [bs, seq_len, 3, num_head, head_dim] + // transpose with perm [2, 0, 1, 3, 4], + // output_shape: [3, bs, num_head, seq_len, head_dim] + int ndims = 5; + std::vector perm_1 = {2, 0, 3, 1, 4}; + TransposeGPUKernelDriver(dev_ctx_, ndims, qkv_input_tensor, perm_1, + transpose_2_out_tensor); + + T* qkv_data = transpose_2_out_tensor->data(); + T* qk_out_data = qk_out_tensor->data(); + T* qktv_out_data = qktv_out_tensor->data(); + T* softmax_out_data = softmax_out_tensor->data(); + T* dropout_out_data = dropout_out_tensor->data(); + T* fmha_out_data = fmha_out_tensor->data(); + + int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + int k_size = q_size; + T* q_ptr = qkv_data; + T* k_ptr = q_ptr + q_size; + T* v_ptr = k_ptr + k_size; + + // q*k^t, batched_gemm + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + auto blas = math::GetBlas(dev_ctx_); + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = seq_len_; + int gemm_k = head_dim_; + T alpha = static_cast(1.0 / sqrt(head_dim_)); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr, + k_ptr, beta, qk_out_data, gemm_batch_size, stride_a, + stride_b); + + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(&src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; + int softmax_axis = -1; + if (&src_mask_tensor != nullptr) { + LaunchElementwiseCudaKernel( + dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); + SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, + softmax_axis, softmax_out_tensor); + } else { + SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out_tensor, softmax_axis, + softmax_out_tensor); + } + + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = seq_len_; + alpha = static_cast(1.0); + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + + if (dropout_param_.dropout_prob_) { + DropoutFwGPUKernelDriver( + dev_ctx_, dropout_param_.is_test_, + static_cast( + dropout_param_.dropout_implementation_), + dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, + dropout_param_.is_fix_seed_, dropout_param_.seed_val_, + static_cast(*softmax_out_tensor), dropout_param_.seed_, + dropout_mask_out_tensor, dropout_out_tensor); + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + dropout_out_data, v_ptr, beta, qktv_out_data, + gemm_batch_size, stride_a, stride_b); + } else { + // softmax_out * v, batched_gemm + // output shape: [batch_size, num_heads, seq_len, head_dim] + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + softmax_out_data, v_ptr, beta, qktv_out_data, + gemm_batch_size, stride_a, stride_b); + } + // transpose: [0, 2, 1, 3] + // output shape: [batch_size, seq_len, num_heads, head_dim] + std::vector perm_3 = {0, 2, 1, 3}; + ndims = 4; + TransposeGPUKernelDriver(dev_ctx_, ndims, *qktv_out_tensor, perm_3, + fmha_out_tensor); + } + + void ComputeBackward( + const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor, + const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor, + const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor, + const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor, + Tensor* qktv_out_grad_tensor, Tensor* dropout_out_grad_tensor, + Tensor* softmax_out_grad_tensor, Tensor* src_mask_out_grad_tensor, + Tensor* qk_out_grad_tensor, Tensor* transpose_2_out_grad_tensor, + Tensor* src_mask_grad_tensor, Tensor* qkv_input_grad_tensor) { + auto blas = math::GetBlas(dev_ctx_); + int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + int k_size = q_size; + int softmax_axis = -1; + + T* qkv_grad_data = transpose_2_out_grad_tensor->data(); + T* q_grad_ptr = qkv_grad_data; + T* k_grad_ptr = q_grad_ptr + q_size; + T* v_grad_ptr = k_grad_ptr + k_size; + const T* qkv_data = transpose_2_out_tensor.data(); + const T* q_ptr = qkv_data; + const T* k_ptr = q_ptr + q_size; + const T* v_ptr = k_ptr + k_size; + + const T* softmax_out_data = softmax_out_tensor.data(); + T* softmax_out_grad_data = softmax_out_grad_tensor->data(); + const T* dropout_out_data = dropout_out_tensor.data(); + T* dropout_out_grad_data = dropout_out_grad_tensor->data(); + T* qktv_out_grad_data = qktv_out_grad_tensor->data(); + + // transpose bw + int ndims = 4; + std::vector perm_3 = {0, 2, 1, 3}; + TransposeGPUKernelDriver(dev_ctx_, ndims, fmha_out_grad_tensor, perm_3, + qktv_out_grad_tensor); + + // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = + // qktv_out_data(out) + CBLAS_TRANSPOSE transA = CblasTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = head_dim_; + int gemm_k = seq_len_; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + // bw: dy = x^t * dout + if (dropout_param_.dropout_prob_) { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + dropout_out_data, qktv_out_grad_data, beta, v_grad_ptr, + gemm_batch_size, stride_a, stride_b); + } else { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + softmax_out_data, qktv_out_grad_data, beta, v_grad_ptr, + gemm_batch_size, stride_a, stride_b); + } + // bw: dx = dout * y^t + transA = CblasNoTrans; + transB = CblasTrans; + gemm_m = seq_len_; + gemm_n = seq_len_; + gemm_k = head_dim_; + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + if (dropout_param_.dropout_prob_) { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qktv_out_grad_data, v_ptr, beta, dropout_out_grad_data, + gemm_batch_size, stride_a, stride_b); + } else { + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qktv_out_grad_data, v_ptr, beta, softmax_out_grad_data, + gemm_batch_size, stride_a, stride_b); + } + // dropout bw + if (dropout_param_.dropout_prob_) { + DropoutGradGPUKernelDriver( + dev_ctx_, static_cast( + dropout_param_.dropout_implementation_), + dropout_param_.dropout_prob_, + static_cast(*dropout_out_grad_tensor), + dropout_mask_out_tensor, softmax_out_grad_tensor->numel(), + softmax_out_grad_tensor); + } + + if (&src_mask_tensor != nullptr) { + SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, + *softmax_out_grad_tensor, softmax_axis, + src_mask_out_grad_tensor); + + // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + + // src_mask + // Special case when dy is not needed and dx doesn't reduce + if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr && + qk_out_tensor.dims() == src_mask_out_tensor.dims()) { + VLOG(4) << "Special case when dy is not needed and dx doesn't " + "reduce"; + framework::TensorCopy(*src_mask_out_grad_tensor, dev_ctx_.GetPlace(), + dev_ctx_, qk_out_grad_tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only used for the backward elementwise_add op when" + "dy is not needed and dx is not reduce")); + return; + } + + } else { + SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, + *softmax_out_grad_tensor, softmax_axis, + qk_out_grad_tensor); + } + + T* qk_out_grad_data = qk_out_grad_tensor->data(); + alpha = static_cast(1.0 / sqrt(head_dim_)); + // recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out + // bw: dy (seq_len * head_dim) = (dout)^t * x + transA = CblasTrans; + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = seq_len_; + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size, + stride_a, stride_b); + // dx (seq_len * head_dim) = dout * y + transA = CblasNoTrans; + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = seq_len_; + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, + qk_out_grad_data, k_ptr, beta, q_grad_ptr, gemm_batch_size, + stride_a, stride_b); + + // transpose bw + ndims = 5; + std::vector perm_1 = {1, 3, 0, 2, 4}; + TransposeGPUKernelDriver(dev_ctx_, ndims, *transpose_2_out_grad_tensor, + perm_1, qkv_input_grad_tensor); + } + + private: + const platform::CUDADeviceContext& dev_ctx_; + + int64_t batch_size_; + int64_t seq_len_; + int64_t num_head_; + int64_t head_dim_; + + AttnDropoutParam dropout_param_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 06c1eaf881..4280c86ca9 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -35,7 +35,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template -- GitLab