From 3d7e211898d89eb5ec66275d998f1affd20b3bab Mon Sep 17 00:00:00 2001 From: RichardWooSJTU <37864677+RichardWooSJTU@users.noreply.github.com> Date: Sun, 18 Sep 2022 19:15:48 +0800 Subject: [PATCH] Add INT8 support for fused_multi_transformer_op (#45284) --- .../ir_params_sync_among_devices_pass.cc | 3 +- paddle/fluid/operators/fused/CMakeLists.txt | 2 + .../operators/fused/attention_layer_norm.h | 30 +- paddle/fluid/operators/fused/attn_gemm_int8.h | 189 +++ paddle/fluid/operators/fused/cublaslt.h | 211 +++ .../operators/fused/fused_dropout_act_bias.h | 89 +- .../operators/fused/fused_dropout_common.h | 1 + .../operators/fused/fused_dropout_helper.h | 171 ++- .../fused_layernorm_residual_dropout_bias.h | 285 ++-- .../fused/fused_multi_transformer_int8_op.cc | 369 ++++++ .../fused/fused_multi_transformer_int8_op.cu | 670 ++++++++++ .../fused/fused_multi_transformer_op.cu | 1152 +--------------- .../fused/fused_multi_transformer_op.h | 1161 +++++++++++++++++ .../fused/fused_residual_dropout_bias.h | 158 ++- .../operators/fused/quant_dequant_kernel.h | 136 ++ paddle/fluid/operators/layer_norm_kernel.cu.h | 77 +- paddle/fluid/platform/dynload/cublasLt.h | 42 +- paddle/fluid/pybind/op_function_generator.h | 8 + paddle/phi/backends/dynload/cublasLt.h | 42 +- paddle/phi/backends/dynload/dynamic_loader.cc | 2 +- .../fluid/tests/unittests/CMakeLists.txt | 6 + .../test_fused_multi_transformer_int8_op.py | 792 +++++++++++ 22 files changed, 4168 insertions(+), 1428 deletions(-) create mode 100644 paddle/fluid/operators/fused/attn_gemm_int8.h create mode 100644 paddle/fluid/operators/fused/cublaslt.h create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_op.h create mode 100644 paddle/fluid/operators/fused/quant_dequant_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 168b99f3d7..7f63eeaad2 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -165,7 +165,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { auto var_data_type = var_node->Var()->GetDataType(); VLOG(5) << "var_name is " << var_name << ", data type is " << var_data_type; - if (var_data_type == paddle::framework::proto::VarType::FP16) { + if (var_data_type == paddle::framework::proto::VarType::FP16 && + t->dtype() != paddle::experimental::DataType::FLOAT16) { framework::Tensor half_tensor; half_tensor.set_type(paddle::experimental::DataType::FLOAT16); half_tensor.Resize(t->dims()); diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index f2a4b68765..9a14d35b59 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -23,6 +23,7 @@ register_operators( fused_transformer_op fused_feedforward_op fused_multi_transformer_op + fused_multi_transformer_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op fused_gemm_epilogue_op @@ -119,6 +120,7 @@ if(WITH_GPU OR WITH_ROCM) # fused_attention_op op_library(fused_attention_op) op_library(fused_multi_transformer_op) + op_library(fused_multi_transformer_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() # resnet_unit needs cudnn 8.0 above diff --git a/paddle/fluid/operators/fused/attention_layer_norm.h b/paddle/fluid/operators/fused/attention_layer_norm.h index baed3ca7a1..e54bca8a89 100644 --- a/paddle/fluid/operators/fused/attention_layer_norm.h +++ b/paddle/fluid/operators/fused/attention_layer_norm.h @@ -19,7 +19,8 @@ limitations under the License. */ namespace paddle { namespace operators { -template +// NOTE: T must be the same as OutType in ComputeBackward +template class AttnLayerNorm { public: AttnLayerNorm(const phi::GPUContext& dev_ctx, @@ -33,17 +34,28 @@ class AttnLayerNorm { ~AttnLayerNorm() {} - void ComputeForward(const T* x_data, + void ComputeForward(const InType* x_data, const LayerNormParamType* scale_data, const LayerNormParamType* bias_data, - T* y_data, + OutType* y_data, LayerNormParamType* mean_data, - LayerNormParamType* var_data) { + LayerNormParamType* var_data, + const float* dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { auto stream = dev_ctx_.stream(); switch (GetDesiredBlockDim(feature_size_)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward, kBlockDim> + LayerNormForward, + kBlockDim, + false, + InType, + OutType> <<>>(x_data, scale_data, bias_data, @@ -51,7 +63,13 @@ class AttnLayerNorm { mean_data, var_data, epsilon_, - feature_size_)); + feature_size_, + dequant_out_scale_data, + quant_out_scale_offset, + quant_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound)); default: PADDLE_THROW(platform::errors::InvalidArgument( "Feature_size must be larger than 1")); diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h new file mode 100644 index 0000000000..ba114df908 --- /dev/null +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -0,0 +1,189 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/operators/fused/cublaslt.h" +#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class AttnMatmulINT8 { + public: + AttnMatmulINT8( + const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias) + : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { + auto helper = std::make_shared(m, k, n); + helpers_.emplace_back(helper); + } + ~AttnMatmulINT8() {} + + // This function is used to execute GEMM, with input and output's types are + // both T. + void ComputeForward(const framework::Tensor* weight, + const framework::Tensor* input, + framework::Tensor* input_tmp, + const framework::Tensor* bias, + framework::Tensor* output, + framework::Tensor* output_tmp, + framework::Tensor* bias_out, + const float quant_in_scale, + const framework::Tensor* dequant_out_scale, + const int quant_out_scale_offset, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { + quantize_kernel_launcher(input->data(), + input_tmp->data(), + quant_in_scale, + m_, + k_, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx_.stream()); + + helpers_[0]->GEMM(input_tmp->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream()); + + dequantize_kernel_launcher(output_tmp->data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + quant_in_scale, + dequant_out_scale->data(), + quant_out_scale_offset); + + if (compute_bias_) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + PADDLE_ENFORCE_EQ(cudaGetLastError(), + cudaSuccess, + platform::errors::Fatal( + "cuda error occured after computing bias. " + "But it does not mean this error is caused by " + "bias computing")); + } + } + + // This function is used to execute GEMM, with input and output's types are + // both INT8. + void ComputeForwardINT8ToINT8(const framework::Tensor* weight, + framework::Tensor* input, + const framework::Tensor* bias, + framework::Tensor* output, + framework::Tensor* bias_out) { + helpers_[0]->GEMM(input->data(), + weight->data(), + output->data(), + dev_ctx_.stream()); + } + + // This function is used to execute GEMM, with input and output's types are + // INT8 and T. + void ComputeForwardINT8ToT(const framework::Tensor* weight, + const float quant_in_scale, + framework::Tensor* input, + const framework::Tensor* bias, + framework::Tensor* output, + framework::Tensor* output_tmp, + framework::Tensor* bias_out, + const framework::Tensor* dequant_out_scale, + const int quant_out_scale_offset) { + helpers_[0]->GEMM(input->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream()); + + dequantize_kernel_launcher(output_tmp->data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + quant_in_scale, + dequant_out_scale->data(), + quant_out_scale_offset); + + if (compute_bias_) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + PADDLE_ENFORCE_EQ(cudaGetLastError(), + cudaSuccess, + platform::errors::Fatal( + "cuda error occured after computing bias. " + "But it does not mean this error is caused by " + "bias computing")); + } + } + + // This function is used to execute GEMM, with input and output's types are T + // and INT8. + void ComputeForwardTToINT8(const framework::Tensor* weight, + const float quant_in_scale, + const framework::Tensor* input, + framework::Tensor* input_tmp, + const framework::Tensor* bias, + framework::Tensor* output, + framework::Tensor* bias_out, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { + quantize_kernel_launcher(input->data(), + input_tmp->data(), + quant_in_scale, + m_, + k_, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx_.stream()); + + helpers_[0]->GEMM(input_tmp->data(), + weight->data(), + output->data(), + dev_ctx_.stream()); + } + + private: + const phi::GPUContext& dev_ctx_; + + int m_; // m + int n_; // n + int k_; // k + + int compute_bias_; + std::vector> helpers_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cublaslt.h b/paddle/fluid/operators/fused/cublaslt.h new file mode 100644 index 0000000000..b9cc6b56f1 --- /dev/null +++ b/paddle/fluid/operators/fused/cublaslt.h @@ -0,0 +1,211 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" + +namespace dyl = paddle::platform::dynload; + +namespace paddle { +namespace operators { +class CublasLtHelper { + public: + CublasLtHelper(int m, int k, int n) + : alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + cublasStatus_t status; + // handle and matmul desc + status = dyl::cublasLtCreate(&handle_); +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32I; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +#endif + + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + +#if CUBLAS_VER_MAJOR < 11 + status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); +#else + status = dyl::cublasLtMatmulDescCreate( + &matmul_desc_, cudaComputeType, CUDA_R_32I); +#endif + + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatmulDescCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + cublasOperation_t op_transpose = CUBLAS_OP_T; + status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op_transpose, + sizeof(op_transpose)); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatmulDescSetAttribute execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + + // matrix desc + status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + + status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + + status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + } + ~CublasLtHelper() { + if (handle_) dyl::cublasLtDestroy(handle_); + if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_); + if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_); + if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_); + if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_); + } + + void GEMM(int8_t* A_dev, + const int8_t* B_dev, + int32_t* C_dev, + cudaStream_t stream) { + cublasStatus_t status; + +#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 + cublasLtMatmulAlgo_t algo; + int algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 15; + int splitK_val = 0; + int reductionScheme = 0; +#if CUDA_VERSION >= 11000 + int stages = 23; +#endif + +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32I; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +#endif + + dyl::cublasLtMatmulAlgoInit(handle_, + cudaComputeType, + CUDA_R_32I, + CUDA_R_8I, + CUDA_R_8I, + CUDA_R_32I, + CUDA_R_32I, + algoId, + &algo); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), + sizeof(customOption)); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); + dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), + sizeof(splitK_val)); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), + sizeof(int)); +#if CUDA_VERSION >= 11000 + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); +#endif +#endif + status = dyl::cublasLtMatmul(handle_, + matmul_desc_, + &alpha_, + B_dev, + B_desc_, + A_dev, + A_desc_, + &beta_, + C_dev, + C_desc_, + C_dev, + C_desc_, +#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 + &algo, +#else + nullptr, +#endif + nullptr, + 0, + stream); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + platform::errors::External( + "cublasLtMatmul execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + } + + private: + cublasLtHandle_t handle_; + cublasLtMatmulDesc_t matmul_desc_; + cublasLtMatrixLayout_t A_desc_; + cublasLtMatrixLayout_t B_desc_; + cublasLtMatrixLayout_t C_desc_; + int32_t alpha_; + int32_t beta_; + + int m_; + int k_; + int n_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index 732da5fa52..6b2cdfb6a8 100644 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -60,19 +60,32 @@ struct GeluGradFunctor { * the src, mask and dst shape is (rows, cols) * the bias shape is (1, cols) */ -template -__global__ void FusedDropoutActBias(Functor act, - const uint64_t seed, - const uint64_t rows, - const uint64_t cols, - const int increment, - const float dropout_prob, - const bool is_upscale_in_train, - const bool is_test, - const T *__restrict__ src, - const T *__restrict__ bias, - T *dst, - MaskType *mask) { +template +__global__ void FusedDropoutActBias( + Functor act, + const uint64_t seed, + const uint64_t rows, + const uint64_t cols, + const int increment, + const float dropout_prob, + const bool is_upscale_in_train, + const bool is_test, + const InType *__restrict__ src, + const T *__restrict__ bias, + OutType *dst, + MaskType *mask, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -90,7 +103,9 @@ __global__ void FusedDropoutActBias(Functor act, VecSize, false, true, - Functor>(r, + Functor, + InType, + OutType>(r, i, cols, &state, @@ -104,7 +119,14 @@ __global__ void FusedDropoutActBias(Functor act, is_test, nullptr, nullptr, - act); + act, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); } } } @@ -112,7 +134,11 @@ __global__ void FusedDropoutActBias(Functor act, /** * @brief dst = dropout(activation(src + bias)); */ -template +template void LaunchDropoutActBias(Functor act_functor, const uint64_t seed, const uint32_t rows, @@ -121,14 +147,21 @@ void LaunchDropoutActBias(Functor act_functor, const float dropout_prob, const bool is_upscale_in_train, const bool is_test, - const T *src, + const InType *src, const T *bias, - T *dst, + OutType *dst, MaskType *mask_data, - const phi::GPUContext &ctx) { + const phi::GPUContext &ctx, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { - SetZero(ctx, dst, rows * cols); + SetZero(ctx, reinterpret_cast(dst), rows * cols); SetZero(ctx, mask_data, rows * cols); return; } @@ -137,7 +170,7 @@ void LaunchDropoutActBias(Functor act_functor, const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); if (cols % VecSize == 0) { - FusedDropoutActBias + FusedDropoutActBias <<>>( act_functor, seed, @@ -150,9 +183,13 @@ void LaunchDropoutActBias(Functor act_functor, src, bias, dst, - mask_data); + mask_data, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); } else { - FusedDropoutActBias + FusedDropoutActBias <<>>( act_functor, seed, @@ -165,7 +202,11 @@ void LaunchDropoutActBias(Functor act_functor, src, bias, dst, - mask_data); + mask_data, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); } } diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 0f37d242eb..1b8dc4bb32 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 208b2a58bc..2d1491fefb 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -109,7 +109,10 @@ struct DropoutParam { } }; -template +template class FusedDropoutHelper { private: int GetIncrement(const phi::GPUContext& ctx) { @@ -140,25 +143,34 @@ class FusedDropoutHelper { // out = residual + dropout( src + bias ) void ResidualDropoutBias(const phi::GPUContext& ctx, - const T* src, + const InType* src, const T* residual, const T* bias, - T* out, - MaskType* mask) { + OutType* out, + MaskType* mask, + const float quant_last_in_scale = 1.0, + const float* dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0) { auto increment = GetIncrement(ctx); - LaunchResidualDropoutBias(rows_, - cols_, - increment, - dropout_param_.seed, - dropout_param_.dropout_prob, - dropout_param_.is_test, - dropout_param_.is_upscale_in_train, - src, - residual, - bias, - mask, - out, - ctx); + LaunchResidualDropoutBias( + rows_, + cols_, + increment, + dropout_param_.seed, + dropout_param_.dropout_prob, + dropout_param_.is_test, + dropout_param_.is_upscale_in_train, + src, + residual, + bias, + mask, + out, + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); } void ResidualDropoutBiasGrad(const phi::GPUContext& ctx, @@ -189,15 +201,22 @@ class FusedDropoutHelper { // out = dropout(activation(src + bias)) void DropoutActBias(const phi::GPUContext& ctx, - const T* src, + const InType* src, const T* bias, const std::string& act_method, - T* out, - MaskType* mask) { + OutType* out, + MaskType* mask, + const float quant_last_in_scale = 1.0, + const float* dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { auto increment = GetIncrement(ctx); if (act_method == "gelu") { GeluFunctor gelu; - LaunchDropoutActBias>( + LaunchDropoutActBias, InType, OutType>( gelu, dropout_param_.seed, rows_, @@ -210,23 +229,40 @@ class FusedDropoutHelper { bias, out, mask, - ctx); + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); } else if (act_method == "relu") { phi::funcs::ReluFunctor relu; - LaunchDropoutActBias>( - relu, - dropout_param_.seed, - rows_, - cols_, - increment, - dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, - dropout_param_.is_test, - src, - bias, - out, - mask, - ctx); + LaunchDropoutActBias, + InType, + OutType>(relu, + dropout_param_.seed, + rows_, + cols_, + increment, + dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, + dropout_param_.is_test, + src, + bias, + out, + mask, + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Currently only supports gelu or relu activation functions!")); @@ -283,8 +319,12 @@ class FusedDropoutHelper { DropoutParam dropout_param_; }; -template -class FusedDropoutLayerNormHelper : public FusedDropoutHelper { +template +class FusedDropoutLayerNormHelper + : public FusedDropoutHelper { public: FusedDropoutLayerNormHelper() {} FusedDropoutLayerNormHelper(const int rows, @@ -301,23 +341,24 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { const int cols, const DropoutParam& dropout_param, const float epsilon) - : FusedDropoutHelper(ctx, rows, cols, dropout_param) { + : FusedDropoutHelper( + ctx, rows, cols, dropout_param) { using U = LayerNormParamType; epsilon_ = epsilon; } // call layer_norm void LayerNorm(const phi::GPUContext& ctx, - const T* src, + const InType* src, const LayerNormParamType* gamma, const LayerNormParamType* beta, - T* out, + OutType* out, LayerNormParamType* mean, LayerNormParamType* variance) { using U = LayerNormParamType; switch (GetDesiredBlockDim(this->cols_)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward + LayerNormForward <<rows_, kBlockDim, 0, ctx.stream()>>>( src, gamma, beta, out, mean, variance, epsilon_, this->cols_)); } @@ -349,17 +390,25 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { // out = layernorm(residual + dropout(src + bias)) template , bool is_same_type = false> - void LayernormResidualDropoutBias(const phi::GPUContext& ctx, - const T* src, - const T* residual, - const T* bias, - const P* gamma, - const P* beta, - T* dropout_out, - MaskType* mask, - T* out, - LayerNormParamType* mean, - LayerNormParamType* variance) { + void LayernormResidualDropoutBias( + const phi::GPUContext& ctx, + const InType* src, + const T* residual, + const T* bias, + const P* gamma, + const P* beta, + T* dropout_out, + MaskType* mask, + OutType* out, + LayerNormParamType* mean, + LayerNormParamType* variance, + const float quant_last_in_scale = 1.0, + const float* dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { using U = LayerNormParamType; int vec_size = MAX_CACHE_BYTES / sizeof(T); if (this->cols_ % vec_size != 0) { @@ -368,7 +417,12 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { int threads = GetDesiredBlockDim(this->cols_ / vec_size); int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); - LaunchLayernormResidualDropoutBias( + LaunchLayernormResidualDropoutBias( this->rows_, this->cols_, increment, @@ -387,7 +441,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { out, mean, variance, - ctx); + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); } template , bool is_same_type = false> diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 7bb3498567..137943afbf 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -418,7 +418,9 @@ template + int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA, + typename InType = T, + typename OutType = T> __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( int rows, int cols, @@ -428,7 +430,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( const bool is_test, const uint64_t increment, const float epsilon, - const T *__restrict__ x_ptr, + const InType *__restrict__ x_ptr, const T *__restrict__ residual_ptr, const T *__restrict__ bias_ptr, const ScaleT *__restrict__ gamma_ptr, @@ -437,10 +439,20 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, T *__restrict__ residual_out_ptr, - T *__restrict__ y_ptr) { + OutType *__restrict__ y_ptr, + const float quant_last_in_scale = 1.0, + const float *__restrict__ quant_out_scale_ptr = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { __shared__ U smem[WARPS_M * WARPS_N]; using Vec = phi::AlignedVector; using Vec_scale = phi::AlignedVector; + using Vec_in_type = phi::AlignedVector; + using Vec_out_type = phi::AlignedVector; + using Vec_float = phi::AlignedVector; using MaskStoreT = phi::AlignedVector; const int tidx = threadIdx.x; @@ -481,12 +493,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( constexpr U rn = 1.f / U(ELTS_PER_ROW); for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { Vec x[LDGS]; + Vec_in_type x_input[LDGS]; Vec residual[LDGS]; + Vec_float dequant_out_scale[LDGS]; + #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); phi::Load(residual_ptr + row * ELTS_PER_ROW + col * VecSize, &residual[it]); + phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, + &x_input[it]); + if (quant_out_scale_ptr != nullptr) { + phi::Load( + quant_out_scale_ptr + quant_out_scale_offset + col * VecSize, + &dequant_out_scale[it]); + } col += THREADS_PER_ROW; } @@ -520,10 +541,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( #pragma unroll for (int jt = 0; jt < VecSize; jt++) { // dropout(x) + residual - x[it][jt] = (x[it][jt] + bias[it][jt]) * - static_cast(mask_vec[it][jt]) * factor + - residual[it][jt]; - xf[it * VecSize + jt] = U(x[it][jt]); + if (std::is_same::value) { + T tmp = (static_cast(static_cast(x_input[it][jt]) * + quant_last_in_scale / + dequant_out_scale[it][jt]) + + bias[it][jt]) * + static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + x[it][jt] = tmp; + xf[it * VecSize + jt] = U(tmp); + } else { + x[it][jt] = (static_cast(x_input[it][jt]) + bias[it][jt]) * + static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + xf[it * VecSize + jt] = U(x[it][jt]); + } } } } else { @@ -532,8 +564,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( #pragma unroll for (int jt = 0; jt < VecSize; jt++) { // dropout(x) + residual - x[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor + - residual[it][jt]; + if (std::is_same::value) { + // for int32 input, we need to dequantize. + T tmp = static_cast(static_cast(x_input[it][jt]) * + quant_last_in_scale / + dequant_out_scale[it][jt]) * + static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + x[it][jt] = tmp; + } else { + x[it][jt] = static_cast(x_input[it][jt]) * + static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + } xf[it * VecSize + jt] = U(x[it][jt]); } } @@ -626,6 +669,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( var_out_ptr[row] = var_local * rn; } + Vec_out_type x_output[LDGS]; + #pragma unroll for (int it = 0; it < LDGS; it++) { #pragma unroll @@ -638,12 +683,26 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( U tmp = rsigma * (static_cast(xf[it * VecSize + jt]) - mu_local); x[it][jt] = static_cast(static_cast(gamma[it][jt]) * tmp + static_cast(beta[it][jt])); + + if (std::is_same::value) + x_output[it][jt] = quant_helper(x[it][jt], + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); } } #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - phi::Store(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize); + if (std::is_same::value) { + phi::Store( + x_output[it], y_ptr + row * ELTS_PER_ROW + col * VecSize); + } else { + phi::Store( + x[it], + reinterpret_cast(y_ptr) + row * ELTS_PER_ROW + col * VecSize); + } col += THREADS_PER_ROW; } } @@ -668,7 +727,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( template + bool ScaleBiasWithSameTypeX = false, + typename InType = T, + typename OutType = T> void LaunchLayernormResidualDropoutBias( const uint32_t rows, const uint32_t cols, @@ -678,18 +739,26 @@ void LaunchLayernormResidualDropoutBias( const float epsilon, const bool is_upscale_in_train, const bool is_test, - const T *src, + const InType *src, const T *residual, const T *bias, const LayerNormScaleBiasT *scale, const LayerNormScaleBiasT *layernorm_bias, MaskType *mask_data, T *dst, - T *layernorm_dst, + OutType *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var, - const phi::GPUContext &ctx) { + const phi::GPUContext &ctx, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { // dropout_prob == 1.0f + // NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0 if (std::abs(dropout_prob - 1.0f) < 1e-5) { auto cuda_place = ctx.GetPlace(); memory::Copy(cuda_place, @@ -705,14 +774,15 @@ void LaunchLayernormResidualDropoutBias( switch (GetDesiredBlockDim(cols)) { FIXED_BLOCK_DIM_CASE( LayerNormForward - <<>>(dst, - scale, - layernorm_bias, - layernorm_dst, - mean, - var, - epsilon, - cols)); + <<>>( + dst, + scale, + layernorm_bias, + reinterpret_cast(layernorm_dst), + mean, + var, + epsilon, + cols)); default: PADDLE_THROW(platform::errors::InvalidArgument( "Product from begin_norm_axis to end must be larger than 1")); @@ -722,44 +792,63 @@ void LaunchLayernormResidualDropoutBias( return; } -#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \ - case (cols): { \ - constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \ - constexpr int WARPS_M = 4 / WARPS_N; \ - const int THREADS_PER_WARP = 32; \ - const int BYTES_PER_LDG = 16; \ - const int VecSize = BYTES_PER_LDG / sizeof(T); \ - const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \ - const int ROWS_PER_CTA = WARPS_M; \ - const int grid = \ - static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); \ - fused_fast_ln_fwd_kernel< \ - T, \ - U, \ - LayerNormScaleBiasT, \ - uint8_t, \ - VecSize, \ - WARPS_M, \ - WARPS_N, \ - BYTES_PER_LDG, \ - cols><<>>(rows, \ - cols, \ - seed, \ - dropout_prob, \ - is_upscale_in_train, \ - is_test, \ - increment, \ - epsilon, \ - src, \ - residual, \ - bias, \ - scale, \ - layernorm_bias, \ - mask_data, \ - mean, \ - var, \ - dst, \ - layernorm_dst); \ +#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \ + case (cols): { \ + constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \ + constexpr int WARPS_M = 4 / WARPS_N; \ + const int THREADS_PER_WARP = 32; \ + const int BYTES_PER_LDG = 16; \ + const int VecSize = BYTES_PER_LDG / sizeof(T); \ + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \ + const int ROWS_PER_CTA = WARPS_M; \ + const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; \ + const int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW * VecSize; \ + const int LDGS = cols / ELTS_PER_ROW_PER_CTA; \ + const int grid = \ + static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); \ + fused_fast_ln_fwd_kernel< \ + T, \ + U, \ + LayerNormScaleBiasT, \ + uint8_t, \ + VecSize, \ + WARPS_M, \ + WARPS_N, \ + BYTES_PER_LDG, \ + cols, \ + THREADS_PER_WARP, \ + THREADS_PER_ROW, \ + THREADS_PER_CTA, \ + ROWS_PER_CTA, \ + ELTS_PER_ROW_PER_CTA, \ + LDGS, \ + InType, \ + OutType> \ + <<>>(rows, \ + cols, \ + seed, \ + dropout_prob, \ + is_upscale_in_train, \ + is_test, \ + increment, \ + epsilon, \ + src, \ + residual, \ + bias, \ + scale, \ + layernorm_bias, \ + mask_data, \ + mean, \ + var, \ + dst, \ + layernorm_dst, \ + quant_last_in_scale, \ + dequant_out_scale_data, \ + quant_out_scale_offset, \ + quant_next_in_scale, \ + quant_round_type, \ + quant_max_bound, \ + quant_min_bound); \ } break #define LAUNCH_FUSED_FAST_LN_KERNEL \ @@ -784,24 +873,25 @@ void LaunchLayernormResidualDropoutBias( if (cols % VecSize != 0) { int blockDim = GetDesiredBlockDim(cols); FusedLayernormResidualDropoutBias - <<>>(rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - is_test, - increment, - epsilon, - src, - residual, - bias, - scale, - layernorm_bias, - mask_data, - dst, - layernorm_dst, - mean, - var); + <<>>( + rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + reinterpret_cast(src), + residual, + bias, + scale, + layernorm_bias, + mask_data, + dst, + reinterpret_cast(layernorm_dst), + mean, + var); } else { if (can_call_fast_ln_kernel) { switch (cols) { @@ -819,24 +909,25 @@ void LaunchLayernormResidualDropoutBias( VecSize, U, ScaleBiasWithSameTypeX> - <<>>(rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - is_test, - increment, - epsilon, - src, - residual, - bias, - scale, - layernorm_bias, - mask_data, - dst, - layernorm_dst, - mean, - var); + <<>>( + rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + reinterpret_cast(src), + residual, + bias, + scale, + layernorm_bias, + mask_data, + dst, + reinterpret_cast(layernorm_dst), + mean, + var); } } } diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc new file mode 100644 index 0000000000..9572a87aba --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc @@ -0,0 +1,369 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerINT8Op"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // ffn + CHECK_INPUTS(FFN1Weight); + CHECK_INPUTS(FFN2Weight); + + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + if (ctx->Attrs().Get("ring_id") == -1) { + if (trans_qkvw) { + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], + y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + + } else { + PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], + y_dim[0], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(dim_embed, 3, num_head, dim_head)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + } + } + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[1], + x_dim[0], + paddle::platform::errors::InvalidArgument( + "The second dim of CacheKV must be equal with " + "batch size %d, but got %d", + x_dim[0], + c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + PADDLE_ENFORCE_GT( + c_dim[3], + 0, + paddle::platform::errors::InvalidArgument( + "The forth dim of CacheKV must be greater than 0, but got %d", + c_dim[3])); // cache_seq_len + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, + const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerINT8OpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + + AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op") + .AsDispensable() + .AsDuplicable(); + + AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") + .AsDispensable() + .AsDuplicable(); + + AddInput("QKVOutScale", + "QKVOutScale is used to dequantize qkv output tensor." + "In order to keep consistent with the PTQ/QAT calculation logic," + "QKVOutScale should be max_bound * max_bound / max_range." + "Here max_range is per-channel weight scale." + "The shape of QKVOutScale is [num_layers, num_channels]") + .AsDispensable(); + AddInput("OutLinearOutScale", + "OutLinearOutScale is used to dequantize out_linear output tensor." + "The definition and shape is the same as QKVOutScale") + .AsDispensable(); + AddInput("FFN1OutScale", + "FFN1OutScale is used to dequantize ffn1 output tensor." + "The definition and shape is the same as QKVOutScale") + .AsDispensable(); + AddInput("FFN2OutScale", + "FFN2OutScale is used to dequantize ffn2 output tensor." + "The definition and shape is the same as QKVOutScale") + .AsDispensable(); + + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, + true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", + true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("act_method", "act_method").SetDefault("gelu"); + AddAttr( + "trans_qkvw", + "Whether the weights of qkv should be transposed. If true," + "the shape eights of qkv should be [3, num_head, dim_head, dim_embed]." + "Otherwise the shape of weights of qkv should be" + "[dim_embed, 3, num_head, dim_head]") + .SetDefault(true); + + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + + AddAttr("num_head", "num_head").SetDefault(0); + AddAttr("dim_head", "dim_head").SetDefault(0); + AddAttr("dim_ffn", "dim_ffn").SetDefault(0); + + AddAttr>( + "qkv_in_scale", + "qkv_in_scale is used to quantize qkv input tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of qkv_in_scale should be num_layers, which is equal to " + "QKVW.dims()[0]") + .SetDefault({}); + AddAttr>( + "out_linear_in_scale", + "out_linear_in_scale is used to quantize out_linear input tensor." + "the size of out_linear_in_scale is the same as qkv_in_scale") + .SetDefault({}); + AddAttr>( + "ffn1_in_scale", + "ffn1_in_scale is used to quantize ffn1 input tensor." + "the size of ffn1_in_scale is the same as qkv_in_scale") + .SetDefault({}); + AddAttr>( + "ffn2_in_scale", + "ffn2_in_scale is used to quantize ffn2 input tensor." + "the size of ffn2_in_scale is the same as qkv_in_scale") + .SetDefault({}); + + AddAttr( + "quant_round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(-2.5)=-3") + .SetDefault(1); + AddAttr( + "quant_max_bound", + "(float, default 127.0) the max bound of float type to int type") + .SetDefault(127.0); + AddAttr( + "quant_min_bound", + "(float, default -127.0) the min bound of float type to int type") + .SetDefault(-127.0); + + AddComment(R"DOC(fused multi transformer layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer_int8, + ops::FusedMultiTransformerINT8Op, + ops::FusedMultiTransformerINT8OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu new file mode 100644 index 0000000000..8e200275f8 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -0,0 +1,670 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/attn_gemm_int8.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" + +namespace paddle { +namespace operators { + +template +class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + + // quant input scales, vector, size = num_layers + auto qkv_in_scale = ctx.Attr>("qkv_in_scale"); + auto out_linear_in_scale = + ctx.Attr>("out_linear_in_scale"); + auto ffn1_in_scale = ctx.Attr>("ffn1_in_scale"); + auto ffn2_in_scale = ctx.Attr>("ffn2_in_scale"); + + // quant round type and bound + auto quant_round_type = ctx.Attr("quant_round_type"); + auto quant_max_bound = ctx.Attr("quant_max_bound"); + auto quant_min_bound = ctx.Attr("quant_min_bound"); + + // dequant output scales, tensor, size = [num_layers, n], n is gemm output + // size + auto *qkv_out_scale = ctx.Input("QKVOutScale"); + auto *out_linear_out_scale = ctx.Input("OutLinearOutScale"); + auto *ffn1_out_scale = ctx.Input("FFN1OutScale"); + auto *ffn2_out_scale = ctx.Input("FFN2OutScale"); + + int qkv_out_scale_n = qkv_out_scale->dims()[1]; + int out_linear_out_scale_n = out_linear_out_scale->dims()[1]; + int ffn1_out_scale_n = ffn1_out_scale->dims()[1]; + int ffn2_out_scale_n = ffn2_out_scale->dims()[1]; + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + auto ln_compute = + AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + ln_mean.Resize({{bsz_seq}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{bsz_seq}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + AttnMatmulINT8 qkv_compute( + dev_ctx, bsz_seq, output_size, input_size, compute_bias); + Tensor qkv_out; + qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + // auto *time_step = ctx.Input("TimeStep"); + + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_EQ(time_step->place(), + platform::CPUPlace(), + platform::errors::PreconditionNotMet( + "The place of input(TimeStep) must be CPUPlace.")); + // cache_seq_len + int time_step_value = time_step->data()[0]; + PADDLE_ENFORCE_GT(time_step_value, + 0, + platform::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + platform::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_value; + } + + Tensor transpose_out_2, qk_out; + transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); + auto *transpose_out_2_data = + dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + Tensor softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *attn_dropout_mask_out_data = dev_ctx.Alloc( + &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); + attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *attn_dropout_data_data = dev_ctx.Alloc( + &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + AttnMatmulINT8 out_linear_compute( + dev_ctx, bsz_seq, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper_for_post_layernorm( + dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + } + dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); + auto *dropout_mask_out_data = dev_ctx.Alloc( + &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + + // 6. ffn matmul1 + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + int dim_ffn = ffn1_weight_dim[0]; + AttnMatmulINT8 ffn1_linear_compute( + dev_ctx, bsz_seq, dim_ffn, dim_embed, false); + Tensor ffn1_out; + ffn1_out.Resize({{bsz_seq, dim_ffn}}); + auto *ffn1_out_data = + dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + // 7. ffn act + bias + DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + FusedDropoutHelper fused_act_dropout_helper_for_post_layernorm( + dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + Tensor ffn1_dropout_out, ffn1_dropout_mask; + ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); + auto *ffn1_dropout_out_data = dev_ctx.Alloc( + &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); + ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); + auto *ffn1_dropout_mask_data = dev_ctx.Alloc( + &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); + + // 8. ffn2 matmul + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + AttnMatmulINT8 ffn2_linear_compute( + dev_ctx, bsz_seq, dim_embed, dim_ffn, false); + + // 9. ffn2 residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper + ffn2_fused_dropout_helper( + dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + FusedDropoutLayerNormHelper + ffn2_fused_dropout_dequant_helper( + dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + FusedDropoutLayerNormHelper + ffn2_fused_dropout_helper_for_post_layernorm( + dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + + // []. init workspace for cublasLt transform + Tensor input_workspace, output_workspace; + // for input and output transform data is CUBLASLT_ORDER_COL32 format, + int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), + n_max = std::max({output_size, dim_embed, dim_ffn}); + + input_workspace.Resize( + {{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}}); + dev_ctx.Alloc(&input_workspace, + input_workspace.numel() * sizeof(int8_t)); + output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}}); + dev_ctx.Alloc(&output_workspace, + output_workspace.numel() * sizeof(int32_t)); + + // calc + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + Tensor *from_tensor = out; + Tensor tmp_out; + tmp_out.Resize({{bsz, seq_len, dim_embed}}); + auto *tmp_out_data = + dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + + auto *x_data = input_x->data(); + Tensor *buf0 = nullptr; + Tensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (pre_layer_norm) { + buf1 = out; + } else { + buf0 = &tmp_out; + buf1 = out; + } + + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + input_workspace.data(), + ln_mean_data, + ln_var_data, + nullptr, + 0, + qkv_in_scale[i], + quant_round_type, + quant_max_bound, + quant_min_bound); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1"; +#endif + + // step2. qkv + const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + if (!pre_layer_norm && i == 0) { + qkv_compute.ComputeForward(qkv_weights[i], + input_x, + &input_workspace, + bias, + &qkv_out, + &output_workspace, + &qkv_out, + qkv_in_scale[i], + qkv_out_scale, + i * qkv_out_scale_n, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else if (!pre_layer_norm) { + qkv_compute.ComputeForward(qkv_weights[i], + buf1, + &input_workspace, + bias, + &qkv_out, + &output_workspace, + &qkv_out, + qkv_in_scale[i], + qkv_out_scale, + i * qkv_out_scale_n, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], + qkv_in_scale[i], + &input_workspace, + bias, + &qkv_out, + &output_workspace, + &qkv_out, + qkv_out_scale, + i * qkv_out_scale_n); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2"; +#endif + + // step3. fmha + const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + cache_kv_out, + &fmha_out, + bsz, + max_seq_len, + num_head, + dim_head, + time_step->data()[0], + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + // TODO(wangxi): can remove dropout in inference + fmha_compute.ComputeForward(qkv_out, + nullptr, + src_mask, + &transpose_out_2, + nullptr, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + const T *k_ptr = q_ptr + q_size; + const T *v_ptr = k_ptr + k_size; + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + fmha_compute.ComputeForward(qkv_out, + cache_kv, + src_mask, + &transpose_out_2, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3"; +#endif + + if (pre_layer_norm) { + out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], + out_linear_in_scale[i], + &fmha_out, + &input_workspace, + nullptr, + &output_workspace, + nullptr, + quant_round_type, + quant_max_bound, + quant_min_bound); + AllReduce(output_workspace, + ring_id, + bsz * seq_len * num_head * dim_head, + dev_ctx); + } else { + out_linear_compute.ComputeForward(out_linear_weights[i], + &fmha_out, + &input_workspace, + nullptr, + buf0, + &output_workspace, + nullptr, + out_linear_in_scale[i], + out_linear_out_scale, + i * out_linear_out_scale_n, + quant_round_type, + quant_max_bound, + quant_min_bound); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // inplace + // non-inplace: buf1 -> input_workspace + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + output_workspace.data(), + x_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + input_workspace.data(), + ln_mean_data, + ln_var_data, + out_linear_in_scale[i], + out_linear_out_scale->data(), + i * out_linear_out_scale_n, + ffn1_in_scale[i], + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper_for_post_layernorm + .LayernormResidualDropoutBias(dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + + // step6. ffn matmul1 + + if (pre_layer_norm) { + ffn1_linear_compute.ComputeForwardINT8ToINT8(ffn1_weights[i], + &input_workspace, + nullptr, + &output_workspace, + nullptr); + } else { + ffn1_linear_compute.ComputeForward(ffn1_weights[i], + buf1, + &input_workspace, + nullptr, + &ffn1_out, + &output_workspace, + nullptr, + ffn1_in_scale[i], + ffn1_out_scale, + i * ffn1_out_scale_n, + quant_round_type, + quant_max_bound, + quant_min_bound); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step6"; +#endif + + // step7. act bias + // TODO(wangxi): remove dropout mask in inference + if (pre_layer_norm) { + fused_act_dropout_helper.DropoutActBias( + dev_ctx, + output_workspace.data(), + ffn1_biases[i]->data(), + "gelu", + input_workspace.data(), + ffn1_dropout_mask_data, + ffn1_in_scale[i], + ffn1_out_scale->data(), + i * ffn1_out_scale_n, + ffn2_in_scale[i], + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + fused_act_dropout_helper_for_post_layernorm.DropoutActBias( + dev_ctx, + ffn1_out_data, + ffn1_biases[i]->data(), + "gelu", + ffn1_dropout_out_data, + ffn1_dropout_mask_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step7"; +#endif + + // step8. ffn matmul2 + if (pre_layer_norm) { + ffn2_linear_compute.ComputeForwardINT8ToINT8(ffn2_weights[i], + &input_workspace, + nullptr, + &output_workspace, + nullptr); + } else { + ffn2_linear_compute.ComputeForward(ffn2_weights[i], + &ffn1_dropout_out, + &input_workspace, + nullptr, + buf0, + &output_workspace, + nullptr, + ffn2_in_scale[i], + ffn2_out_scale, + i * ffn2_out_scale_n, + quant_round_type, + quant_max_bound, + quant_min_bound); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8.0"; +#endif + + if (pre_layer_norm) { + AllReduce(output_workspace, + ring_id, + bsz * seq_len * num_head * dim_head, + dev_ctx); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8.1"; +#endif + + // step9. residual bias + if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference + if (i < layers - 1) { + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + output_workspace.data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf1->data(), + dropout_mask_out_data, + input_workspace.data(), + ln_mean_data, + ln_var_data, + ffn2_in_scale[i], + ffn2_out_scale->data(), + i * ffn2_out_scale_n, + qkv_in_scale[i + 1], + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + ffn2_fused_dropout_dequant_helper.ResidualDropoutBias( + dev_ctx, + output_workspace.data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + buf1->data(), + dropout_mask_out_data, + ffn2_in_scale[i], + ffn2_out_scale->data(), + i * ffn2_out_scale_n, + 1.0); + } + } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper_for_post_layernorm + .LayernormResidualDropoutBias(dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step9"; +#endif + if (pre_layer_norm) { + x_data = buf1->data(); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_int8, + ops::FusedMultiTransformerINT8OpKernel, + ops::FusedMultiTransformerINT8OpKernel); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 04681f3d7a..5cf22885aa 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -1,1161 +1,19 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -// This file has been adapted from FasterTransformer file: -// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu -// We add License in the head. - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/attn_gemm.h" -#include "paddle/fluid/operators/fused/fmha_ref.h" -#include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/ProcessGroup.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; - -// for debug -// #define _DEBUG_FUSED_MULTI_TRANSFORMER - -template -static void AllReduce(framework::Tensor &tensor, // NOLINT - const int ring_id, - const phi::GPUContext &ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup *pg = map->get(ring_id); - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(tensor); - paddle::distributed::AllreduceOptions opts; - opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t numel = tensor.numel(); - const void *sendbuff = tensor.data(); - auto place = ctx.GetPlace(); - void *recvbuff = ctx.Alloc(&tensor, tensor.numel() * sizeof(T)); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = ctx.stream(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - -namespace { - -namespace plat = paddle::platform; -using float16 = plat::float16; - -#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#define MMHA_USE_FP32_ACUM_FOR_OUT - -template -struct Masked_multihead_attention_params { - // output buffer, [B, 1(seq_len), num_head * dim_head] - T *out; - // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] - const T *qkv; - // bias, [3, num_head, dim_head] - const T *qkv_bias; - // TODO(wangxi): optimize with input_lengths and max_input_len? - // [bsz, 1, 1, time_step(cache_seq_length)+1] - const T *attn_mask; - - // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] - // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first - // v [B, num_head, max_seq_len, dim_head] - T *cache_kv; - - int batch_size; - int num_head; - int timestep; // cache_seq_length - int max_seq_length; - - // 1.f / sqrt(Dh) - float inv_sqrt_dh; -}; - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -// clang-format off - -template struct Qk_vec_ {}; -template <> struct Qk_vec_ { using Type = float; }; -template <> struct Qk_vec_ { using Type = float2; }; -template <> struct Qk_vec_ { using Type = float4; }; -template <> struct Qk_vec_ { using Type = float4; }; -template <> struct Qk_vec_ { using Type = uint32_t; }; -template <> struct Qk_vec_ { using Type = uint32_t; }; -template <> struct Qk_vec_ { using Type = uint2; }; -template <> struct Qk_vec_ { using Type = uint4; }; - -template struct K_vec_ {}; -template <> struct K_vec_ { using Type = float; }; -template <> struct K_vec_ { using Type = float2; }; -template <> struct K_vec_ { using Type = float4; }; -template <> struct K_vec_ { using Type = uint32_t; }; -template <> struct K_vec_ { using Type = uint2; }; -template <> struct K_vec_ { using Type = uint4; }; - -template struct V_vec_ {}; -template <> struct V_vec_ { using Type = float; }; -template <> struct V_vec_ { using Type = float2; }; -template <> struct V_vec_ { using Type = float4; }; -template <> struct V_vec_ { using Type = uint32_t; }; -template <> struct V_vec_ { using Type = uint2; }; -template <> struct V_vec_ { using Type = uint4; }; - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template struct V_vec_acum_fp32_ {}; -// template <> struct V_vec_acum_fp32_ { using Type = float; }; -// template <> struct V_vec_acum_fp32_ { using Type = float2; }; -template <> struct V_vec_acum_fp32_ { using Type = float4; }; -// template <> struct V_vec_acum_fp32_ { using Type = float2; }; -// template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; -template <> struct V_vec_acum_fp32_ { using Type = Float8_; }; -#endif - -// clang-format on - -inline __device__ float half_to_float(uint16_t h) { - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -inline __device__ float2 half2_to_float2(uint32_t v) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -inline __device__ uint32_t float2_to_half2(float2 f) { - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" - : "=r"(tmp.u32) - : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -inline __device__ float add(float a, float b) { return a + b; } - -inline __device__ float2 add(float2 a, float2 b) { - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -inline __device__ float4 add(float4 a, float4 b) { - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -inline __device__ uint16_t add(uint16_t a, uint16_t b) { - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -inline __device__ uint32_t add(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -inline __device__ uint2 add(uint2 a, uint2 b) { - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -inline __device__ uint4 add(uint4 a, uint4 b) { - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -inline __device__ float2 add(uint32_t a, float2 fb) { - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -inline __device__ Float8_ add(uint4 a, Float8_ fb) { - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -template -inline __device__ Acc mul(A a, B b); - -template <> -inline __device__ float mul(float a, float b) { - return a * b; -} - -template <> -inline __device__ float2 mul(float2 a, float2 b) { - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -template <> -inline __device__ float4 mul(float4 a, float4 b) { - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -template <> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) { - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -template <> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -template <> -inline __device__ uint2 mul(uint2 a, uint2 b) { - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -template <> -inline __device__ uint4 mul(uint4 a, uint4 b) { - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -template <> -inline __device__ uint32_t mul(uint32_t a, float b) { - float2 tmp = half2_to_float2(a); - float2 tmp_res; - tmp_res.x = tmp.x * b; - tmp_res.y = tmp.y * b; - uint32_t res = float2_to_half2(tmp_res); - return res; -} - -template <> -inline __device__ uint2 mul(uint2 a, float b) { - uint2 res; - res.x = mul(a.x, b); - res.y = mul(a.y, b); - return res; -} - -template <> -inline __device__ uint4 mul(uint4 a, float b) { - uint4 res; - res.x = mul(a.x, b); - res.y = mul(a.y, b); - res.z = mul(a.z, b); - res.w = mul(a.w, b); - return res; -} - -template <> -inline __device__ float2 mul(float2 a, float b) { - float2 res; - res.x = a.x * b; - res.y = a.y * b; - return res; -} - -template <> -inline __device__ float4 mul(float4 a, float b) { - float4 res; - res.x = a.x * b; - res.y = a.y * b; - res.z = a.z * b; - res.w = a.w * b; - return res; -} - -inline __device__ float sum(float v) { return v; } -inline __device__ float sum(float2 v) { return v.x + v.y; } -inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -inline __device__ float sum(uint16_t v) { return half_to_float(v); } -inline __device__ float sum(uint32_t v) { - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -inline __device__ float sum(uint2 v) { - uint32_t c = add(v.x, v.y); - return sum(c); -} - -inline __device__ float sum(uint4 v) { - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); - return sum(c); -} - -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); -} - -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); -} - -inline __device__ constexpr uint32_t shfl_mask(int threads) { - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -template -inline __device__ __host__ T div_up(T m, T n) { - return (m + n - 1) / n; -} - -inline __device__ float fma(float a, float b, float c) { return a * b + c; } - -inline __device__ float2 fma(float2 a, float2 b, float2 c) { - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -inline __device__ float4 fma(float4 a, float4 b, float4 c) { - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(d) - : "r"(a), "r"(b), "r"(c)); - return d; -} - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -inline __device__ float2 fma(float a, float2 b, float2 c) { - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -inline __device__ float4 fma(float a, float4 b, float4 c) { - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -inline __device__ uint32_t h0_h0(uint16_t a) { - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { - return fma(h0_h0(a), b, c); -} - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -inline __device__ float cast_to_float(float u) { return u; } - -inline __device__ float2 cast_to_float(float2 u) { return u; } - -inline __device__ float4 cast_to_float(float4 u) { return u; } - -inline __device__ Float8_ cast_to_float(uint4 u) { - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - K_vec inv_q = mul(q[0], inv_sqrt_dh); - K_vec qk_vec = mul(inv_q, k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); - } - - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - return qk_dot_(q, k, inv_sqrt_dh); - } -}; - -template -inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - if (lane == 0) { - red_smem[warp] = sum; - } - __syncthreads(); - - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - return __shfl_sync(uint32_t(-1), sum, 0); -} - -inline __device__ void convert_from_float(float &dst, float src) { // NOLINT - dst = src; -} - -inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT - dst = src; -} - -inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT - float src) { - dst = static_cast(src); -} - -inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT - -template -inline __device__ void zero(T &dst) { // NOLINT - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -template -__global__ void masked_multihead_attention_kernel( - Masked_multihead_attention_params params) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - constexpr int WARP_SIZE = 32; - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - extern __shared__ char smem_[]; - - float *qk_smem = reinterpret_cast(smem_); - - char *logits_smem_ = smem_; - // fp32 accum for logits - float *logits_smem = reinterpret_cast(logits_smem_); - - T *out_smem = reinterpret_cast(smem_); - - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - using Qk_vec = typename Qk_vec_::Type; - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - const int bi = blockIdx.y; - const int hi = blockIdx.x; - const int bhi = bi * params.num_head + hi; - const int tid = threadIdx.x; - - float qk_max = -FLT_MAX; - float qk = 0; - - // qkv [B, S=1, 3, num_head, head_dim] - int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; - - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // Use block reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // cache_k, [B, num_head, head_dim / x, max_seq_len, x] - // x == 4/8 for FP32/FP16, 128bit, 16Byte - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - const T *q_base = params.qkv; - const T *k_base = params.qkv + params.num_head * Dh; - const T *q_bias_base = params.qkv_bias; - const T *k_bias_base = params.qkv_bias + params.num_head * Dh; - - if (tid < QK_VECS_PER_WARP) { - int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; - int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; - - Qk_vec q; - zero(q); - q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_base[qk_offset]) - : q; - Qk_vec k; - zero(k); - k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_base[qk_offset]) - : k; - - Qk_vec q_bias; - zero(q_bias); - q_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) - : q_bias; - Qk_vec k_bias; - zero(k_bias); - k_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) - : k_bias; - - q = add(q, q_bias); - // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 - // we may not require k_bias. - k = add(k, k_bias); - - *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; - - int co = tid / QK_VECS_IN_16B; - int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; - int offset = bhi * params.max_seq_length * Dh + - co * params.max_seq_length * QK_ELTS_IN_16B + - params.timestep * QK_ELTS_IN_16B + ci; - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.cache_kv[offset]) = k; - } - - qk = dot(q, k); - - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = - (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - if (tid == 0) { - // NOTE(wangxi): mask must be 0.0 - // T mask = params.attn_mask[ - // bi * (params.timestep + 1) + params.timestep]; - // qk += static_cast(mask); - qk *= params.inv_sqrt_dh; - qk_max = qk; - qk_smem[params.timestep] = qk; - } - __syncthreads(); - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("=======q_out=======\n"); - for (int i = 0; i < Dh; ++i) printf("%f ", static_cast(q_smem[i])); - printf("\n"); - } - __syncthreads(); -#endif - - using K_vec = typename K_vec_::Type; - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - int ko = tid / THREADS_PER_KEY; - int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); - - K_vec q[K_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < K_VECS_PER_THREAD; ++i) { - q[i] = *reinterpret_cast( - &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); - } - - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; - int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - - for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_seq_length + ti; - if (ti < params.timestep) { - k[ii] = - (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) - ? *reinterpret_cast( - &k_cache[jj * QK_ELTS_IN_16B]) - : k_vec_zero; - } - } - - // NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k) - // may overflow with FP16 in large model. - float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); - - // bool is_mask = false; - if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { - // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; - qk += static_cast(mask); - qk_max = fmaxf(qk_max, qk); - - qk_smem[ti] = qk; - } - } - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - const int warp = tid / WARP_SIZE; - const int lane = tid % WARP_SIZE; - - if (lane == 0) { - red_smem[warp] = qk_max; - } - - __syncthreads(); - - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("=======qk_out=======\n"); - for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); - printf("qk_max=%f\n", qk_max); - } - __syncthreads(); -#endif - - float sum = 0.f; - for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { - // bool is_mask = false; - // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); - float logit = __expf(qk_smem[ti] - qk_max); - sum += logit; - qk_smem[ti] = logit; - } - - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // FIXME(wangxi): need add 1.e-6f? - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { - convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); - } - __syncthreads(); - - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - using V_vec = typename V_vec_::Type; - - int vo = tid / THREADS_PER_VALUE; - int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; - - T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * - params.max_seq_length * Dh + - bhi * params.max_seq_length * Dh + vi]; - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - - V_vec_acum out; - zero(out); - - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { - V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti]; - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("======logits_out=====\n"); - for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); - printf("\n"); - } - __syncthreads(); -#endif - - V_vec v_bias; - zero(v_bias); - if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { - V_vec v = *reinterpret_cast( - ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); - v_bias = *reinterpret_cast( - ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); - v = add(v, v_bias); - *reinterpret_cast(&v_cache[params.timestep * Dh]) = v; - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - out = fma(logits_smem[params.timestep], cast_to_float(v), out); -#else - out = fma(logits_smem[params.timestep], v, out); -#endif - } - - __syncthreads(); - - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; - active_groups /= 2) { - int midpoint = active_groups / 2; - - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float( - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), - out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = - add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), - out); -#else - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - __syncthreads(); - if (bi == 0 && hi == 0 && tid == 0) { - printf("======fmha_out=====\n"); - for (int i = 0; i < Dh; ++i) - printf("%f ", static_cast(params.out[i])); - printf("\n"); - } -#endif -#else - assert(false); -#endif -} - -template -inline size_t smem_size_in_bytes( - const Masked_multihead_attention_params ¶ms, - int dim_head, - int threads_per_value, - int threads_per_block) { - size_t qk_sz = div_up(params.timestep + 1, 4) * 16; - size_t logits_sz = 0; - -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T); - } -#endif - size_t softmax_sz = qk_sz + logits_sz; - - int rows_per_red = threads_per_block / threads_per_value; - size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; - - return max(softmax_sz, red_sz); -} - -#define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, - const cudaStream_t &stream) { - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - if (params.timestep < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); - } else if (params.timestep < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); - } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); - } -} - -template -void fmha(const phi::GPUContext &dev_ctx, - const Tensor &qkv_tensor, - const Tensor &qkv_bias_tensor, - const Tensor &src_mask_tensor, - Tensor *cache_kv_tensor, - Tensor *out_tensor, - int batch_size, - int max_seq_length, - int num_head, - int dim_head, - int timestep, - float inv_sqrt_dh) { - Masked_multihead_attention_params params; - params.out = out_tensor->data(); - params.qkv = qkv_tensor.data(); - params.qkv_bias = qkv_bias_tensor.data(); - params.attn_mask = src_mask_tensor.data(); - params.cache_kv = cache_kv_tensor->data(); - - params.batch_size = batch_size; - params.num_head = num_head; - params.timestep = timestep; - params.max_seq_length = max_seq_length; - params.inv_sqrt_dh = inv_sqrt_dh; - - switch (dim_head) { - case 10: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 26: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 32: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 64: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 96: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 128: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 192: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Dim_head = %d is unsupport!", dim_head)); - } -} - -// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 -constexpr int VEC_16B = 16; - -template -__global__ void write_cache_k_kernel(T *cache_k, - const T *k, - const int num_head, - const int dim_head, - const int seq_len, - const int max_seq_len) { - const int bi = blockIdx.y; - const int hi = blockIdx.z; - constexpr int X_ELEMS = VEC_16B / sizeof(T); - - // [bsz, num_head, seq_len, dim_head/x, x] - auto k_src = reinterpret_cast( - k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); - // [bsz, num_head, dim_head/x, max_seq_len, x] - auto k_dst = reinterpret_cast( - cache_k + bi * num_head * max_seq_len * dim_head + - hi * max_seq_len * dim_head); - - const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; - // vec size - int dim_head_div_x = dim_head / X_ELEMS; - - // FIXME(wangxi): num_head is not need? - // if (out_idx >= num_head * dim_head_div_x * max_seq_len) return; - if (out_idx >= dim_head_div_x * max_seq_len) return; - - int idx = out_idx; - const int k_seq_len_id = idx % max_seq_len; - // idx = (idx - k_seq_len_id) / max_seq_len; - idx = idx / max_seq_len; - const int k_vec_id = idx % dim_head_div_x; - - if (k_seq_len_id < seq_len) { - k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id]; - } -} - -template -__global__ void write_cache_v_kernel(T *cache_v, - const T *v, - const int num_head, - const int dim_head, - const int seq_len, - const int max_seq_len) { - const int bi = blockIdx.y; - const int hi = blockIdx.z; - - // [bsz, num_head, seq_len, dim_head/x, x] - auto v_src = reinterpret_cast( - v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); - // [bsz, num_head, max_seq_len, dim_head/x, x] - auto v_dst = reinterpret_cast( - cache_v + bi * num_head * max_seq_len * dim_head + - hi * max_seq_len * dim_head); - - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - constexpr int X_ELEMS = VEC_16B / sizeof(T); - const int dim_head_div_x = dim_head / X_ELEMS; - - if (idx >= dim_head_div_x * seq_len) return; - - v_dst[idx] = v_src[idx]; -} - -template -void write_cache_kv(const phi::GPUContext &dev_ctx, - T *cache_k, - T *cache_v, - const T *k, - const T *v, - const int bsz, - const int num_head, - const int seq_len, - const int max_seq_len, - const int dim_head) { - constexpr int block_sz = 128; - constexpr int x = VEC_16B / sizeof(T); - - assert(dim_head % x == 0); - PADDLE_ENFORCE_EQ( - dim_head % x, - 0, - platform::errors::PreconditionNotMet( - "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); - - int max_size = max_seq_len * dim_head / x; - int size = seq_len * dim_head / x; - dim3 grid(div_up(max_size, block_sz), bsz, num_head); - dim3 grid_v(div_up(size, block_sz), bsz, num_head); - - // transpose [bsz, num_head, seq_len, dim_head/x, x]-> - // [bsz, num_head, dim_head/x, max_seq_len, x] - write_cache_k_kernel<<>>( - cache_k, k, num_head, dim_head, seq_len, max_seq_len); - - // copy [bsz, num_head, seq_len, dim_head/x, x]-> - // [bsz, num_head, max_seq_len, dim_head/x, x] - write_cache_v_kernel<<>>( - cache_v, v, num_head, dim_head, seq_len, max_seq_len); -} - -} // namespace - template class FusedMultiTransformerOpKernel : public framework::OpKernel { public: @@ -1480,11 +338,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { if (pre_layer_norm) { out_linear_compute.ComputeForward( out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, dev_ctx); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); } else { out_linear_compute.ComputeForward( out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, dev_ctx); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step4"; @@ -1563,9 +421,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { #endif if (pre_layer_norm) { - AllReduce(*buf1, ring_id, dev_ctx); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); } else { - AllReduce(*buf0, ring_id, dev_ctx); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.1"; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.h new file mode 100644 index 0000000000..761a31ce09 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.h @@ -0,0 +1,1161 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// This file has been adapted from FasterTransformer file: +// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu +// We add License in the head. + +#include +#include + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/fused/attention_layer_norm.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/fmha_ref.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// for debug +// #define _DEBUG_FUSED_MULTI_TRANSFORMER + +template +static void AllReduce(framework::Tensor &tensor, // NOLINT + const int ring_id, + const int count, + const phi::GPUContext &ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup *pg = map->get(ring_id); + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(tensor); + paddle::distributed::AllreduceOptions opts; + opts.reduce_op = distributed::ReduceOp::SUM; + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + const void *sendbuff = tensor.data(); + auto place = ctx.GetPlace(); + void *recvbuff = tensor.mutable_data(place); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +namespace { // NOLINT + +namespace plat = paddle::platform; +using float16 = plat::float16; + +#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#define MMHA_USE_FP32_ACUM_FOR_OUT + +template +struct Masked_multihead_attention_params { + // output buffer, [B, 1(seq_len), num_head * dim_head] + T *out; + // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] + const T *qkv; + // bias, [3, num_head, dim_head] + const T *qkv_bias; + // TODO(wangxi): optimize with input_lengths and max_input_len? + // [bsz, 1, 1, time_step(cache_seq_length)+1] + const T *attn_mask; + + // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] + // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first + // v [B, num_head, max_seq_len, dim_head] + T *cache_kv; + + int batch_size; + int num_head; + int timestep; // cache_seq_length + int max_seq_length; + + // 1.f / sqrt(Dh) + float inv_sqrt_dh; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// clang-format off + +template struct Qk_vec_ {}; +template <> struct Qk_vec_ { using Type = float; }; +template <> struct Qk_vec_ { using Type = float2; }; +template <> struct Qk_vec_ { using Type = float4; }; +template <> struct Qk_vec_ { using Type = float4; }; +template <> struct Qk_vec_ { using Type = uint32_t; }; +template <> struct Qk_vec_ { using Type = uint32_t; }; +template <> struct Qk_vec_ { using Type = uint2; }; +template <> struct Qk_vec_ { using Type = uint4; }; + +template struct K_vec_ {}; +template <> struct K_vec_ { using Type = float; }; +template <> struct K_vec_ { using Type = float2; }; +template <> struct K_vec_ { using Type = float4; }; +template <> struct K_vec_ { using Type = uint32_t; }; +template <> struct K_vec_ { using Type = uint2; }; +template <> struct K_vec_ { using Type = uint4; }; + +template struct V_vec_ {}; +template <> struct V_vec_ { using Type = float; }; +template <> struct V_vec_ { using Type = float2; }; +template <> struct V_vec_ { using Type = float4; }; +template <> struct V_vec_ { using Type = uint32_t; }; +template <> struct V_vec_ { using Type = uint2; }; +template <> struct V_vec_ { using Type = uint4; }; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template struct V_vec_acum_fp32_ {}; +// template <> struct V_vec_acum_fp32_ { using Type = float; }; +// template <> struct V_vec_acum_fp32_ { using Type = float2; }; +template <> struct V_vec_acum_fp32_ { using Type = float4; }; +// template <> struct V_vec_acum_fp32_ { using Type = float2; }; +// template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; +template <> struct V_vec_acum_fp32_ { using Type = Float8_; }; +#endif + +// clang-format on + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +template +inline __device__ Acc mul(A a, B b); + +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 tmp_res; + tmp_res.x = tmp.x * b; + tmp_res.y = tmp.y * b; + uint32_t res = float2_to_half2(tmp_res); + return res; +} + +template <> +inline __device__ uint2 mul(uint2 a, float b) { + uint2 res; + res.x = mul(a.x, b); + res.y = mul(a.y, b); + return res; +} + +template <> +inline __device__ uint4 mul(uint4 a, float b) { + uint4 res; + res.x = mul(a.x, b); + res.y = mul(a.y, b); + res.z = mul(a.z, b); + res.w = mul(a.w, b); + return res; +} + +template <> +inline __device__ float2 mul(float2 a, float b) { + float2 res; + res.x = a.x * b; + res.y = a.y * b; + return res; +} + +template <> +inline __device__ float4 mul(float4 a, float b) { + float4 res; + res.x = a.x * b; + res.y = a.y * b; + res.z = a.z * b; + res.w = a.w * b; + return res; +} + +inline __device__ float sum(float v) { return v; } +inline __device__ float sum(float2 v) { return v.x + v.y; } +inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } +inline __device__ float sum(uint16_t v) { return half_to_float(v); } +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +inline __device__ constexpr uint32_t shfl_mask(int threads) { + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float cast_to_float(float u) { return u; } + +inline __device__ float2 cast_to_float(float2 u) { return u; } + +inline __device__ float4 cast_to_float(float4 u) { return u; } + +inline __device__ Float8_ cast_to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + K_vec inv_q = mul(q[0], inv_sqrt_dh); + K_vec qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } + + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + return qk_dot_(q, k, inv_sqrt_dh); + } +}; + +template +inline __device__ float block_sum(float *red_smem, float sum) { + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + if (lane == 0) { + red_smem[warp] = sum; + } + __syncthreads(); + + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + return __shfl_sync(uint32_t(-1), sum, 0); +} + +inline __device__ void convert_from_float(float &dst, float src) { // NOLINT + dst = src; +} + +inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT + dst = src; +} + +inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT + float src) { + dst = static_cast(src); +} + +inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT + +template +inline __device__ void zero(T &dst) { // NOLINT + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +template +__global__ void masked_multihead_attention_kernel( + Masked_multihead_attention_params params) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + constexpr int WARP_SIZE = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + + char *logits_smem_ = smem_; + // fp32 accum for logits + float *logits_smem = reinterpret_cast(logits_smem_); + + T *out_smem = reinterpret_cast(smem_); + + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + using Qk_vec = typename Qk_vec_::Type; + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + + const int bi = blockIdx.y; + const int hi = blockIdx.x; + const int bhi = bi * params.num_head + hi; + const int tid = threadIdx.x; + + float qk_max = -FLT_MAX; + float qk = 0; + + // qkv [B, S=1, 3, num_head, head_dim] + int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; + + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // Use block reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // cache_k, [B, num_head, head_dim / x, max_seq_len, x] + // x == 4/8 for FP32/FP16, 128bit, 16Byte + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + const T *q_base = params.qkv; + const T *k_base = params.qkv + params.num_head * Dh; + const T *q_bias_base = params.qkv_bias; + const T *k_bias_base = params.qkv_bias + params.num_head * Dh; + + if (tid < QK_VECS_PER_WARP) { + int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; + int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + + Qk_vec q; + zero(q); + q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_base[qk_offset]) + : q; + Qk_vec k; + zero(k); + k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_base[qk_offset]) + : k; + + Qk_vec q_bias; + zero(q_bias); + q_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + : q_bias; + Qk_vec k_bias; + zero(k_bias); + k_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + : k_bias; + + q = add(q, q_bias); + // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 + // we may not require k_bias. + k = add(k, k_bias); + + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; + + int co = tid / QK_VECS_IN_16B; + int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; + int offset = bhi * params.max_seq_length * Dh + + co * params.max_seq_length * QK_ELTS_IN_16B + + params.timestep * QK_ELTS_IN_16B + ci; + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.cache_kv[offset]) = k; + } + + qk = dot(q, k); + + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = + (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + if (tid == 0) { + // NOTE(wangxi): mask must be 0.0 + // T mask = params.attn_mask[ + // bi * (params.timestep + 1) + params.timestep]; + // qk += static_cast(mask); + qk *= params.inv_sqrt_dh; + qk_max = qk; + qk_smem[params.timestep] = qk; + } + __syncthreads(); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + if (bi == 0 && hi == 0 && tid == 0) { + printf("=======q_out=======\n"); + for (int i = 0; i < Dh; ++i) printf("%f ", static_cast(q_smem[i])); + printf("\n"); + } + __syncthreads(); +#endif + + using K_vec = typename K_vec_::Type; + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + int ko = tid / THREADS_PER_KEY; + int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); + + K_vec q[K_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < K_VECS_PER_THREAD; ++i) { + q[i] = *reinterpret_cast( + &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); + } + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_seq_length + ti; + if (ti < params.timestep) { + k[ii] = + (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + ? *reinterpret_cast( + &k_cache[jj * QK_ELTS_IN_16B]) + : k_vec_zero; + } + } + + // NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k) + // may overflow with FP16 in large model. + float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); + + // bool is_mask = false; + if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { + // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; + qk += static_cast(mask); + qk_max = fmaxf(qk_max, qk); + + qk_smem[ti] = qk; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + const int warp = tid / WARP_SIZE; + const int lane = tid % WARP_SIZE; + + if (lane == 0) { + red_smem[warp] = qk_max; + } + + __syncthreads(); + + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + if (bi == 0 && hi == 0 && tid == 0) { + printf("=======qk_out=======\n"); + for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); + printf("qk_max=%f\n", qk_max); + } + __syncthreads(); +#endif + + float sum = 0.f; + for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + // bool is_mask = false; + // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); + float logit = __expf(qk_smem[ti] - qk_max); + sum += logit; + qk_smem[ti] = logit; + } + + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // FIXME(wangxi): need add 1.e-6f? + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); + } + __syncthreads(); + + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + using V_vec = typename V_vec_::Type; + + int vo = tid / THREADS_PER_VALUE; + int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; + + T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * + params.max_seq_length * Dh + + bhi * params.max_seq_length * Dh + vi]; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + + V_vec_acum out; + zero(out); + + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { + V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti]; + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + if (bi == 0 && hi == 0 && tid == 0) { + printf("======logits_out=====\n"); + for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); + printf("\n"); + } + __syncthreads(); +#endif + + V_vec v_bias; + zero(v_bias); + if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { + V_vec v = *reinterpret_cast( + ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); + v_bias = *reinterpret_cast( + ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + v = add(v, v_bias); + *reinterpret_cast(&v_cache[params.timestep * Dh]) = v; + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + out = fma(logits_smem[params.timestep], cast_to_float(v), out); +#else + out = fma(logits_smem[params.timestep], v, out); +#endif + } + + __syncthreads(); + + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + int midpoint = active_groups / 2; + + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float( + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), + out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = + add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), + out); +#else + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; +#endif + } + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + __syncthreads(); + if (bi == 0 && hi == 0 && tid == 0) { + printf("======fmha_out=====\n"); + for (int i = 0; i < Dh; ++i) + printf("%f ", static_cast(params.out[i])); + printf("\n"); + } +#endif +#else + assert(false); +#endif +} + +template +inline size_t smem_size_in_bytes( + const Masked_multihead_attention_params ¶ms, + int dim_head, + int threads_per_value, + int threads_per_block) { + size_t qk_sz = div_up(params.timestep + 1, 4) * 16; + size_t logits_sz = 0; + +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS // NOLINT + if (sizeof(T) != 4) { + logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T); + } +#endif // NOLINT + size_t softmax_sz = qk_sz + logits_sz; + + int rows_per_red = threads_per_block / threads_per_value; + size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; + + return max(softmax_sz, red_sz); +} + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_head, params.batch_size); \ + masked_multihead_attention_kernel \ + <<>>(params) + +template +void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, + const cudaStream_t &stream) { + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + if (params.timestep < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); + } else if (params.timestep < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); + } else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); + } +} + +template +void fmha(const phi::GPUContext &dev_ctx, + const Tensor &qkv_tensor, + const Tensor &qkv_bias_tensor, + const Tensor &src_mask_tensor, + Tensor *cache_kv_tensor, + Tensor *out_tensor, + int batch_size, + int max_seq_length, + int num_head, + int dim_head, + int timestep, + float inv_sqrt_dh) { + Masked_multihead_attention_params params; + params.out = out_tensor->data(); + params.qkv = qkv_tensor.data(); + params.qkv_bias = qkv_bias_tensor.data(); + params.attn_mask = src_mask_tensor.data(); + params.cache_kv = cache_kv_tensor->data(); + + params.batch_size = batch_size; + params.num_head = num_head; + params.timestep = timestep; + params.max_seq_length = max_seq_length; + params.inv_sqrt_dh = inv_sqrt_dh; + + switch (dim_head) { + case 10: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 26: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 32: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 64: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 96: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 128: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 192: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Dim_head = %d is unsupport!", dim_head)); + } +} + +// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 +constexpr int VEC_16B = 16; + +template +__global__ void write_cache_k_kernel(T *cache_k, + const T *k, + const int num_head, + const int dim_head, + const int seq_len, + const int max_seq_len) { + const int bi = blockIdx.y; + const int hi = blockIdx.z; + constexpr int X_ELEMS = VEC_16B / sizeof(T); + + // [bsz, num_head, seq_len, dim_head/x, x] + auto k_src = reinterpret_cast( + k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); + // [bsz, num_head, dim_head/x, max_seq_len, x] + auto k_dst = reinterpret_cast( + cache_k + bi * num_head * max_seq_len * dim_head + + hi * max_seq_len * dim_head); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + // vec size + int dim_head_div_x = dim_head / X_ELEMS; + + // FIXME(wangxi): num_head is not need? + // if (out_idx >= num_head * dim_head_div_x * max_seq_len) return; + if (out_idx >= dim_head_div_x * max_seq_len) return; + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + // idx = (idx - k_seq_len_id) / max_seq_len; + idx = idx / max_seq_len; + const int k_vec_id = idx % dim_head_div_x; + + if (k_seq_len_id < seq_len) { + k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id]; + } +} + +template +__global__ void write_cache_v_kernel(T *cache_v, + const T *v, + const int num_head, + const int dim_head, + const int seq_len, + const int max_seq_len) { + const int bi = blockIdx.y; + const int hi = blockIdx.z; + + // [bsz, num_head, seq_len, dim_head/x, x] + auto v_src = reinterpret_cast( + v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); + // [bsz, num_head, max_seq_len, dim_head/x, x] + auto v_dst = reinterpret_cast( + cache_v + bi * num_head * max_seq_len * dim_head + + hi * max_seq_len * dim_head); + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + constexpr int X_ELEMS = VEC_16B / sizeof(T); + const int dim_head_div_x = dim_head / X_ELEMS; + + if (idx >= dim_head_div_x * seq_len) return; + + v_dst[idx] = v_src[idx]; +} + +template +void write_cache_kv(const phi::GPUContext &dev_ctx, + T *cache_k, + T *cache_v, + const T *k, + const T *v, + const int bsz, + const int num_head, + const int seq_len, + const int max_seq_len, + const int dim_head) { + constexpr int block_sz = 128; + constexpr int x = VEC_16B / sizeof(T); + + assert(dim_head % x == 0); + PADDLE_ENFORCE_EQ( + dim_head % x, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); + + int max_size = max_seq_len * dim_head / x; + int size = seq_len * dim_head / x; + dim3 grid(div_up(max_size, block_sz), bsz, num_head); + dim3 grid_v(div_up(size, block_sz), bsz, num_head); + + // transpose [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, dim_head/x, max_seq_len, x] + write_cache_k_kernel<<>>( + cache_k, k, num_head, dim_head, seq_len, max_seq_len); + + // copy [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, max_seq_len, dim_head/x, x] + write_cache_v_kernel<<>>( + cache_v, v, num_head, dim_head, seq_len, max_seq_len); +} + +} // namespace + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index c1131cae5d..f162d200ab 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -28,7 +28,9 @@ template + typename Functor, + typename InType = T, + typename OutType = T> __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, @@ -36,30 +38,45 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, - const T *__restrict__ src, + const InType *__restrict__ src, const T *__restrict__ residual, const T *__restrict__ bias, - T *dst, + OutType *dst, MaskType *mask, const bool is_test, typename details::MPTypeTrait::Type *mean_val, typename details::MPTypeTrait::Type *var_val, - Functor act_func) { + Functor act_func, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { using LoadT = phi::AlignedVector; + using LoadInType = phi::AlignedVector; + using LoadFloat = phi::AlignedVector; using StoreT = phi::AlignedVector; + using StoreOutType = phi::AlignedVector; + using MaskStoreT = phi::AlignedVector; using U = typename details::MPTypeTrait::Type; - LoadT src_vec; + LoadInType src_vec; LoadT residual_vec; LoadT bias_vec; + LoadFloat quant_out_scale_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { bias_vec[ii] = static_cast(0); residual_vec[ii] = static_cast(0); } // vectorize load data from global - phi::Load(&src[row_id * cols + col_id], &src_vec); + phi::Load(&src[row_id * cols + col_id], &src_vec); + phi::Load( + &dequant_out_scale_data[quant_out_scale_offset + col_id], + &quant_out_scale_vec); if (residual) { phi::Load(&residual[row_id * cols + col_id], &residual_vec); } @@ -84,10 +101,18 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( } StoreT dest_vec; + StoreOutType dest_vec_out_type; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { - T tmp = src_vec[ii] + bias_vec[ii]; + T tmp; + if (std::is_same::value) { + T tmp0 = static_cast(static_cast(src_vec[ii]) * + quant_last_in_scale / quant_out_scale_vec[ii]); + tmp = tmp0 + bias_vec[ii]; + } else { + tmp = static_cast(src_vec[ii]) + bias_vec[ii]; + } if (Activation) { tmp = act_func(tmp); } @@ -98,10 +123,23 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( *mean_val += tmp; *var_val += (tmp * tmp); } + if (std::is_same::value) { + dest_vec_out_type[ii] = quant_helper(dest_vec[ii], + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } } // store result to global - phi::Store(dest_vec, &dst[row_id * cols + col_id]); + if (std::is_same::value) { + phi::Store(dest_vec_out_type, + &dst[row_id * cols + col_id]); + } else { + phi::Store(dest_vec, + reinterpret_cast(&dst[row_id * cols + col_id])); + } if (!is_test) { phi::Store(mask_vec, &mask[row_id * cols + col_id]); } @@ -114,19 +152,28 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( * is_test: only used in inference * mask: can be null if is_test=true */ -template -__global__ void FusedResidualDropoutBias(const size_t rows, - const size_t cols, - uint64_t seed, - const float dropout_prob, - const bool is_upscale_in_train, - const T *__restrict__ src, - const T *__restrict__ residual, - const T *__restrict__ bias, - MaskType *mask, - T *dst, - uint64_t increment, - const bool is_test) { +template +__global__ void FusedResidualDropoutBias( + const size_t rows, + const size_t cols, + uint64_t seed, + const float dropout_prob, + const bool is_upscale_in_train, + const InType *__restrict__ src, + const T *__restrict__ residual, + const T *__restrict__ bias, + MaskType *mask, + OutType *dst, + uint64_t increment, + const bool is_test, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -142,22 +189,27 @@ __global__ void FusedResidualDropoutBias(const size_t rows, VecSize, false, false, - phi::funcs::ReluFunctor>( - r, - i, - cols, - &state, - dropout_prob, - factor, - src, - residual, - bias, - dst, - mask, - is_test, - nullptr, - nullptr, - relu); + phi::funcs::ReluFunctor, + InType, + OutType>(r, + i, + cols, + &state, + dropout_prob, + factor, + src, + residual, + bias, + dst, + mask, + is_test, + nullptr, + nullptr, + relu, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); } } } @@ -165,7 +217,10 @@ __global__ void FusedResidualDropoutBias(const size_t rows, /** * @brief dst = residual + dropout(src + bias); */ -template +template void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, const int increment, @@ -173,14 +228,19 @@ void LaunchResidualDropoutBias(const uint32_t rows, const float dropout_prob, const bool is_test, bool is_upscale_in_train, - const T *src, + const InType *src, const T *residual, const T *bias, MaskType *mask_data, - T *dst, - const phi::GPUContext &ctx) { + OutType *dst, + const phi::GPUContext &ctx, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_next_in_scale = 1.0) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { + // NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0 if (residual == dst) return; if (residual) { memory::Copy(ctx.GetPlace(), @@ -202,7 +262,7 @@ void LaunchResidualDropoutBias(const uint32_t rows, const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); if (cols % VecSize == 0) { - FusedResidualDropoutBias + FusedResidualDropoutBias <<>>( rows, cols, @@ -215,9 +275,13 @@ void LaunchResidualDropoutBias(const uint32_t rows, mask_data, dst, increment, - is_test); + is_test, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); } else { - FusedResidualDropoutBias + FusedResidualDropoutBias <<>>( rows, cols, @@ -230,7 +294,11 @@ void LaunchResidualDropoutBias(const uint32_t rows, mask_data, dst, increment, - is_test); + is_test, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); } } diff --git a/paddle/fluid/operators/fused/quant_dequant_kernel.h b/paddle/fluid/operators/fused/quant_dequant_kernel.h new file mode 100644 index 0000000000..21b7b0f345 --- /dev/null +++ b/paddle/fluid/operators/fused/quant_dequant_kernel.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/fake_quantize_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +__forceinline__ __device__ int8_t quant_helper(const T input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * inverse(scale) * static_cast(input); + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + return static_cast(quant_value); +} + +template +__global__ void quantize_kernel(const T* input, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + char4 tmp; + tmp.x = quant_helper( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + tmp.y = quant_helper( + input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); + tmp.z = quant_helper( + input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); + tmp.w = quant_helper( + input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); + output[(m_id * n + n_id) >> 2] = tmp; + } +} + +template +void quantize_kernel_launcher(const T* input, + int8_t* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound, + gpuStream_t stream) { + // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 + dim3 grid((n + 31) / 32, (m + 31) / 32); + dim3 block(32, 32); + + quantize_kernel<<>>(input, + (char4*)output, // NOLINT + scale, + m, + n, + round_type, + max_bound, + min_bound); +} + +// dequantize using weight scales and input scales +template +__global__ void dequantize_kernel(T* output, + const int32_t* input, + const int m, // hidden + const int n, // batch size + const float quant_in_scale, + const float* dequant_out_scale_data, + const int quant_out_scale_offset) { + int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden + int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id]; + output[n_id * m + m_id] = + static_cast(static_cast(input[n_id * m + m_id]) * + quant_in_scale / out_scale); + } +} + +template +void dequantize_kernel_launcher(const int32_t* input, + T* output, + const int batch_size, // m + const int hidden_units, // n + gpuStream_t stream, + const float quant_in_scale, + const float* dequant_out_scale_data, + const int quant_out_scale_offset) { + dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32); + dim3 block(32, 32); + + dequantize_kernel<<>>(output, + input, + hidden_units, + batch_size, + quant_in_scale, + dequant_out_scale_data, + quant_out_scale_offset); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 0c41429c61..899eae3efb 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -24,6 +24,7 @@ namespace cub = hipcub; #include +#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/core/ddim.h" @@ -338,16 +339,24 @@ using LayerNormScaleBiasT = template + bool ScaleBiasWithSameTypeX = false, + typename InType = T, + typename OutType = T> __global__ void LayerNormForward( - const T *x, + const InType *x, const LayerNormScaleBiasT *scale, const LayerNormScaleBiasT *bias, - T *y, + OutType *y, U *mean, U *var, float epsilon, - int64_t feature_size) { + int64_t feature_size, + const float *dequant_out_scale_data = nullptr, + const int quant_out_scale_offset = 0, + const float quant_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { __shared__ U mean_share; __shared__ U var_share; __shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim / @@ -387,28 +396,72 @@ __global__ void LayerNormForward( if (bias != nullptr) { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast(static_cast(scale[j]) * - (static_cast(x[i]) - mean_val) * invvar + - static_cast(bias[j])); + if (std::is_same::value) { + y[i] = quant_helper( + static_cast(static_cast(scale[j]) * + (static_cast(x[i]) - mean_val) * invvar + + static_cast(bias[j])), + quant_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + y[i] = static_cast(static_cast(scale[j]) * + (static_cast(x[i]) - mean_val) * + invvar + + static_cast(bias[j])); + } } } else { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast(static_cast(scale[j]) * - (static_cast(x[i]) - mean_val) * invvar); + if (std::is_same::value) { + y[i] = quant_helper( + static_cast(static_cast(scale[j]) * + (static_cast(x[i]) - mean_val) * invvar), + quant_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + y[i] = + static_cast(static_cast(scale[j]) * + (static_cast(x[i]) - mean_val) * invvar); + } } } } else { // scale == nullptr if (bias != nullptr) { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + - static_cast(bias[j])); + if (std::is_same::value) { + y[i] = quant_helper( + static_cast((static_cast(x[i]) - mean_val) * invvar + + static_cast(bias[j])), + quant_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + y[i] = + static_cast((static_cast(x[i]) - mean_val) * invvar + + static_cast(bias[j])); + } } } else { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar); + if (std::is_same::value) { + y[i] = quant_helper( + static_cast((static_cast(x[i]) - mean_val) * invvar), + quant_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + y[i] = + static_cast((static_cast(x[i]) - mean_val) * invvar); + } } } } diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index 3a1d28072c..c3425ac604 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -40,26 +40,28 @@ namespace dynload { // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatmulDescGetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixLayoutGetAttribute); \ - __macro(cublasLtMatmulPreferenceCreate); \ - __macro(cublasLtMatmulPreferenceDestroy); \ - __macro(cublasLtMatmulPreferenceSetAttribute); \ - __macro(cublasLtMatmulAlgoGetHeuristic); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ - __macro(cublasLtMatrixTransformDescSetAttribute); +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); \ + __macro(cublasLtMatmulAlgoInit); \ + __macro(cublasLtMatmulAlgoConfigSetAttribute); CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index ba0f872cb7..af080bd0b3 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,6 +71,12 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_multi_transformer_int8", + {"X", "LnScale", "LnBias", "QKVW", + "QKVBias", "CacheKV", "TimeStep", "SrcMask", + "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias", + "FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias", + "QKVOutScale", "OutLinearOutScale", "FFN1OutScale", "FFN2OutScale"}}, {"fused_bias_dropout_residual_layer_norm", {"X", "Residual", "Bias", "LnScale", "LnBias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, @@ -329,6 +335,7 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, + {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, {"resnet_basic_block", {"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out", "Var1Out", "Conv2", "SavedMean2", "SavedInvstd2", "Mean2Out", @@ -433,6 +440,7 @@ std::map> op_passing_outs_map = { {"split", {"Out"}}, {"concat", {"Out"}}, {"fused_multi_transformer", {"CacheKVOut"}}, + {"fused_multi_transformer_int8", {"CacheKVOut"}}, {"group_norm", {"Mean", "Variance"}}, {"resnet_basic_block", {"Mean1Out", "Var1Out", "Mean2Out", "Var2Out", "Mean3Out", "Var3Out"}}, diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 1e2a20ebdf..90492ff4ba 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -54,26 +54,28 @@ extern void *cublasLt_dso_handle; // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatmulDescGetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixLayoutGetAttribute); \ - __macro(cublasLtMatmulPreferenceCreate); \ - __macro(cublasLtMatmulPreferenceDestroy); \ - __macro(cublasLtMatmulPreferenceSetAttribute); \ - __macro(cublasLtMatmulAlgoGetHeuristic); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ - __macro(cublasLtMatrixTransformDescSetAttribute); +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); \ + __macro(cublasLtMatmulAlgoInit); \ + __macro(cublasLtMatmulAlgoConfigSetAttribute); CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index 36a7869595..b804e93058 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -326,7 +326,7 @@ void* GetCublasDsoHandle() { void* GetCublasLtDsoHandle() { // APIs available after CUDA 10.1 -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10100 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10010 return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublasLt.so"); #else std::string warning_msg( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index bdd6d375bf..1daf55e630 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -72,6 +72,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_attention_op) list(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op) + list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) @@ -141,6 +142,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_complex_matmul) list(REMOVE_ITEM TEST_OPS test_ops_nms) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) + list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) endif() list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) @@ -1202,6 +1204,10 @@ endif() if(WITH_GPU OR WITH_ROCM) set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120) endif() +if(WITH_GPU AND NOT WIN32) + set_tests_properties(test_fused_multi_transformer_int8_op PROPERTIES TIMEOUT + 60) +endif() set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120) set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) set_tests_properties( diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py new file mode 100644 index 0000000000..00f25b4570 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py @@ -0,0 +1,792 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +import paddle.incubate.nn.functional as incubate_f +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle import tensor +from paddle.fluid import layers +import unittest +from op_test import OpTest +from paddle.fluid.framework import default_main_program +from paddle.fluid.dygraph.layers import Layer +from paddle.fluid.layer_helper import LayerHelper +from paddle.nn.initializer import Constant +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid.framework import _non_static_mode, default_main_program +from paddle import _legacy_C_ops + +default_main_program().random_seed = 42 +np.random.seed(0) + + +def fused_multi_transformer_int8( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm=True, + epsilon=1e-05, + cache_kvs=None, + time_step=None, + attn_mask=None, + dropout_rate=0.0, + activation="gelu", + training=False, + mode='upscale_in_train', + trans_qkvw=True, + ring_id=-1, + name=None, + qkv_out_scales=None, + out_linear_out_scales=None, + ffn1_out_scales=None, + ffn2_out_scales=None, + num_head=0, + dim_head=0, + dim_ffn=0, + qkv_in_scale=[], + out_linear_in_scale=[], + ffn1_in_scale=[], + ffn2_in_scale=[], +): + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_int8( + x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs, time_step, + attn_mask, linear_weights, linear_biases, ffn_ln_scales, ffn_ln_biases, + ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, qkv_out_scales, + out_linear_out_scales, ffn1_out_scales, ffn2_out_scales, cache_kvs, + 'num_head', num_head, 'dim_head', dim_head, 'dim_ffn', dim_ffn, + 'qkv_in_scale', qkv_in_scale, 'out_linear_in_scale', + out_linear_in_scale, 'ffn1_in_scale', ffn1_in_scale, 'ffn2_in_scale', + ffn2_in_scale, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_rate', dropout_rate, 'is_test', not training, + 'dropout_implementation', mode, 'act_method', activation, 'trans_qkvw', + trans_qkvw, 'ring_id', ring_id) + if cache_kvs is not None: + return final_out, cache_kv_out + return final_out + + +class TestFusedMultiTransformerInt8Op(unittest.TestCase): + + def setUp(self): + self.config() + self.generate_input_data() + + self.rtol = 1e-5 + # FIXME(wangxi): Because there is a problem with the test precision + # on A100, atol is temporarily set to 1e-2, and it will be + # changed back after the precision problem is solved. + self.atol = 1e-2 + # make sure local development precision + if "V100" in paddle.device.cuda.get_device_name(): + self.atol = 1e-4 + if self.x_type is np.float16: + self.atol = 1e-1 + + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_multi_transformer_int8" + # use autograd to check grad in this unittest. + self.__class__.no_need_check_grad = True + + paddle.set_default_dtype(np.float32) + self.norm = LayerNorm(self.embed_dim, + weight_attr=False, + bias_attr=False) + self.ffn_norm = LayerNorm(self.embed_dim, + weight_attr=False, + bias_attr=False) + + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + self.activation = getattr(F, self.act_method) + + def config(self): + # for debug + self.debug = False + + self.x_type = np.float32 + self.attn_mask_type = np.float64 + #self.attn_mask_type = np.bool + self.pre_layer_norm = True + self.has_attn_mask = True + + # has_cache_kv, gen_cache_kv, stage + # False, False, not generation + # True, True, generation context stage + # True, False, generation decoder stage + self.has_cache_kv = False + self.gen_cache_kv = False + + self.training = False + + self.layers = 3 + self.batch_size = 1 + self.query_length = 1 + self.cache_length = 1 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.act_method = 'gelu' + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + q_weight = np.random.randint(-64, 64, [self.embed_dim, self.embed_dim], + np.int32).astype('float64') + k_weight = np.random.randint(-64, 64, [self.kdim, self.embed_dim], + np.int32).astype('float64') + v_weight = np.random.randint(-64, 64, [self.vdim, self.embed_dim], + np.int32).astype('float64') + + self.q_weight_tensor = paddle.to_tensor(q_weight) + self.k_weight_tensor = paddle.to_tensor(k_weight) + self.v_weight_tensor = paddle.to_tensor(v_weight) + + out_weight = np.random.randint(-64, 64, + [self.embed_dim, self.embed_dim], + np.int32).astype('float64') + ffn1_weight = np.random.randint(-64, 64, + [self.embed_dim, 4 * self.embed_dim], + np.int32).astype('float64') + ffn2_weight = np.random.randint(-64, 64, + [4 * self.embed_dim, self.embed_dim], + np.int32).astype('float64') + + self.out_weight_tensor = paddle.to_tensor(out_weight) + self.ffn1_weight_tensor = paddle.to_tensor(ffn1_weight) + self.ffn2_weight_tensor = paddle.to_tensor(ffn2_weight) + + q_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type) + k_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type) + v_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type) + + self.q_proj_bias_tensor = paddle.to_tensor(q_proj_bias) + self.k_proj_bias_tensor = paddle.to_tensor(k_proj_bias) + self.v_proj_bias_tensor = paddle.to_tensor(v_proj_bias) + + out_linear_proj_bias = np.random.rand(self.embed_dim).astype( + self.x_type) + ffn1_proj_bias = np.random.rand(4 * self.embed_dim).astype(self.x_type) + ffn2_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type) + + self.out_linear_proj_bias_tensor = paddle.to_tensor( + out_linear_proj_bias) + self.ffn1_proj_bias_tensor = paddle.to_tensor(ffn1_proj_bias) + self.ffn2_proj_bias_tensor = paddle.to_tensor(ffn2_proj_bias) + + out_seq_len = self.key_length + + self.qkv_in_scales = [] + self.qkv_out_scales = [] + self.out_linear_in_scales = [] + self.out_linear_out_scales = [] + self.ffn1_in_scales = [] + self.ffn1_out_scales = [] + self.ffn2_in_scales = [] + self.ffn2_out_scales = [] + + if self.has_cache_kv: + self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads, + self.cache_length, + self.head_dim).astype(self.x_type) + + if self.gen_cache_kv: + self.cache_kv[:] = 0 + else: + out_seq_len += self.cache_length + else: + self.cache_kv = None + + if self.has_attn_mask: + # [B, n_head, seq_len, out_seq_len] + self.attn_mask = np.ones( + (self.batch_size, 1, self.query_length, out_seq_len), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + if self.has_cache_kv and not self.gen_cache_kv: + # NOTE: decoder stage, -1(out_seq_len) should no mask + self.attn_mask[:, :, :, -2] = 0.0 + self.attn_mask = (self.attn_mask - 1.0) * 1e4 + else: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e4 + elif self.attn_mask_type == np.bool_: + if self.has_cache_kv and not self.gen_cache_kv: + self.attn_mask[:, :, :, -2] = 0 + else: + self.attn_mask = np.tril(self.attn_mask) + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'.") + else: + self.attn_mask = None + + def fake_quant(self, input, scale): + quant_value = 127.0 * (1.0 / scale) * paddle.cast(input, 'float32') + quant_value = paddle.round(quant_value) + + # No need to clip here because scale is the max value + + return paddle.cast(quant_value, 'float64') + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + + cache_kvs = [] + cache_kv = None + if self.has_cache_kv: + cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) + + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None + for i in range(self.layers): + residual = tensor_query + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm(tensor_query) + max_v = paddle.max(paddle.abs(paddle.cast(ln1_out, 'float32')))[0] + # self.qkv_in_scales.append(127.0 / max_v) + self.qkv_in_scales.append(max_v) + self.qkv_out_scales.append(127.0 * 127.0) + # print('qkv_in_scales ', i, self.qkv_in_scales[i]) + # print('qkv_out_scales ', i, self.qkv_out_scales[i]) + + # quant ln1_out + ln1_out = self.fake_quant(ln1_out, self.qkv_in_scales[i]) + + q = paddle.nn.functional.linear(ln1_out, self.q_weight_tensor) + # de quant + q = paddle.cast( + paddle.cast(q, 'float32') * self.qkv_in_scales[i] / + self.qkv_out_scales[i], self.x_type) + + q = q + self.q_proj_bias_tensor + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = paddle.nn.functional.linear(ln1_out, self.k_weight_tensor) + k = paddle.cast( + paddle.cast(k, 'float32') * self.qkv_in_scales[i] / + self.qkv_out_scales[i], self.x_type) + k = k + self.k_proj_bias_tensor + v = paddle.nn.functional.linear(ln1_out, self.v_weight_tensor) + v = paddle.cast( + paddle.cast(v, 'float32') * self.qkv_in_scales[i] / + self.qkv_out_scales[i], self.x_type) + v = v + self.v_proj_bias_tensor + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + if self.has_cache_kv: + # [1, B, n_head, cache_seq_len, head_dim] + cache_k, cache_v = paddle.split(cache_kv, 2) + cache_k = paddle.squeeze(cache_k, axis=0) + cache_v = paddle.squeeze(cache_v, axis=0) + # [B, n_head, cache_seq_len + seq_len, head_dim] + # out_seq_len = cache_seq_len + seq_len + if self.debug: + print('q out is') + print(q_out[0, 0, :, :]) + print('cache k out seq=128') + print(k_out[0, 0, :, :]) + if self.gen_cache_kv: + cache_kvs.append((k_out, v_out)) + else: + k_out = paddle.concat([cache_k, k_out], axis=-2) + v_out = paddle.concat([cache_v, v_out], axis=-2) + + # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, out_seq_len] + qk_out = layers.matmul(x=q_out, + y=k_out, + transpose_y=True, + alpha=self.head_dim**-0.5) + + if self.debug: + print('qk out is') + print(qk_out[0][0][0]) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + if self.debug: + print('attn mask out is') + print(attn_mask_out[0][0][0]) + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.debug: + print('softmax out is') + print(softmax_out[0][0][0]) + if self.dropout_prob: + dropout_out = F.dropout(softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, head_dim] + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + if self.debug: + print('fmha out is') + print(fmha_out[0][0][0]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + + max_v = paddle.max(paddle.abs(paddle.cast(out_linear_in, + 'float32')))[0] + # self.out_linear_in_scales.append(127.0 / max_v) + + self.out_linear_in_scales.append(max_v) + self.out_linear_out_scales.append((127.0 * 127.0)) + out_linear_in = self.fake_quant(out_linear_in, + self.out_linear_in_scales[i]) + + out = paddle.nn.functional.linear(out_linear_in, + self.out_weight_tensor) + + out = paddle.cast( + paddle.cast(out, 'float32') * self.out_linear_in_scales[i] / + self.out_linear_out_scales[i], self.x_type) + + out = out + self.out_linear_proj_bias_tensor + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + attn_out = self.norm(residual_out) + else: + attn_out = residual_out + + ffn_ln_out = attn_out + if self.pre_layer_norm: + ffn_ln_out = self.ffn_norm(attn_out) + + max_v = paddle.max(paddle.abs(paddle.cast(ffn_ln_out, + 'float32')))[0] + self.ffn1_in_scales.append(max_v) + self.ffn1_out_scales.append((127.0 * 127.0)) + ffn_ln_out = self.fake_quant(ffn_ln_out, self.ffn1_in_scales[i]) + + ffn1_out = paddle.nn.functional.linear(ffn_ln_out, + self.ffn1_weight_tensor) + + ffn1_out = paddle.cast( + paddle.cast(ffn1_out, 'float32') * self.ffn1_in_scales[i] / + self.ffn1_out_scales[i], self.x_type) + + ffn1_out = ffn1_out + self.ffn1_proj_bias_tensor + ffn1_out = self.dropout(self.activation(ffn1_out)) + + max_v = paddle.max(paddle.abs(paddle.cast(ffn1_out, 'float32')))[0] + # self.ffn2_in_scales.append(127.0 / max_v) + self.ffn2_in_scales.append(max_v) + self.ffn2_out_scales.append((127.0 * 127.0)) + # print('ffn2_in_scales ', i, self.ffn2_in_scales[i]) + ffn1_out = self.fake_quant(ffn1_out, self.ffn2_in_scales[i]) + + ffn2_out = paddle.nn.functional.linear(ffn1_out, + self.ffn2_weight_tensor) + + ffn2_out = paddle.cast( + paddle.cast(ffn2_out, 'float32') * self.ffn2_in_scales[i] / + self.ffn2_out_scales[i], self.x_type) + ffn2_out = ffn2_out + self.ffn2_proj_bias_tensor + + residual_out = attn_out + self.dropout(ffn2_out) + # print("residual ", attn_out) + # print("residual_out ", residual_out) + final_out = residual_out + if not self.pre_layer_norm: + final_out = self.ffn_norm(residual_out) + + tensor_query = final_out + + if self.has_cache_kv and self.gen_cache_kv: + return final_out, cache_kvs + return final_out + + def GetFusedMultiTransformerOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + + ln_scale = paddle.ones([self.embed_dim], 'float32') + ln_bias = paddle.zeros([self.embed_dim], 'float32') + ffn_ln_scale = ln_scale + ffn_ln_bias = ln_bias + + q_proj_weight = self.q_weight_tensor.numpy().transpose((1, 0)) + k_proj_weight = self.k_weight_tensor.numpy().transpose((1, 0)) + v_proj_weight = self.v_weight_tensor.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_weight_tensor = paddle.to_tensor(qkv_weight) + qkv_weight_tensor = paddle.cast(qkv_weight_tensor, 'int8') + + out_weight_tensor = paddle.cast( + paddle.to_tensor(self.out_weight_tensor.numpy().transpose((1, 0))), + 'int8') + ffn1_weight_tensor = paddle.cast( + paddle.to_tensor(self.ffn1_weight_tensor.numpy().transpose((1, 0))), + 'int8') + ffn2_weight_tensor = paddle.cast( + paddle.to_tensor(self.ffn2_weight_tensor.numpy().transpose((1, 0))), + 'int8') + + qkv_bias = np.concatenate( + (self.q_proj_bias_tensor.numpy(), self.k_proj_bias_tensor.numpy(), + self.v_proj_bias_tensor.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + qkv_bias_tensor = paddle.to_tensor(qkv_bias) + + x = paddle.to_tensor(self.query, stop_gradient=True) + cache_kvs, cache_kv = None, None + time_step = None + if self.has_cache_kv: + cache_kvs = [] + + max_seq_length = (self.cache_length + 128) // 128 * 128 + cache_kv = np.zeros([ + 2, self.batch_size, self.num_heads, max_seq_length, + self.head_dim + ], + dtype=self.x_type) + + elems = 4 + if self.x_type is np.float16: + elems = 8 + + assert self.head_dim % elems == 0 + v_elems = self.head_dim // elems + + # [B, num_head, 128, head_dim] + # cache_k_tmp = self.cache_kv[0, :] + # [B, num_head, 128, head_dim / 4, 4] + cache_k_tmp = self.cache_kv[0].reshape([ + self.batch_size, self.num_heads, self.cache_length, v_elems, + elems + ]) + # [B, num_head, head_dim / 4, 128, 4] + cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4]) + + cache_kv[0, :].reshape([ + self.batch_size, self.num_heads, v_elems, max_seq_length, elems + ])[:, :, :, :self.cache_length, :] = cache_k_tmp + + cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1] + if self.gen_cache_kv: + assert self.query_length == self.cache_length + cache_kv[:] = 0 + else: + time_step = paddle.to_tensor([self.cache_length], + dtype='int32', + place=paddle.CPUPlace()) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) + else: + attn_mask = None + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None and self.attn_mask_type != np.bool_: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + + qkv_weights, qkv_biases = [], [] + out_weights, out_biases = [], [] + ln_scales, ln_biases = [], [] + ffn1_weights, ffn1_biases = [], [] + ffn2_weights, ffn2_biases = [], [] + ffn_ln_scales, ffn_ln_biases = [], [] + qkv_in_scale = [] + out_linear_in_scale = [] + ffn1_in_scale = [] + ffn2_in_scale = [] + + qkv_out_scales_tensor = paddle.ones([self.layers, 3 * self.embed_dim], + 'float32') + out_linear_out_scales_tensor = paddle.ones( + [self.layers, self.embed_dim], 'float32') + ffn1_out_scales_tensor = paddle.ones([self.layers, 4 * self.embed_dim], + 'float32') + ffn2_out_scales_tensor = paddle.ones([self.layers, self.embed_dim], + 'float32') + + for i in range(self.layers): + qkv_weights.append(qkv_weight_tensor) + qkv_biases.append(qkv_bias_tensor) + out_weights.append(out_weight_tensor) + out_biases.append(self.out_linear_proj_bias_tensor) + ln_scales.append(ln_scale) + ln_biases.append(ln_bias) + ffn1_weights.append(ffn1_weight_tensor) + ffn1_biases.append(self.ffn1_proj_bias_tensor) + ffn2_weights.append(ffn2_weight_tensor) + ffn2_biases.append(self.ffn2_proj_bias_tensor) + ffn_ln_scales.append(ffn_ln_scale) + ffn_ln_biases.append(ffn_ln_bias) + qkv_in_scale.append(self.qkv_in_scales[i]) + out_linear_in_scale.append(self.out_linear_in_scales[i]) + ffn1_in_scale.append(self.ffn1_in_scales[i]) + ffn2_in_scale.append(self.ffn2_in_scales[i]) + + qkv_out_scales_tensor[i, :] *= self.qkv_out_scales[i] + out_linear_out_scales_tensor[i, :] *= self.out_linear_out_scales[i] + ffn1_out_scales_tensor[i, :] *= self.ffn1_out_scales[i] + ffn2_out_scales_tensor[i, :] *= self.ffn2_out_scales[i] + + if self.has_cache_kv: + cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=True)) + + final_out = fused_multi_transformer_int8( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + out_weights, + out_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm=self.pre_layer_norm, + epsilon=epsilon, + cache_kvs=cache_kvs, + time_step=time_step, + attn_mask=attn_mask, + dropout_rate=self.dropout_prob, + training=self.training, + mode='upscale_in_train', + trans_qkvw=True, + ring_id=-1, + name=None, + qkv_out_scales=qkv_out_scales_tensor, + out_linear_out_scales=out_linear_out_scales_tensor, + ffn1_out_scales=ffn1_out_scales_tensor, + ffn2_out_scales=ffn2_out_scales_tensor, + num_head=self.num_heads, + dim_head=self.head_dim, + dim_ffn=4 * self.embed_dim, + qkv_in_scale=qkv_in_scale, + out_linear_in_scale=out_linear_in_scale, + ffn1_in_scale=ffn1_in_scale, + ffn2_in_scale=ffn2_in_scale) + + if self.has_cache_kv: + return final_out[0], final_out[1] + + return final_out + + def test_fused_multi_transformer_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedMultiTransformerOut() + if self.has_cache_kv: + final_out, cache_kv_out = final_out + s = cache_kv_out[0].shape + bsz = s[1] + num_head = s[2] + max_seq_len = s[3] + head_dim = s[4] + elems = 8 if self.x_type is np.float16 else 4 + v_elems = head_dim // elems + + if self.debug: + print("cache_k out timestep=128") + print(cache_kv_out[0].reshape( + [2, bsz, num_head, v_elems, max_seq_len, + elems])[0, 0, 0, :, self.cache_length, :]) + + print("cache_v out timestep=128") + print(cache_kv_out[0][1, 0, 0, self.cache_length, :]) + + if self.gen_cache_kv: + final_out_ref, cache_kvs = final_out_ref + for i in range(self.layers): + cache_k_ref = cache_kvs[i][0] + cache_v_ref = cache_kvs[i][1] + + cache_k = cache_kv_out[i][0, :] + cache_k = cache_k.reshape( + [bsz, num_head, v_elems, max_seq_len, elems]) + cache_k = cache_k[:, :, :, :self.cache_length, :] + cache_k = cache_k.transpose([0, 1, 3, 2, 4]) + cache_k = cache_k.reshape( + [bsz, num_head, self.cache_length, head_dim]) + + cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :] + + np.testing.assert_allclose(cache_k_ref, + cache_k, + rtol=self.rtol, + atol=self.atol) + np.testing.assert_allclose(cache_v_ref, + cache_v, + rtol=self.rtol, + atol=self.atol) + if i == 0: + break + + np.testing.assert_allclose(final_out_ref, + final_out, + rtol=self.rtol, + atol=self.atol) + + +class TestFusedMultiTransformerInt8OpFp16(TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.x_type = np.float16 + self.layers = 3 # odd layers + + +class TestFusedMultiTransformerInt8OpCacheKV(TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + super().generate_input_data() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 3 # odd layers + + +class TestFusedMultiTransformerInt8OpCacheKVFp16(TestFusedMultiTransformerInt8Op + ): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + + +class TestFusedMultiTransformerInt8OpGenCacheKV(TestFusedMultiTransformerInt8Op + ): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + + +class TestFusedMultiTransformerInt8OpGenCacheKVFp16( + TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.x_type = np.float16 + self.layers = 3 # odd layers + + +class TestFusedMultiTransformerInt8OpPostLayerNormFp16( + TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerInt8OpCacheKVPostLayerNorm( + TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerInt8OpCacheKVPostLayerNormFp16( + TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + self.pre_layer_norm = False + + +class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNorm( + TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.pre_layer_norm = False + + +class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNormFp16( + TestFusedMultiTransformerInt8Op): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +if __name__ == "__main__": + unittest.main() -- GitLab