未验证 提交 3d7e2118 编写于 作者: R RichardWooSJTU 提交者: GitHub

Add INT8 support for fused_multi_transformer_op (#45284)

上级 7f346a76
......@@ -165,7 +165,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is "
<< var_data_type;
if (var_data_type == paddle::framework::proto::VarType::FP16) {
if (var_data_type == paddle::framework::proto::VarType::FP16 &&
t->dtype() != paddle::experimental::DataType::FLOAT16) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
......
......@@ -23,6 +23,7 @@ register_operators(
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
fused_multi_transformer_int8_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
......@@ -119,6 +120,7 @@ if(WITH_GPU OR WITH_ROCM)
# fused_attention_op
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
op_library(fused_multi_transformer_int8_op)
op_library(fused_bias_dropout_residual_layer_norm_op)
endif()
# resnet_unit needs cudnn 8.0 above
......
......@@ -19,7 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
// NOTE: T must be the same as OutType in ComputeBackward
template <typename T, typename InType = T, typename OutType = T>
class AttnLayerNorm {
public:
AttnLayerNorm(const phi::GPUContext& dev_ctx,
......@@ -33,17 +34,28 @@ class AttnLayerNorm {
~AttnLayerNorm() {}
void ComputeForward(const T* x_data,
void ComputeForward(const InType* x_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* bias_data,
T* y_data,
OutType* y_data,
LayerNormParamType<T>* mean_data,
LayerNormParamType<T>* var_data) {
LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();
switch (GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, LayerNormParamType<T>, kBlockDim>
LayerNormForward<T,
LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
......@@ -51,7 +63,13 @@ class AttnLayerNorm {
mean_data,
var_data,
epsilon_,
feature_size_));
feature_size_,
dequant_out_scale_data,
quant_out_scale_offset,
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Feature_size must be larger than 1"));
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/operators/fused/cublaslt.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class AttnMatmulINT8 {
public:
AttnMatmulINT8(
const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias)
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper);
}
~AttnMatmulINT8() {}
// This function is used to execute GEMM, with input and output's types are
// both T.
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const float quant_in_scale,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());
dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}
// This function is used to execute GEMM, with input and output's types are
// both INT8.
void ComputeForwardINT8ToINT8(const framework::Tensor* weight,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}
// This function is used to execute GEMM, with input and output's types are
// INT8 and T.
void ComputeForwardINT8ToT(const framework::Tensor* weight,
const float quant_in_scale,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());
dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}
// This function is used to execute GEMM, with input and output's types are T
// and INT8.
void ComputeForwardTToINT8(const framework::Tensor* weight,
const float quant_in_scale,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}
private:
const phi::GPUContext& dev_ctx_;
int m_; // m
int n_; // n
int k_; // k
int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
namespace dyl = paddle::platform::dynload;
namespace paddle {
namespace operators {
class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else
status = dyl::cublasLtMatmulDescCreate(
&matmul_desc_, cudaComputeType, CUDA_R_32I);
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatmulDescCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
cublasOperation_t op_transpose = CUBLAS_OP_T;
status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op_transpose,
sizeof(op_transpose));
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatmulDescSetAttribute execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
// matrix desc
status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
~CublasLtHelper() {
if (handle_) dyl::cublasLtDestroy(handle_);
if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_);
if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_);
if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_);
if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_);
}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream) {
cublasStatus_t status;
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
cublasLtMatmulAlgo_t algo;
int algoId = 21;
int swizzle = 0;
int customOption = 0;
int tile = 15;
int splitK_val = 0;
int reductionScheme = 0;
#if CUDA_VERSION >= 11000
int stages = 23;
#endif
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
dyl::cublasLtMatmulAlgoInit(handle_,
cudaComputeType,
CUDA_R_32I,
CUDA_R_8I,
CUDA_R_8I,
CUDA_R_32I,
CUDA_R_32I,
algoId,
&algo);
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption),
sizeof(customOption));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val),
sizeof(splitK_val));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme),
sizeof(int));
#if CUDA_VERSION >= 11000
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif
#endif
status = dyl::cublasLtMatmul(handle_,
matmul_desc_,
&alpha_,
B_dev,
B_desc_,
A_dev,
A_desc_,
&beta_,
C_dev,
C_desc_,
C_dev,
C_desc_,
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
&algo,
#else
nullptr,
#endif
nullptr,
0,
stream);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatmul execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
private:
cublasLtHandle_t handle_;
cublasLtMatmulDesc_t matmul_desc_;
cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_;
int32_t alpha_;
int32_t beta_;
int m_;
int k_;
int n_;
};
} // namespace operators
} // namespace paddle
......@@ -60,8 +60,14 @@ struct GeluGradFunctor {
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActBias(Functor act,
template <typename T,
typename MaskType,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedDropoutActBias(
Functor act,
const uint64_t seed,
const uint64_t rows,
const uint64_t cols,
......@@ -69,10 +75,17 @@ __global__ void FusedDropoutActBias(Functor act,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const T *__restrict__ src,
const InType *__restrict__ src,
const T *__restrict__ bias,
T *dst,
MaskType *mask) {
OutType *dst,
MaskType *mask,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
......@@ -90,7 +103,9 @@ __global__ void FusedDropoutActBias(Functor act,
VecSize,
false,
true,
Functor>(r,
Functor,
InType,
OutType>(r,
i,
cols,
&state,
......@@ -104,7 +119,14 @@ __global__ void FusedDropoutActBias(Functor act,
is_test,
nullptr,
nullptr,
act);
act,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
}
......@@ -112,7 +134,11 @@ __global__ void FusedDropoutActBias(Functor act,
/**
* @brief dst = dropout(activation(src + bias));
*/
template <typename T, typename MaskType, typename Functor>
template <typename T,
typename MaskType,
typename Functor,
typename InType = T,
typename OutType = T>
void LaunchDropoutActBias(Functor act_functor,
const uint64_t seed,
const uint32_t rows,
......@@ -121,14 +147,21 @@ void LaunchDropoutActBias(Functor act_functor,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const T *src,
const InType *src,
const T *bias,
T *dst,
OutType *dst,
MaskType *mask_data,
const phi::GPUContext &ctx) {
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
SetZero<T>(ctx, dst, rows * cols);
SetZero<T>(ctx, reinterpret_cast<T *>(dst), rows * cols);
SetZero<MaskType>(ctx, mask_data, rows * cols);
return;
}
......@@ -137,7 +170,7 @@ void LaunchDropoutActBias(Functor act_functor,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedDropoutActBias<T, MaskType, VecSize, Functor>
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
......@@ -150,9 +183,13 @@ void LaunchDropoutActBias(Functor act_functor,
src,
bias,
dst,
mask_data);
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
} else {
FusedDropoutActBias<T, MaskType, 1, Functor>
FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
......@@ -165,7 +202,11 @@ void LaunchDropoutActBias(Functor act_functor,
src,
bias,
dst,
mask_data);
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
......
......@@ -109,7 +109,10 @@ struct DropoutParam {
}
};
template <typename T, typename MaskType>
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutHelper {
private:
int GetIncrement(const phi::GPUContext& ctx) {
......@@ -140,13 +143,18 @@ class FusedDropoutHelper {
// out = residual + dropout( src + bias )
void ResidualDropoutBias(const phi::GPUContext& ctx,
const T* src,
const InType* src,
const T* residual,
const T* bias,
T* out,
MaskType* mask) {
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType>(rows_,
LaunchResidualDropoutBias<T, MaskType, InType, OutType>(
rows_,
cols_,
increment,
dropout_param_.seed,
......@@ -158,7 +166,11 @@ class FusedDropoutHelper {
bias,
mask,
out,
ctx);
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
void ResidualDropoutBiasGrad(const phi::GPUContext& ctx,
......@@ -189,15 +201,22 @@ class FusedDropoutHelper {
// out = dropout(activation(src + bias))
void DropoutActBias(const phi::GPUContext& ctx,
const T* src,
const InType* src,
const T* bias,
const std::string& act_method,
T* out,
MaskType* mask) {
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto increment = GetIncrement(ctx);
if (act_method == "gelu") {
GeluFunctor<T> gelu;
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>>(
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>, InType, OutType>(
gelu,
dropout_param_.seed,
rows_,
......@@ -210,11 +229,21 @@ class FusedDropoutHelper {
bias,
out,
mask,
ctx);
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else if (act_method == "relu") {
phi::funcs::ReluFunctor<T> relu;
LaunchDropoutActBias<T, MaskType, phi::funcs::ReluFunctor<T>>(
relu,
LaunchDropoutActBias<T,
MaskType,
phi::funcs::ReluFunctor<T>,
InType,
OutType>(relu,
dropout_param_.seed,
rows_,
cols_,
......@@ -226,7 +255,14 @@ class FusedDropoutHelper {
bias,
out,
mask,
ctx);
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
......@@ -283,8 +319,12 @@ class FusedDropoutHelper {
DropoutParam dropout_param_;
};
template <typename T, typename MaskType>
class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutLayerNormHelper
: public FusedDropoutHelper<T, MaskType, InType, OutType> {
public:
FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows,
......@@ -301,23 +341,24 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
const int cols,
const DropoutParam& dropout_param,
const float epsilon)
: FusedDropoutHelper<T, MaskType>(ctx, rows, cols, dropout_param) {
: FusedDropoutHelper<T, MaskType, InType, OutType>(
ctx, rows, cols, dropout_param) {
using U = LayerNormParamType<T>;
epsilon_ = epsilon;
}
// call layer_norm
void LayerNorm(const phi::GPUContext& ctx,
const T* src,
const InType* src,
const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* beta,
T* out,
OutType* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
switch (GetDesiredBlockDim(this->cols_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U, kBlockDim>
LayerNormForward<T, U, kBlockDim, false, InType, OutType>
<<<this->rows_, kBlockDim, 0, ctx.stream()>>>(
src, gamma, beta, out, mean, variance, epsilon_, this->cols_));
}
......@@ -349,17 +390,25 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
// out = layernorm(residual + dropout(src + bias))
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
void LayernormResidualDropoutBias(const phi::GPUContext& ctx,
const T* src,
void LayernormResidualDropoutBias(
const phi::GPUContext& ctx,
const InType* src,
const T* residual,
const T* bias,
const P* gamma,
const P* beta,
T* dropout_out,
MaskType* mask,
T* out,
OutType* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
LayerNormParamType<T>* variance,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using U = LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
......@@ -368,7 +417,12 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T, MaskType, U, is_same_type>(
LaunchLayernormResidualDropoutBias<T,
MaskType,
U,
is_same_type,
InType,
OutType>(
this->rows_,
this->cols_,
increment,
......@@ -387,7 +441,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
out,
mean,
variance,
ctx);
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
......
......@@ -418,7 +418,9 @@ template <typename T,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW,
int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA,
typename InType = T,
typename OutType = T>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
int rows,
int cols,
......@@ -428,7 +430,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
const bool is_test,
const uint64_t increment,
const float epsilon,
const T *__restrict__ x_ptr,
const InType *__restrict__ x_ptr,
const T *__restrict__ residual_ptr,
const T *__restrict__ bias_ptr,
const ScaleT *__restrict__ gamma_ptr,
......@@ -437,10 +439,20 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
U *__restrict__ mean_out_ptr,
U *__restrict__ var_out_ptr,
T *__restrict__ residual_out_ptr,
T *__restrict__ y_ptr) {
OutType *__restrict__ y_ptr,
const float quant_last_in_scale = 1.0,
const float *__restrict__ quant_out_scale_ptr = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using Vec_in_type = phi::AlignedVector<InType, VecSize>;
using Vec_out_type = phi::AlignedVector<OutType, VecSize>;
using Vec_float = phi::AlignedVector<float, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
const int tidx = threadIdx.x;
......@@ -481,12 +493,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
Vec_in_type x_input[LDGS];
Vec residual[LDGS];
Vec_float dequant_out_scale[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
phi::Load<T, VecSize>(residual_ptr + row * ELTS_PER_ROW + col * VecSize,
&residual[it]);
phi::Load<InType, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize,
&x_input[it]);
if (quant_out_scale_ptr != nullptr) {
phi::Load<float, VecSize>(
quant_out_scale_ptr + quant_out_scale_offset + col * VecSize,
&dequant_out_scale[it]);
}
col += THREADS_PER_ROW;
}
......@@ -520,20 +541,42 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = (x[it][jt] + bias[it][jt]) *
if (std::is_same<InType, int32_t>::value) {
T tmp = (static_cast<T>(static_cast<float>(x_input[it][jt]) *
quant_last_in_scale /
dequant_out_scale[it][jt]) +
bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
x[it][jt] = tmp;
xf[it * VecSize + jt] = U(tmp);
} else {
x[it][jt] = (static_cast<T>(x_input[it][jt]) + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor +
if (std::is_same<InType, int32_t>::value) {
// for int32 input, we need to dequantize.
T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) *
quant_last_in_scale /
dequant_out_scale[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
x[it][jt] = tmp;
} else {
x[it][jt] = static_cast<T>(x_input[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
}
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
......@@ -626,6 +669,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
var_out_ptr[row] = var_local * rn;
}
Vec_out_type x_output[LDGS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
......@@ -638,12 +683,26 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
U tmp = rsigma * (static_cast<U>(xf[it * VecSize + jt]) - mu_local);
x[it][jt] = static_cast<T>(static_cast<U>(gamma[it][jt]) * tmp +
static_cast<U>(beta[it][jt]));
if (std::is_same<OutType, int8_t>::value)
x_output[it][jt] = quant_helper(x[it][jt],
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
if (std::is_same<OutType, int8_t>::value) {
phi::Store<OutType, VecSize>(
x_output[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
} else {
phi::Store<T, VecSize>(
x[it],
reinterpret_cast<T *>(y_ptr) + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW;
}
}
......@@ -668,7 +727,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
template <typename T,
typename MaskType,
typename U,
bool ScaleBiasWithSameTypeX = false>
bool ScaleBiasWithSameTypeX = false,
typename InType = T,
typename OutType = T>
void LaunchLayernormResidualDropoutBias(
const uint32_t rows,
const uint32_t cols,
......@@ -678,18 +739,26 @@ void LaunchLayernormResidualDropoutBias(
const float epsilon,
const bool is_upscale_in_train,
const bool is_test,
const T *src,
const InType *src,
const T *residual,
const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask_data,
T *dst,
T *layernorm_dst,
OutType *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var,
const phi::GPUContext &ctx) {
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
// dropout_prob == 1.0f
// NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
auto cuda_place = ctx.GetPlace();
memory::Copy(cuda_place,
......@@ -705,10 +774,11 @@ void LaunchLayernormResidualDropoutBias(
switch (GetDesiredBlockDim(cols)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U, kBlockDim, ScaleBiasWithSameTypeX>
<<<rows, kBlockDim, 0, ctx.stream()>>>(dst,
<<<rows, kBlockDim, 0, ctx.stream()>>>(
dst,
scale,
layernorm_bias,
layernorm_dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var,
epsilon,
......@@ -731,6 +801,9 @@ void LaunchLayernormResidualDropoutBias(
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; \
const int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW * VecSize; \
const int LDGS = cols / ELTS_PER_ROW_PER_CTA; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
......@@ -742,7 +815,16 @@ void LaunchLayernormResidualDropoutBias(
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
cols><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(rows, \
cols, \
THREADS_PER_WARP, \
THREADS_PER_ROW, \
THREADS_PER_CTA, \
ROWS_PER_CTA, \
ELTS_PER_ROW_PER_CTA, \
LDGS, \
InType, \
OutType> \
<<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
......@@ -759,7 +841,14 @@ void LaunchLayernormResidualDropoutBias(
mean, \
var, \
dst, \
layernorm_dst); \
layernorm_dst, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_out_scale_offset, \
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound); \
} break
#define LAUNCH_FUSED_FAST_LN_KERNEL \
......@@ -784,7 +873,8 @@ void LaunchLayernormResidualDropoutBias(
if (cols % VecSize != 0) {
int blockDim = GetDesiredBlockDim(cols);
FusedLayernormResidualDropoutBias<T, uint8_t, 1, U, ScaleBiasWithSameTypeX>
<<<rows, blockDim, 0, ctx.stream()>>>(rows,
<<<rows, blockDim, 0, ctx.stream()>>>(
rows,
cols,
seed,
dropout_prob,
......@@ -792,14 +882,14 @@ void LaunchLayernormResidualDropoutBias(
is_test,
increment,
epsilon,
src,
reinterpret_cast<const T *>(src),
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
layernorm_dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var);
} else {
......@@ -819,7 +909,8 @@ void LaunchLayernormResidualDropoutBias(
VecSize,
U,
ScaleBiasWithSameTypeX>
<<<rows, blockDim, 0, ctx.stream()>>>(rows,
<<<rows, blockDim, 0, ctx.stream()>>>(
rows,
cols,
seed,
dropout_prob,
......@@ -827,14 +918,14 @@ void LaunchLayernormResidualDropoutBias(
is_test,
increment,
epsilon,
src,
reinterpret_cast<const T *>(src),
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
layernorm_dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var);
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel {
private:
static constexpr const char *OpName = "FusedMultiTransformerINT8Op";
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
#define CHECK_INPUT(name) \
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName)
#define CHECK_INPUTS(name) \
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName)
#define CHECK_OUTPUT(name) \
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName)
#define CHECK_OUTPUTS(name) \
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName)
CHECK_INPUT(X);
// attention
CHECK_INPUTS(QKVW);
CHECK_INPUTS(OutLinearW);
if (ctx->HasInput("TimeStep")) {
CHECK_INPUTS(CacheKV);
}
if (ctx->HasInputs("CacheKV")) {
CHECK_OUTPUTS(CacheKVOut);
}
// ffn
CHECK_INPUTS(FFN1Weight);
CHECK_INPUTS(FFN2Weight);
CHECK_OUTPUT(Out);
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputsDim("QKVW")[0];
bool trans_qkvw = ctx->Attrs().Get<bool>("trans_qkvw");
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
platform::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(y_dim.size(),
4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(
x_dim[2],
trans_qkvw ? y_dim[3] : y_dim[0],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
if (trans_qkvw) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
} else {
PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3],
y_dim[0],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
}
if (ctx->HasInputs("CacheKV")) {
// [2, batch_size, num_head, max_seq_len, head_size]
const auto &c_dims = ctx->GetInputsDim("CacheKV");
const auto &c_dim = c_dims[0];
PADDLE_ENFORCE_EQ(
c_dim.size(),
5,
paddle::platform::errors::InvalidArgument(
"The CacheKV must be 5 dims, but got %d", c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0],
2,
paddle::platform::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1],
x_dim[0],
paddle::platform::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
trans_qkvw ? y_dim[1] : y_dim[2],
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3],
0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "TimeStep") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class FusedMultiTransformerINT8OpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("LnScale",
"Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("LnBias",
"Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("QKVW", "The qkv weight tensor.").AsDuplicable();
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable();
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
AddInput("OutLinearBias", "The out_linear bias tensor.")
.AsDispensable()
.AsDuplicable();
AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op")
.AsDuplicable();
AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddInput("QKVOutScale",
"QKVOutScale is used to dequantize qkv output tensor."
"In order to keep consistent with the PTQ/QAT calculation logic,"
"QKVOutScale should be max_bound * max_bound / max_range."
"Here max_range is per-channel weight scale."
"The shape of QKVOutScale is [num_layers, num_channels]")
.AsDispensable();
AddInput("OutLinearOutScale",
"OutLinearOutScale is used to dequantize out_linear output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
AddInput("FFN1OutScale",
"FFN1OutScale is used to dequantize ffn1 output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
AddInput("FFN2OutScale",
"FFN2OutScale is used to dequantize ffn2 output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "Result after multi .");
AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
"[default true].")
.SetDefault(true);
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
true,
platform::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));
});
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f,
true,
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"The meaning is the same as 'attn_dropout_implementation'.")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train",
true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<bool>(
"trans_qkvw",
"Whether the weights of qkv should be transposed. If true,"
"the shape eights of qkv should be [3, num_head, dim_head, dim_embed]."
"Otherwise the shape of weights of qkv should be"
"[dim_embed, 3, num_head, dim_head]")
.SetDefault(true);
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1);
AddAttr<int>("num_head", "num_head").SetDefault(0);
AddAttr<int>("dim_head", "dim_head").SetDefault(0);
AddAttr<int>("dim_ffn", "dim_ffn").SetDefault(0);
AddAttr<std::vector<float>>(
"qkv_in_scale",
"qkv_in_scale is used to quantize qkv input tensor."
"in_scale is generated by PTQ or QAT, which represents valid max range "
"of this tensor."
"the size of qkv_in_scale should be num_layers, which is equal to "
"QKVW.dims()[0]")
.SetDefault({});
AddAttr<std::vector<float>>(
"out_linear_in_scale",
"out_linear_in_scale is used to quantize out_linear input tensor."
"the size of out_linear_in_scale is the same as qkv_in_scale")
.SetDefault({});
AddAttr<std::vector<float>>(
"ffn1_in_scale",
"ffn1_in_scale is used to quantize ffn1 input tensor."
"the size of ffn1_in_scale is the same as qkv_in_scale")
.SetDefault({});
AddAttr<std::vector<float>>(
"ffn2_in_scale",
"ffn2_in_scale is used to quantize ffn2 input tensor."
"the size of ffn2_in_scale is the same as qkv_in_scale")
.SetDefault({});
AddAttr<int>(
"quant_round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(-2.5)=-3")
.SetDefault(1);
AddAttr<float>(
"quant_max_bound",
"(float, default 127.0) the max bound of float type to int type")
.SetDefault(127.0);
AddAttr<float>(
"quant_min_bound",
"(float, default -127.0) the min bound of float type to int type")
.SetDefault(-127.0);
AddComment(R"DOC(fused multi transformer layers op)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_multi_transformer_int8,
ops::FusedMultiTransformerINT8Op,
ops::FusedMultiTransformerINT8OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm_int8.h"
#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h"
namespace paddle {
namespace operators {
template <typename T>
class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto &dev_ctx = ctx.cuda_device_context();
auto *time_step = ctx.Input<Tensor>("TimeStep");
// 0. input
auto *input_x = ctx.Input<Tensor>("X");
const auto input_x_dims = input_x->dims();
int bsz = input_x_dims[0];
int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
// quant input scales, vector, size = num_layers
auto qkv_in_scale = ctx.Attr<std::vector<float>>("qkv_in_scale");
auto out_linear_in_scale =
ctx.Attr<std::vector<float>>("out_linear_in_scale");
auto ffn1_in_scale = ctx.Attr<std::vector<float>>("ffn1_in_scale");
auto ffn2_in_scale = ctx.Attr<std::vector<float>>("ffn2_in_scale");
// quant round type and bound
auto quant_round_type = ctx.Attr<int>("quant_round_type");
auto quant_max_bound = ctx.Attr<float>("quant_max_bound");
auto quant_min_bound = ctx.Attr<float>("quant_min_bound");
// dequant output scales, tensor, size = [num_layers, n], n is gemm output
// size
auto *qkv_out_scale = ctx.Input<Tensor>("QKVOutScale");
auto *out_linear_out_scale = ctx.Input<Tensor>("OutLinearOutScale");
auto *ffn1_out_scale = ctx.Input<Tensor>("FFN1OutScale");
auto *ffn2_out_scale = ctx.Input<Tensor>("FFN2OutScale");
int qkv_out_scale_n = qkv_out_scale->dims()[1];
int out_linear_out_scale_n = out_linear_out_scale->dims()[1];
int ffn1_out_scale_n = ffn1_out_scale->dims()[1];
int ffn2_out_scale_n = ffn2_out_scale->dims()[1];
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto ln_scales = ctx.MultiInput<Tensor>("LnScale");
auto ln_biases = ctx.MultiInput<Tensor>("LnBias");
auto ln_compute =
AttnLayerNorm<T, T, int8_t>(dev_ctx, epsilon, bsz_seq, dim_embed);
Tensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}});
auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto qkv_weights = ctx.MultiInput<Tensor>("QKVW");
auto qkv_biases = ctx.MultiInput<Tensor>("QKVBias");
const bool trans_qkvw = ctx.Attr<bool>("trans_qkvw");
const auto qkv_w_dims = qkv_weights[0]->dims();
int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
AttnMatmulINT8<T> qkv_compute(
dev_ctx, bsz_seq, output_size, input_size, compute_bias);
Tensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 3. fmha
AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr);
auto fmha_compute =
FMHARef<T>(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param);
auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto cache_kvs = ctx.MultiInput<Tensor>("CacheKV");
auto cache_kv_outs = ctx.MultiOutput<Tensor>("CacheKVOut");
// auto *time_step = ctx.Input<Tensor>("TimeStep");
auto out_seq_len = seq_len;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step->place(),
platform::CPUPlace(),
platform::errors::PreconditionNotMet(
"The place of input(TimeStep) must be CPUPlace."));
// cache_seq_len
int time_step_value = time_step->data<int>()[0];
PADDLE_ENFORCE_GT(time_step_value,
0,
platform::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d",
time_step_value));
PADDLE_ENFORCE_EQ(
seq_len,
1,
platform::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
out_seq_len += time_step_value;
}
Tensor transpose_out_2, qk_out;
transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}});
auto *transpose_out_2_data =
dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T));
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
Tensor softmax_out;
Tensor attn_dropout_mask_out, attn_dropout_out;
Tensor qktv_out, fmha_out;
softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *softmax_out_data =
dev_ctx.Alloc<T>(&softmax_out, softmax_out.numel() * sizeof(T));
attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_mask_out_data = dev_ctx.Alloc<T>(
&attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T));
attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_data_data = dev_ctx.Alloc<T>(
&attn_dropout_out, attn_dropout_out.numel() * sizeof(T));
qktv_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *qktv_out_data =
dev_ctx.Alloc<T>(&qktv_out, qktv_out.numel() * sizeof(T));
fmha_out.Resize({{bsz, seq_len, num_head, dim_head}});
auto *fmha_out_data =
dev_ctx.Alloc<T>(&fmha_out, fmha_out.numel() * sizeof(T));
// 4. out_linear
auto out_linear_weights = ctx.MultiInput<Tensor>("OutLinearW");
auto out_linear_biases = ctx.MultiInput<Tensor>("OutLinearBias");
int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false)
AttnMatmulINT8<T> out_linear_compute(
dev_ctx, bsz_seq, dim_embed, hidden_size, false);
// 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t, int32_t, int8_t>
fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon);
FusedDropoutLayerNormHelper<T, uint8_t>
fused_dropout_layernorm_helper_for_post_layernorm(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<Tensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<Tensor>("FFNLnBias");
Tensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}});
bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T));
}
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
// 6. ffn matmul1
auto ffn1_weights = ctx.MultiInput<Tensor>("FFN1Weight");
auto ffn1_biases = ctx.MultiInput<Tensor>("FFN1Bias");
auto ffn1_weight_dim = ffn1_weights[0]->dims();
int dim_ffn = ffn1_weight_dim[0];
AttnMatmulINT8<T> ffn1_linear_compute(
dev_ctx, bsz_seq, dim_ffn, dim_embed, false);
Tensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
// 7. ffn act + bias
DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutHelper<T, uint8_t, int32_t, int8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper_for_post_layernorm(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param);
Tensor ffn1_dropout_out, ffn1_dropout_mask;
ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_dropout_out_data = dev_ctx.Alloc<T>(
&ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T));
ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_dropout_mask_data = dev_ctx.Alloc<uint8_t>(
&ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t));
// 8. ffn2 matmul
auto ffn2_weights = ctx.MultiInput<Tensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<Tensor>("FFN2Bias");
AttnMatmulINT8<T> ffn2_linear_compute(
dev_ctx, bsz_seq, dim_embed, dim_ffn, false);
// 9. ffn2 residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t, int32_t, int8_t>
ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
FusedDropoutLayerNormHelper<T, uint8_t, int32_t, T>
ffn2_fused_dropout_dequant_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
FusedDropoutLayerNormHelper<T, uint8_t>
ffn2_fused_dropout_helper_for_post_layernorm(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
// []. init workspace for cublasLt transform
Tensor input_workspace, output_workspace;
// for input and output transform data is CUBLASLT_ORDER_COL32 format,
int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn),
n_max = std::max({output_size, dim_embed, dim_ffn});
input_workspace.Resize(
{{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}});
dev_ctx.Alloc<int8_t>(&input_workspace,
input_workspace.numel() * sizeof(int8_t));
output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}});
dev_ctx.Alloc<int32_t>(&output_workspace,
output_workspace.numel() * sizeof(int32_t));
// calc
auto *out = ctx.Output<Tensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
Tensor *from_tensor = out;
Tensor tmp_out;
tmp_out.Resize({{bsz, seq_len, dim_embed}});
auto *tmp_out_data =
dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>();
Tensor *buf0 = nullptr;
Tensor *buf1 = nullptr;
// step0: x --> buf1
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (pre_layer_norm) {
buf1 = out;
} else {
buf0 = &tmp_out;
buf1 = out;
}
for (int i = 0; i < layers; ++i) {
// step1. layer_norm
if (i == 0 && pre_layer_norm) {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
// TODO(wangxi): can remove mean var in inference
ln_compute.ComputeForward(x_data,
ln_scale_data,
ln_bias_data,
input_workspace.data<int8_t>(),
ln_mean_data,
ln_var_data,
nullptr,
0,
qkv_in_scale[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step1";
#endif
// step2. qkv
const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr;
// NOTE: in decoder stage, bias is fused in fmha
const Tensor *bias = time_step ? nullptr : qkv_bias;
if (!pre_layer_norm && i == 0) {
qkv_compute.ComputeForward(qkv_weights[i],
input_x,
&input_workspace,
bias,
&qkv_out,
&output_workspace,
&qkv_out,
qkv_in_scale[i],
qkv_out_scale,
i * qkv_out_scale_n,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else if (!pre_layer_norm) {
qkv_compute.ComputeForward(qkv_weights[i],
buf1,
&input_workspace,
bias,
&qkv_out,
&output_workspace,
&qkv_out,
qkv_in_scale[i],
qkv_out_scale,
i * qkv_out_scale_n,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
qkv_compute.ComputeForwardINT8ToT(qkv_weights[i],
qkv_in_scale[i],
&input_workspace,
bias,
&qkv_out,
&output_workspace,
&qkv_out,
qkv_out_scale,
i * qkv_out_scale_n);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2";
#endif
// step3. fmha
const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr;
Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr;
if (time_step) { // generation decoder stage
// [2, batch_size, num_head, max_seq_len, head_size]
int max_seq_len = cache_kv->dims()[3];
fmha<T>(dev_ctx,
qkv_out,
*qkv_bias,
*src_mask,
cache_kv_out,
&fmha_out,
bsz,
max_seq_len,
num_head,
dim_head,
time_step->data<int>()[0],
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(qkv_out,
nullptr,
src_mask,
&transpose_out_2,
nullptr,
&qk_out,
nullptr,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
int64_t k_size = q_size;
const T *q_ptr = qkv_data;
const T *k_ptr = q_ptr + q_size;
const T *v_ptr = k_ptr + k_size;
// [2, bsz, num_head, max_seq_len, head_dim]
int max_seq_len = cache_kv_out->dims()[3];
T *cache_kv_data = cache_kv_out->data<T>();
int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head;
T *cache_k_ptr = cache_kv_data;
T *cache_v_ptr = cache_kv_data + cache_k_size;
write_cache_kv<T>(dev_ctx,
cache_k_ptr,
cache_v_ptr,
k_ptr,
v_ptr,
bsz,
num_head,
seq_len,
max_seq_len,
dim_head);
} else { // not generation
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(qkv_out,
cache_kv,
src_mask,
&transpose_out_2,
cache_kv_out,
&qk_out,
nullptr,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
#endif
if (pre_layer_norm) {
out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i],
out_linear_in_scale[i],
&fmha_out,
&input_workspace,
nullptr,
&output_workspace,
nullptr,
quant_round_type,
quant_max_bound,
quant_min_bound);
AllReduce<int32_t>(output_workspace,
ring_id,
bsz * seq_len * num_head * dim_head,
dev_ctx);
} else {
out_linear_compute.ComputeForward(out_linear_weights[i],
&fmha_out,
&input_workspace,
nullptr,
buf0,
&output_workspace,
nullptr,
out_linear_in_scale[i],
out_linear_out_scale,
i * out_linear_out_scale_n,
quant_round_type,
quant_max_bound,
quant_min_bound);
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
#endif
// step5. ln(residual + dropout(input + bias))
if (pre_layer_norm) {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
// inplace
// non-inplace: buf1 -> input_workspace
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
output_workspace.data<int32_t>(),
x_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
input_workspace.data<int8_t>(),
ln_mean_data,
ln_var_data,
out_linear_in_scale[i],
out_linear_out_scale->data<float>(),
i * out_linear_out_scale_n,
ffn1_in_scale[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
auto *residual_data = (i == 0 ? x_data : buf1->data<T>());
fused_dropout_layernorm_helper_for_post_layernorm
.LayernormResidualDropoutBias(dev_ctx,
buf0->data<T>(),
residual_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
#endif
// step6. ffn matmul1
if (pre_layer_norm) {
ffn1_linear_compute.ComputeForwardINT8ToINT8(ffn1_weights[i],
&input_workspace,
nullptr,
&output_workspace,
nullptr);
} else {
ffn1_linear_compute.ComputeForward(ffn1_weights[i],
buf1,
&input_workspace,
nullptr,
&ffn1_out,
&output_workspace,
nullptr,
ffn1_in_scale[i],
ffn1_out_scale,
i * ffn1_out_scale_n,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step6";
#endif
// step7. act bias
// TODO(wangxi): remove dropout mask in inference
if (pre_layer_norm) {
fused_act_dropout_helper.DropoutActBias(
dev_ctx,
output_workspace.data<int32_t>(),
ffn1_biases[i]->data<T>(),
"gelu",
input_workspace.data<int8_t>(),
ffn1_dropout_mask_data,
ffn1_in_scale[i],
ffn1_out_scale->data<float>(),
i * ffn1_out_scale_n,
ffn2_in_scale[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
fused_act_dropout_helper_for_post_layernorm.DropoutActBias(
dev_ctx,
ffn1_out_data,
ffn1_biases[i]->data<T>(),
"gelu",
ffn1_dropout_out_data,
ffn1_dropout_mask_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
#endif
// step8. ffn matmul2
if (pre_layer_norm) {
ffn2_linear_compute.ComputeForwardINT8ToINT8(ffn2_weights[i],
&input_workspace,
nullptr,
&output_workspace,
nullptr);
} else {
ffn2_linear_compute.ComputeForward(ffn2_weights[i],
&ffn1_dropout_out,
&input_workspace,
nullptr,
buf0,
&output_workspace,
nullptr,
ffn2_in_scale[i],
ffn2_out_scale,
i * ffn2_out_scale_n,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.0";
#endif
if (pre_layer_norm) {
AllReduce<int32_t>(output_workspace,
ring_id,
bsz * seq_len * num_head * dim_head,
dev_ctx);
} else {
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.1";
#endif
// step9. residual bias
if (pre_layer_norm) {
// TODO(wangxi): remove dropout mask in inference
if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>();
auto *ln_bias_data = ln_biases[i + 1]->data<U>();
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx,
output_workspace.data<int32_t>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
dropout_mask_out_data,
input_workspace.data<int8_t>(),
ln_mean_data,
ln_var_data,
ffn2_in_scale[i],
ffn2_out_scale->data<float>(),
i * ffn2_out_scale_n,
qkv_in_scale[i + 1],
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
ffn2_fused_dropout_dequant_helper.ResidualDropoutBias(
dev_ctx,
output_workspace.data<int32_t>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
buf1->data<T>(),
dropout_mask_out_data,
ffn2_in_scale[i],
ffn2_out_scale->data<float>(),
i * ffn2_out_scale_n,
1.0);
}
} else {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
ffn2_fused_dropout_helper_for_post_layernorm
.LayernormResidualDropoutBias(dev_ctx,
buf0->data<T>(),
buf1->data<T>(),
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step9";
#endif
if (pre_layer_norm) {
x_data = buf1->data<T>();
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_int8,
ops::FusedMultiTransformerINT8OpKernel<plat::float16>,
ops::FusedMultiTransformerINT8OpKernel<float>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This file has been adapted from FasterTransformer file:
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.
#include <cuda_fp16.h>
#include <float.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// for debug
// #define _DEBUG_FUSED_MULTI_TRANSFORMER
template <typename T>
static void AllReduce(framework::Tensor &tensor, // NOLINT
const int ring_id,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = ctx.Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
namespace {
namespace plat = paddle::platform;
using float16 = plat::float16;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T>
struct Masked_multihead_attention_params {
// output buffer, [B, 1(seq_len), num_head * dim_head]
T *out;
// qkv_out, [B, 1(seq_len), 3, num_head * dim_head]
const T *qkv;
// bias, [3, num_head, dim_head]
const T *qkv_bias;
// TODO(wangxi): optimize with input_lengths and max_input_len?
// [bsz, 1, 1, time_step(cache_seq_length)+1]
const T *attn_mask;
// [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head]
// k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
int batch_size;
int num_head;
int timestep; // cache_seq_length
int max_seq_length;
// 1.f / sqrt(Dh)
float inv_sqrt_dh;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// clang-format off
template <typename T, int Dh> struct Qk_vec_ {};
template <> struct Qk_vec_<float, 32> { using Type = float; };
template <> struct Qk_vec_<float, 64> { using Type = float2; };
template <> struct Qk_vec_<float, 128> { using Type = float4; };
template <> struct Qk_vec_<float, 256> { using Type = float4; };
template <> struct Qk_vec_<float16, 32> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 64> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 128> { using Type = uint2; };
template <> struct Qk_vec_<float16, 256> { using Type = uint4; };
template <typename T, int THREADS_PER_KEY> struct K_vec_ {};
template <> struct K_vec_<float, 4> { using Type = float; };
template <> struct K_vec_<float, 2> { using Type = float2; };
template <> struct K_vec_<float, 1> { using Type = float4; };
template <> struct K_vec_<float16, 4> { using Type = uint32_t; };
template <> struct K_vec_<float16, 2> { using Type = uint2; };
template <> struct K_vec_<float16, 1> { using Type = uint4; };
template <typename T, int V_VEC_SIZE> struct V_vec_ {};
template <> struct V_vec_<float, 1> { using Type = float; };
template <> struct V_vec_<float, 2> { using Type = float2; };
template <> struct V_vec_<float, 4> { using Type = float4; };
template <> struct V_vec_<float16, 2> { using Type = uint32_t; };
template <> struct V_vec_<float16, 4> { using Type = uint2; };
template <> struct V_vec_<float16, 8> { using Type = uint4; };
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T> struct V_vec_acum_fp32_ {};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
// template <> struct V_vec_acum_fp32_<float2> { using Type = float2; };
template <> struct V_vec_acum_fp32_<float4> { using Type = float4; };
// template <> struct V_vec_acum_fp32_<uint32_t> { using Type = float2; };
// template <> struct V_vec_acum_fp32_<uint2 > { using Type = Float4_; };
template <> struct V_vec_acum_fp32_<uint4> { using Type = Float8_; };
#endif
// clang-format on
inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template <>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
template <>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template <>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 tmp_res;
tmp_res.x = tmp.x * b;
tmp_res.y = tmp.y * b;
uint32_t res = float2_to_half2(tmp_res);
return res;
}
template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
return res;
}
template <>
inline __device__ uint4 mul(uint4 a, float b) {
uint4 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
res.z = mul<uint32_t, uint32_t, float>(a.z, b);
res.w = mul<uint32_t, uint32_t, float>(a.w, b);
return res;
}
template <>
inline __device__ float2 mul(float2 a, float b) {
float2 res;
res.x = a.x * b;
res.y = a.y * b;
return res;
}
template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
res.x = a.x * b;
res.y = a.y * b;
res.z = a.z * b;
res.w = a.w * b;
return res;
}
inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
inline __device__ float sum(uint16_t v) { return half_to_float(v); }
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
template <typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template <typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
inline __device__ constexpr uint32_t shfl_mask(int threads) {
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
}
template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
return (m + n - 1) / n;
}
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
return d;
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float cast_to_float(float u) { return u; }
inline __device__ float2 cast_to_float(float2 u) { return u; }
inline __device__ float4 cast_to_float(float4 u) { return u; }
inline __device__ Float8_ cast_to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
K_vec inv_q = mul<K_vec, K_vec, float>(q[0], inv_sqrt_dh);
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
inv_q = mul<K_vec, K_vec, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
return qk_dot_<THREADS_PER_KEY>(q, k, inv_sqrt_dh);
}
};
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
if (lane == 0) {
red_smem[warp] = sum;
}
__syncthreads();
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
return __shfl_sync(uint32_t(-1), sum, 0);
}
inline __device__ void convert_from_float(float &dst, float src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT
float src) {
dst = static_cast<plat::float16>(src);
}
inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT
template <typename T>
inline __device__ void zero(T &dst) { // NOLINT
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
template <typename T,
int Dh,
int Dh_MAX,
int THREADS_PER_KEY,
int THREADS_PER_VALUE,
int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
constexpr int WARP_SIZE = 32;
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
extern __shared__ char smem_[];
float *qk_smem = reinterpret_cast<float *>(smem_);
char *logits_smem_ = smem_;
// fp32 accum for logits
float *logits_smem = reinterpret_cast<float *>(logits_smem_);
T *out_smem = reinterpret_cast<T *>(smem_);
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
const int bi = blockIdx.y;
const int hi = blockIdx.x;
const int bhi = bi * params.num_head + hi;
const int tid = threadIdx.x;
float qk_max = -FLT_MAX;
float qk = 0;
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
const T *q_base = params.qkv;
const T *k_base = params.qkv + params.num_head * Dh;
const T *q_bias_base = params.qkv_bias;
const T *k_bias_base = params.qkv_bias + params.num_head * Dh;
if (tid < QK_VECS_PER_WARP) {
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
Qk_vec q;
zero(q);
q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset])
: q;
Qk_vec k;
zero(k);
k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
: k;
Qk_vec q_bias;
zero(q_bias);
q_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset])
: q_bias;
Qk_vec k_bias;
zero(k_bias);
k_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset])
: k_bias;
q = add(q, q_bias);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// we may not require k_bias.
k = add(k, k_bias);
*reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q;
int co = tid / QK_VECS_IN_16B;
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
qk = dot<Qk_vec, Qk_vec>(q, k);
if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
}
}
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED =
(QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
}
__syncthreads();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n");
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i]));
printf("\n");
}
__syncthreads();
#endif
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
int ko = tid / THREADS_PER_KEY;
int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, "");
K_vec q[K_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < K_VECS_PER_THREAD; ++i) {
q[i] = *reinterpret_cast<const K_vec *>(
&q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]);
}
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
}
}
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask);
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
const int warp = tid / WARP_SIZE;
const int lane = tid % WARP_SIZE;
if (lane == 0) {
red_smem[warp] = qk_max;
}
__syncthreads();
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======qk_out=======\n");
for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]);
printf("qk_max=%f\n", qk_max);
}
__syncthreads();
#endif
float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max);
sum += logit;
qk_smem[ti] = logit;
}
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
__syncthreads();
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
int vo = tid / THREADS_PER_VALUE;
int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE;
T *v_cache = &params.cache_kv[params.batch_size * params.num_head *
params.max_seq_length * Dh +
bhi * params.max_seq_length * Dh + vi];
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
#else
using V_vec_acum = V_vec;
#endif
V_vec_acum out;
zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti];
// Update the partial sums.
out = fma(logit, v, out);
#endif
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("======logits_out=====\n");
for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]);
printf("\n");
}
__syncthreads();
#endif
V_vec v_bias;
zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out);
#else
out = fma(logits_smem[params.timestep], v, out);
#endif
}
__syncthreads();
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out =
add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
}
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]) = out;
#endif
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
__syncthreads();
if (bi == 0 && hi == 0 && tid == 0) {
printf("======fmha_out=====\n");
for (int i = 0; i < Dh; ++i)
printf("%f ", static_cast<float>(params.out[i]));
printf("\n");
}
#endif
#else
assert(false);
#endif
}
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params,
int dim_head,
int threads_per_value,
int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(T) != 4) {
logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T);
}
#endif
size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = threads_per_block / threads_per_value;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template <typename T, int Dh, int Dh_MAX>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
} else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
}
}
template <typename T>
void fmha(const phi::GPUContext &dev_ctx,
const Tensor &qkv_tensor,
const Tensor &qkv_bias_tensor,
const Tensor &src_mask_tensor,
Tensor *cache_kv_tensor,
Tensor *out_tensor,
int batch_size,
int max_seq_length,
int num_head,
int dim_head,
int timestep,
float inv_sqrt_dh) {
Masked_multihead_attention_params<T> params;
params.out = out_tensor->data<T>();
params.qkv = qkv_tensor.data<T>();
params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>();
params.batch_size = batch_size;
params.num_head = num_head;
params.timestep = timestep;
params.max_seq_length = max_seq_length;
params.inv_sqrt_dh = inv_sqrt_dh;
switch (dim_head) {
case 10:
fmha_launch_kernel<T, 10, 32>(params, dev_ctx.stream());
break;
case 26:
fmha_launch_kernel<T, 26, 32>(params, dev_ctx.stream());
break;
case 32:
fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream());
break;
case 64:
fmha_launch_kernel<T, 64, 64>(params, dev_ctx.stream());
break;
case 96:
fmha_launch_kernel<T, 96, 128>(params, dev_ctx.stream());
break;
case 128:
fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream());
break;
case 192:
fmha_launch_kernel<T, 192, 256>(params, dev_ctx.stream());
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Dim_head = %d is unsupport!", dim_head));
}
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr int VEC_16B = 16;
template <typename T>
__global__ void write_cache_k_kernel(T *cache_k,
const T *k,
const int num_head,
const int dim_head,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
constexpr int X_ELEMS = VEC_16B / sizeof(T);
// [bsz, num_head, seq_len, dim_head/x, x]
auto k_src = reinterpret_cast<const uint4 *>(
k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, dim_head/x, max_seq_len, x]
auto k_dst = reinterpret_cast<uint4 *>(
cache_k + bi * num_head * max_seq_len * dim_head +
hi * max_seq_len * dim_head);
const int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
// vec size
int dim_head_div_x = dim_head / X_ELEMS;
// FIXME(wangxi): num_head is not need?
// if (out_idx >= num_head * dim_head_div_x * max_seq_len) return;
if (out_idx >= dim_head_div_x * max_seq_len) return;
int idx = out_idx;
const int k_seq_len_id = idx % max_seq_len;
// idx = (idx - k_seq_len_id) / max_seq_len;
idx = idx / max_seq_len;
const int k_vec_id = idx % dim_head_div_x;
if (k_seq_len_id < seq_len) {
k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id];
}
}
template <typename T>
__global__ void write_cache_v_kernel(T *cache_v,
const T *v,
const int num_head,
const int dim_head,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
// [bsz, num_head, seq_len, dim_head/x, x]
auto v_src = reinterpret_cast<const uint4 *>(
v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, max_seq_len, dim_head/x, x]
auto v_dst = reinterpret_cast<uint4 *>(
cache_v + bi * num_head * max_seq_len * dim_head +
hi * max_seq_len * dim_head);
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int X_ELEMS = VEC_16B / sizeof(T);
const int dim_head_div_x = dim_head / X_ELEMS;
if (idx >= dim_head_div_x * seq_len) return;
v_dst[idx] = v_src[idx];
}
template <typename T>
void write_cache_kv(const phi::GPUContext &dev_ctx,
T *cache_k,
T *cache_v,
const T *k,
const T *v,
const int bsz,
const int num_head,
const int seq_len,
const int max_seq_len,
const int dim_head) {
constexpr int block_sz = 128;
constexpr int x = VEC_16B / sizeof(T);
assert(dim_head % x == 0);
PADDLE_ENFORCE_EQ(
dim_head % x,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", dim_head, x));
int max_size = max_seq_len * dim_head / x;
int size = seq_len * dim_head / x;
dim3 grid(div_up(max_size, block_sz), bsz, num_head);
dim3 grid_v(div_up(size, block_sz), bsz, num_head);
// transpose [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, dim_head/x, max_seq_len, x]
write_cache_k_kernel<<<grid, block_sz, 0, dev_ctx.stream()>>>(
cache_k, k, num_head, dim_head, seq_len, max_seq_len);
// copy [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, max_seq_len, dim_head/x, x]
write_cache_v_kernel<<<grid_v, block_sz, 0, dev_ctx.stream()>>>(
cache_v, v, num_head, dim_head, seq_len, max_seq_len);
}
} // namespace
template <typename T>
class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public:
......@@ -1480,11 +338,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
if (pre_layer_norm) {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, dev_ctx);
AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx);
} else {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr);
AllReduce<T>(*buf0, ring_id, dev_ctx);
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
......@@ -1563,9 +421,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
#endif
if (pre_layer_norm) {
AllReduce<T>(*buf1, ring_id, dev_ctx);
AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx);
} else {
AllReduce<T>(*buf0, ring_id, dev_ctx);
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.1";
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This file has been adapted from FasterTransformer file:
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.
#include <cuda_fp16.h>
#include <float.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// for debug
// #define _DEBUG_FUSED_MULTI_TRANSFORMER
template <typename T>
static void AllReduce(framework::Tensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
namespace { // NOLINT
namespace plat = paddle::platform;
using float16 = plat::float16;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T>
struct Masked_multihead_attention_params {
// output buffer, [B, 1(seq_len), num_head * dim_head]
T *out;
// qkv_out, [B, 1(seq_len), 3, num_head * dim_head]
const T *qkv;
// bias, [3, num_head, dim_head]
const T *qkv_bias;
// TODO(wangxi): optimize with input_lengths and max_input_len?
// [bsz, 1, 1, time_step(cache_seq_length)+1]
const T *attn_mask;
// [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head]
// k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
int batch_size;
int num_head;
int timestep; // cache_seq_length
int max_seq_length;
// 1.f / sqrt(Dh)
float inv_sqrt_dh;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// clang-format off
template <typename T, int Dh> struct Qk_vec_ {};
template <> struct Qk_vec_<float, 32> { using Type = float; };
template <> struct Qk_vec_<float, 64> { using Type = float2; };
template <> struct Qk_vec_<float, 128> { using Type = float4; };
template <> struct Qk_vec_<float, 256> { using Type = float4; };
template <> struct Qk_vec_<float16, 32> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 64> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 128> { using Type = uint2; };
template <> struct Qk_vec_<float16, 256> { using Type = uint4; };
template <typename T, int THREADS_PER_KEY> struct K_vec_ {};
template <> struct K_vec_<float, 4> { using Type = float; };
template <> struct K_vec_<float, 2> { using Type = float2; };
template <> struct K_vec_<float, 1> { using Type = float4; };
template <> struct K_vec_<float16, 4> { using Type = uint32_t; };
template <> struct K_vec_<float16, 2> { using Type = uint2; };
template <> struct K_vec_<float16, 1> { using Type = uint4; };
template <typename T, int V_VEC_SIZE> struct V_vec_ {};
template <> struct V_vec_<float, 1> { using Type = float; };
template <> struct V_vec_<float, 2> { using Type = float2; };
template <> struct V_vec_<float, 4> { using Type = float4; };
template <> struct V_vec_<float16, 2> { using Type = uint32_t; };
template <> struct V_vec_<float16, 4> { using Type = uint2; };
template <> struct V_vec_<float16, 8> { using Type = uint4; };
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T> struct V_vec_acum_fp32_ {};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
// template <> struct V_vec_acum_fp32_<float2> { using Type = float2; };
template <> struct V_vec_acum_fp32_<float4> { using Type = float4; };
// template <> struct V_vec_acum_fp32_<uint32_t> { using Type = float2; };
// template <> struct V_vec_acum_fp32_<uint2 > { using Type = Float4_; };
template <> struct V_vec_acum_fp32_<uint4> { using Type = Float8_; };
#endif
// clang-format on
inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template <>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
template <>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template <>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 tmp_res;
tmp_res.x = tmp.x * b;
tmp_res.y = tmp.y * b;
uint32_t res = float2_to_half2(tmp_res);
return res;
}
template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
return res;
}
template <>
inline __device__ uint4 mul(uint4 a, float b) {
uint4 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
res.z = mul<uint32_t, uint32_t, float>(a.z, b);
res.w = mul<uint32_t, uint32_t, float>(a.w, b);
return res;
}
template <>
inline __device__ float2 mul(float2 a, float b) {
float2 res;
res.x = a.x * b;
res.y = a.y * b;
return res;
}
template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
res.x = a.x * b;
res.y = a.y * b;
res.z = a.z * b;
res.w = a.w * b;
return res;
}
inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
inline __device__ float sum(uint16_t v) { return half_to_float(v); }
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
template <typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template <typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
inline __device__ constexpr uint32_t shfl_mask(int threads) {
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
}
template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
return (m + n - 1) / n;
}
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
return d;
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float cast_to_float(float u) { return u; }
inline __device__ float2 cast_to_float(float2 u) { return u; }
inline __device__ float4 cast_to_float(float4 u) { return u; }
inline __device__ Float8_ cast_to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
K_vec inv_q = mul<K_vec, K_vec, float>(q[0], inv_sqrt_dh);
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
inv_q = mul<K_vec, K_vec, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
return qk_dot_<THREADS_PER_KEY>(q, k, inv_sqrt_dh);
}
};
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
if (lane == 0) {
red_smem[warp] = sum;
}
__syncthreads();
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
return __shfl_sync(uint32_t(-1), sum, 0);
}
inline __device__ void convert_from_float(float &dst, float src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT
float src) {
dst = static_cast<plat::float16>(src);
}
inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT
template <typename T>
inline __device__ void zero(T &dst) { // NOLINT
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
template <typename T,
int Dh,
int Dh_MAX,
int THREADS_PER_KEY,
int THREADS_PER_VALUE,
int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
constexpr int WARP_SIZE = 32;
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
extern __shared__ char smem_[];
float *qk_smem = reinterpret_cast<float *>(smem_);
char *logits_smem_ = smem_;
// fp32 accum for logits
float *logits_smem = reinterpret_cast<float *>(logits_smem_);
T *out_smem = reinterpret_cast<T *>(smem_);
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
const int bi = blockIdx.y;
const int hi = blockIdx.x;
const int bhi = bi * params.num_head + hi;
const int tid = threadIdx.x;
float qk_max = -FLT_MAX;
float qk = 0;
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
const T *q_base = params.qkv;
const T *k_base = params.qkv + params.num_head * Dh;
const T *q_bias_base = params.qkv_bias;
const T *k_bias_base = params.qkv_bias + params.num_head * Dh;
if (tid < QK_VECS_PER_WARP) {
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
Qk_vec q;
zero(q);
q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset])
: q;
Qk_vec k;
zero(k);
k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
: k;
Qk_vec q_bias;
zero(q_bias);
q_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset])
: q_bias;
Qk_vec k_bias;
zero(k_bias);
k_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset])
: k_bias;
q = add(q, q_bias);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// we may not require k_bias.
k = add(k, k_bias);
*reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q;
int co = tid / QK_VECS_IN_16B;
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
qk = dot<Qk_vec, Qk_vec>(q, k);
if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
}
}
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED =
(QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
}
__syncthreads();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n");
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i]));
printf("\n");
}
__syncthreads();
#endif
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
int ko = tid / THREADS_PER_KEY;
int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, "");
K_vec q[K_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < K_VECS_PER_THREAD; ++i) {
q[i] = *reinterpret_cast<const K_vec *>(
&q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]);
}
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
}
}
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask);
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
const int warp = tid / WARP_SIZE;
const int lane = tid % WARP_SIZE;
if (lane == 0) {
red_smem[warp] = qk_max;
}
__syncthreads();
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======qk_out=======\n");
for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]);
printf("qk_max=%f\n", qk_max);
}
__syncthreads();
#endif
float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max);
sum += logit;
qk_smem[ti] = logit;
}
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
__syncthreads();
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
int vo = tid / THREADS_PER_VALUE;
int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE;
T *v_cache = &params.cache_kv[params.batch_size * params.num_head *
params.max_seq_length * Dh +
bhi * params.max_seq_length * Dh + vi];
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
#else
using V_vec_acum = V_vec;
#endif
V_vec_acum out;
zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti];
// Update the partial sums.
out = fma(logit, v, out);
#endif
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("======logits_out=====\n");
for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]);
printf("\n");
}
__syncthreads();
#endif
V_vec v_bias;
zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out);
#else
out = fma(logits_smem[params.timestep], v, out);
#endif
}
__syncthreads();
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out =
add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
}
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]) = out;
#endif
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
__syncthreads();
if (bi == 0 && hi == 0 && tid == 0) {
printf("======fmha_out=====\n");
for (int i = 0; i < Dh; ++i)
printf("%f ", static_cast<float>(params.out[i]));
printf("\n");
}
#endif
#else
assert(false);
#endif
}
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params,
int dim_head,
int threads_per_value,
int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS // NOLINT
if (sizeof(T) != 4) {
logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T);
}
#endif // NOLINT
size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = threads_per_block / threads_per_value;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template <typename T, int Dh, int Dh_MAX>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
} else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
}
}
template <typename T>
void fmha(const phi::GPUContext &dev_ctx,
const Tensor &qkv_tensor,
const Tensor &qkv_bias_tensor,
const Tensor &src_mask_tensor,
Tensor *cache_kv_tensor,
Tensor *out_tensor,
int batch_size,
int max_seq_length,
int num_head,
int dim_head,
int timestep,
float inv_sqrt_dh) {
Masked_multihead_attention_params<T> params;
params.out = out_tensor->data<T>();
params.qkv = qkv_tensor.data<T>();
params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>();
params.batch_size = batch_size;
params.num_head = num_head;
params.timestep = timestep;
params.max_seq_length = max_seq_length;
params.inv_sqrt_dh = inv_sqrt_dh;
switch (dim_head) {
case 10:
fmha_launch_kernel<T, 10, 32>(params, dev_ctx.stream());
break;
case 26:
fmha_launch_kernel<T, 26, 32>(params, dev_ctx.stream());
break;
case 32:
fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream());
break;
case 64:
fmha_launch_kernel<T, 64, 64>(params, dev_ctx.stream());
break;
case 96:
fmha_launch_kernel<T, 96, 128>(params, dev_ctx.stream());
break;
case 128:
fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream());
break;
case 192:
fmha_launch_kernel<T, 192, 256>(params, dev_ctx.stream());
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Dim_head = %d is unsupport!", dim_head));
}
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr int VEC_16B = 16;
template <typename T>
__global__ void write_cache_k_kernel(T *cache_k,
const T *k,
const int num_head,
const int dim_head,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
constexpr int X_ELEMS = VEC_16B / sizeof(T);
// [bsz, num_head, seq_len, dim_head/x, x]
auto k_src = reinterpret_cast<const uint4 *>(
k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, dim_head/x, max_seq_len, x]
auto k_dst = reinterpret_cast<uint4 *>(
cache_k + bi * num_head * max_seq_len * dim_head +
hi * max_seq_len * dim_head);
const int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
// vec size
int dim_head_div_x = dim_head / X_ELEMS;
// FIXME(wangxi): num_head is not need?
// if (out_idx >= num_head * dim_head_div_x * max_seq_len) return;
if (out_idx >= dim_head_div_x * max_seq_len) return;
int idx = out_idx;
const int k_seq_len_id = idx % max_seq_len;
// idx = (idx - k_seq_len_id) / max_seq_len;
idx = idx / max_seq_len;
const int k_vec_id = idx % dim_head_div_x;
if (k_seq_len_id < seq_len) {
k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id];
}
}
template <typename T>
__global__ void write_cache_v_kernel(T *cache_v,
const T *v,
const int num_head,
const int dim_head,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
// [bsz, num_head, seq_len, dim_head/x, x]
auto v_src = reinterpret_cast<const uint4 *>(
v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, max_seq_len, dim_head/x, x]
auto v_dst = reinterpret_cast<uint4 *>(
cache_v + bi * num_head * max_seq_len * dim_head +
hi * max_seq_len * dim_head);
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int X_ELEMS = VEC_16B / sizeof(T);
const int dim_head_div_x = dim_head / X_ELEMS;
if (idx >= dim_head_div_x * seq_len) return;
v_dst[idx] = v_src[idx];
}
template <typename T>
void write_cache_kv(const phi::GPUContext &dev_ctx,
T *cache_k,
T *cache_v,
const T *k,
const T *v,
const int bsz,
const int num_head,
const int seq_len,
const int max_seq_len,
const int dim_head) {
constexpr int block_sz = 128;
constexpr int x = VEC_16B / sizeof(T);
assert(dim_head % x == 0);
PADDLE_ENFORCE_EQ(
dim_head % x,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", dim_head, x));
int max_size = max_seq_len * dim_head / x;
int size = seq_len * dim_head / x;
dim3 grid(div_up(max_size, block_sz), bsz, num_head);
dim3 grid_v(div_up(size, block_sz), bsz, num_head);
// transpose [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, dim_head/x, max_seq_len, x]
write_cache_k_kernel<<<grid, block_sz, 0, dev_ctx.stream()>>>(
cache_k, k, num_head, dim_head, seq_len, max_seq_len);
// copy [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, max_seq_len, dim_head/x, x]
write_cache_v_kernel<<<grid_v, block_sz, 0, dev_ctx.stream()>>>(
cache_v, v, num_head, dim_head, seq_len, max_seq_len);
}
} // namespace
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,9 @@ template <typename T,
int VecSize,
bool ComputeLayerNorm,
bool Activation,
typename Functor>
typename Functor,
typename InType = T,
typename OutType = T>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id,
const int col_id,
......@@ -36,30 +38,45 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
curandStatePhilox4_32_10_t *state,
const float dropout_prob,
const T factor,
const T *__restrict__ src,
const InType *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
T *dst,
OutType *dst,
MaskType *mask,
const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val,
Functor act_func) {
Functor act_func,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
using LoadFloat = phi::AlignedVector<float, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
using U = typename details::MPTypeTrait<T>::Type;
LoadT src_vec;
LoadInType src_vec;
LoadT residual_vec;
LoadT bias_vec;
LoadFloat quant_out_scale_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
residual_vec[ii] = static_cast<T>(0);
}
// vectorize load data from global
phi::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<float, VecSize>(
&dequant_out_scale_data[quant_out_scale_offset + col_id],
&quant_out_scale_vec);
if (residual) {
phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
}
......@@ -84,10 +101,18 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
}
StoreT dest_vec;
StoreOutType dest_vec_out_type;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T tmp = src_vec[ii] + bias_vec[ii];
T tmp;
if (std::is_same<InType, int32_t>::value) {
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_last_in_scale / quant_out_scale_vec[ii]);
tmp = tmp0 + bias_vec[ii];
} else {
tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
}
if (Activation) {
tmp = act_func(tmp);
}
......@@ -98,10 +123,23 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
*mean_val += tmp;
*var_val += (tmp * tmp);
}
if (std::is_same<OutType, int8_t>::value) {
dest_vec_out_type[ii] = quant_helper(dest_vec[ii],
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
// store result to global
phi::Store<T, VecSize>(dest_vec, &dst[row_id * cols + col_id]);
if (std::is_same<OutType, int8_t>::value) {
phi::Store<OutType, VecSize>(dest_vec_out_type,
&dst[row_id * cols + col_id]);
} else {
phi::Store<T, VecSize>(dest_vec,
reinterpret_cast<T *>(&dst[row_id * cols + col_id]));
}
if (!is_test) {
phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
}
......@@ -114,19 +152,28 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
* is_test: only used in inference
* mask: can be null if is_test=true
*/
template <typename T, typename MaskType, int VecSize>
__global__ void FusedResidualDropoutBias(const size_t rows,
template <typename T,
typename MaskType,
int VecSize,
typename InType = T,
typename OutType = T>
__global__ void FusedResidualDropoutBias(
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const T *__restrict__ src,
const InType *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
MaskType *mask,
T *dst,
OutType *dst,
uint64_t increment,
const bool is_test) {
const bool is_test,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
......@@ -142,8 +189,9 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
VecSize,
false,
false,
phi::funcs::ReluFunctor<T>>(
r,
phi::funcs::ReluFunctor<T>,
InType,
OutType>(r,
i,
cols,
&state,
......@@ -157,7 +205,11 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
is_test,
nullptr,
nullptr,
relu);
relu,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
}
......@@ -165,7 +217,10 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
/**
* @brief dst = residual + dropout(src + bias);
*/
template <typename T, typename MaskType>
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
void LaunchResidualDropoutBias(const uint32_t rows,
const uint32_t cols,
const int increment,
......@@ -173,14 +228,19 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const float dropout_prob,
const bool is_test,
bool is_upscale_in_train,
const T *src,
const InType *src,
const T *residual,
const T *bias,
MaskType *mask_data,
T *dst,
const phi::GPUContext &ctx) {
OutType *dst,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
// NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
if (residual == dst) return;
if (residual) {
memory::Copy(ctx.GetPlace(),
......@@ -202,7 +262,7 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedResidualDropoutBias<T, uint8_t, VecSize>
FusedResidualDropoutBias<T, uint8_t, VecSize, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
......@@ -215,9 +275,13 @@ void LaunchResidualDropoutBias(const uint32_t rows,
mask_data,
dst,
increment,
is_test);
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
} else {
FusedResidualDropoutBias<T, uint8_t, 1>
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
......@@ -230,7 +294,11 @@ void LaunchResidualDropoutBias(const uint32_t rows,
mask_data,
dst,
increment,
is_test);
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
__forceinline__ __device__ int8_t quant_helper(const T input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * inverse(scale) * static_cast<float>(input);
if (round_type == 0) {
quant_value = static_cast<float>(roundWithTiesToEven(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
return static_cast<int8_t>(quant_value);
}
template <typename T>
__global__ void quantize_kernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
bool check = ((m_id < m) && (n_id < n));
if (check) {
char4 tmp;
tmp.x = quant_helper(
input[m_id * n + n_id], scale, round_type, max_bound, min_bound);
tmp.y = quant_helper(
input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound);
tmp.z = quant_helper(
input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound);
tmp.w = quant_helper(
input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound);
output[(m_id * n + n_id) >> 2] = tmp;
}
}
template <typename T>
void quantize_kernel_launcher(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
}
// dequantize using weight scales and input scales
template <typename T>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const int m, // hidden
const int n, // batch size
const float quant_in_scale,
const float* dequant_out_scale_data,
const int quant_out_scale_offset) {
int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden
int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size
bool check = ((m_id < m) && (n_id < n));
if (check) {
float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id];
output[n_id * m + m_id] =
static_cast<T>(static_cast<float>(input[n_id * m + m_id]) *
quant_in_scale / out_scale);
}
}
template <typename T>
void dequantize_kernel_launcher(const int32_t* input,
T* output,
const int batch_size, // m
const int hidden_units, // n
gpuStream_t stream,
const float quant_in_scale,
const float* dequant_out_scale_data,
const int quant_out_scale_offset) {
dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32);
dim3 block(32, 32);
dequantize_kernel<<<grid, block, 0, stream>>>(output,
input,
hidden_units,
batch_size,
quant_in_scale,
dequant_out_scale_data,
quant_out_scale_offset);
}
} // namespace operators
} // namespace paddle
......@@ -24,6 +24,7 @@ namespace cub = hipcub;
#include <iostream>
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
......@@ -338,16 +339,24 @@ using LayerNormScaleBiasT =
template <typename T,
typename U,
int BlockDim,
bool ScaleBiasWithSameTypeX = false>
bool ScaleBiasWithSameTypeX = false,
typename InType = T,
typename OutType = T>
__global__ void LayerNormForward(
const T *x,
const InType *x,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias,
T *y,
OutType *y,
U *mean,
U *var,
float epsilon,
int64_t feature_size) {
int64_t feature_size,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim /
......@@ -387,28 +396,72 @@ __global__ void LayerNormForward(
if (bias != nullptr) {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(static_cast<U>(scale[j]) *
if (std::is_same<OutType, int8_t>::value) {
y[i] = quant_helper(
static_cast<T>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar +
static_cast<U>(bias[j])),
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
y[i] = static_cast<OutType>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) *
invvar +
static_cast<U>(bias[j]));
}
}
} else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(static_cast<U>(scale[j]) *
if (std::is_same<OutType, int8_t>::value) {
y[i] = quant_helper(
static_cast<T>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar),
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
y[i] =
static_cast<OutType>(static_cast<U>(scale[j]) *
(static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
} else { // scale == nullptr
if (bias != nullptr) {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
if (std::is_same<OutType, int8_t>::value) {
y[i] = quant_helper(
static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
static_cast<U>(bias[j])),
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
y[i] =
static_cast<OutType>((static_cast<U>(x[i]) - mean_val) * invvar +
static_cast<U>(bias[j]));
}
}
} else {
for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
if (std::is_same<OutType, int8_t>::value) {
y[i] = quant_helper(
static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar),
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
y[i] =
static_cast<OutType>((static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
}
......
......@@ -59,7 +59,9 @@ namespace dynload {
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
__macro(cublasLtMatrixTransformDescSetAttribute); \
__macro(cublasLtMatmulAlgoInit); \
__macro(cublasLtMatmulAlgoConfigSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP)
// #endif
......
......@@ -71,6 +71,12 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"FFN1Bias",
"FFN2Weight",
"FFN2Bias"}},
{"fused_multi_transformer_int8",
{"X", "LnScale", "LnBias", "QKVW",
"QKVBias", "CacheKV", "TimeStep", "SrcMask",
"OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias",
"FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias",
"QKVOutScale", "OutLinearOutScale", "FFN1OutScale", "FFN2OutScale"}},
{"fused_bias_dropout_residual_layer_norm",
{"X", "Residual", "Bias", "LnScale", "LnBias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
......@@ -329,6 +335,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta2PowOut",
"MasterParamOut"}},
{"fused_multi_transformer", {"CacheKVOut", "Out"}},
{"fused_multi_transformer_int8", {"CacheKVOut", "Out"}},
{"resnet_basic_block",
{"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out",
"Var1Out", "Conv2", "SavedMean2", "SavedInvstd2", "Mean2Out",
......@@ -433,6 +440,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"split", {"Out"}},
{"concat", {"Out"}},
{"fused_multi_transformer", {"CacheKVOut"}},
{"fused_multi_transformer_int8", {"CacheKVOut"}},
{"group_norm", {"Mean", "Variance"}},
{"resnet_basic_block",
{"Mean1Out", "Var1Out", "Mean2Out", "Var2Out", "Mean3Out", "Var3Out"}},
......
......@@ -73,7 +73,9 @@ extern void *cublasLt_dso_handle;
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
__macro(cublasLtMatrixTransformDescSetAttribute); \
__macro(cublasLtMatmulAlgoInit); \
__macro(cublasLtMatmulAlgoConfigSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP)
// #endif
......
......@@ -326,7 +326,7 @@ void* GetCublasDsoHandle() {
void* GetCublasLtDsoHandle() {
// APIs available after CUDA 10.1
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10100
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10010
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublasLt.so");
#else
std::string warning_msg(
......
......@@ -72,6 +72,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_attention_op)
list(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
......@@ -141,6 +142,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_complex_matmul)
list(REMOVE_ITEM TEST_OPS test_ops_nms)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......@@ -1202,6 +1204,10 @@ endif()
if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120)
endif()
if(WITH_GPU AND NOT WIN32)
set_tests_properties(test_fused_multi_transformer_int8_op PROPERTIES TIMEOUT
60)
endif()
set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120)
set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400)
set_tests_properties(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
import paddle.incubate.nn.functional as incubate_f
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn.initializer import Constant
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program
from paddle import _legacy_C_ops
default_main_program().random_seed = 42
np.random.seed(0)
def fused_multi_transformer_int8(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=True,
epsilon=1e-05,
cache_kvs=None,
time_step=None,
attn_mask=None,
dropout_rate=0.0,
activation="gelu",
training=False,
mode='upscale_in_train',
trans_qkvw=True,
ring_id=-1,
name=None,
qkv_out_scales=None,
out_linear_out_scales=None,
ffn1_out_scales=None,
ffn2_out_scales=None,
num_head=0,
dim_head=0,
dim_ffn=0,
qkv_in_scale=[],
out_linear_in_scale=[],
ffn1_in_scale=[],
ffn2_in_scale=[],
):
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_int8(
x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs, time_step,
attn_mask, linear_weights, linear_biases, ffn_ln_scales, ffn_ln_biases,
ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, qkv_out_scales,
out_linear_out_scales, ffn1_out_scales, ffn2_out_scales, cache_kvs,
'num_head', num_head, 'dim_head', dim_head, 'dim_ffn', dim_ffn,
'qkv_in_scale', qkv_in_scale, 'out_linear_in_scale',
out_linear_in_scale, 'ffn1_in_scale', ffn1_in_scale, 'ffn2_in_scale',
ffn2_in_scale, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
'dropout_rate', dropout_rate, 'is_test', not training,
'dropout_implementation', mode, 'act_method', activation, 'trans_qkvw',
trans_qkvw, 'ring_id', ring_id)
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
class TestFusedMultiTransformerInt8Op(unittest.TestCase):
def setUp(self):
self.config()
self.generate_input_data()
self.rtol = 1e-5
# FIXME(wangxi): Because there is a problem with the test precision
# on A100, atol is temporarily set to 1e-2, and it will be
# changed back after the precision problem is solved.
self.atol = 1e-2
# make sure local development precision
if "V100" in paddle.device.cuda.get_device_name():
self.atol = 1e-4
if self.x_type is np.float16:
self.atol = 1e-1
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_multi_transformer_int8"
# use autograd to check grad in this unittest.
self.__class__.no_need_check_grad = True
paddle.set_default_dtype(np.float32)
self.norm = LayerNorm(self.embed_dim,
weight_attr=False,
bias_attr=False)
self.ffn_norm = LayerNorm(self.embed_dim,
weight_attr=False,
bias_attr=False)
paddle.set_default_dtype(self.x_type)
self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")
self.activation = getattr(F, self.act_method)
def config(self):
# for debug
self.debug = False
self.x_type = np.float32
self.attn_mask_type = np.float64
#self.attn_mask_type = np.bool
self.pre_layer_norm = True
self.has_attn_mask = True
# has_cache_kv, gen_cache_kv, stage
# False, False, not generation
# True, True, generation context stage
# True, False, generation decoder stage
self.has_cache_kv = False
self.gen_cache_kv = False
self.training = False
self.layers = 3
self.batch_size = 1
self.query_length = 1
self.cache_length = 1
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.act_method = 'gelu'
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
q_weight = np.random.randint(-64, 64, [self.embed_dim, self.embed_dim],
np.int32).astype('float64')
k_weight = np.random.randint(-64, 64, [self.kdim, self.embed_dim],
np.int32).astype('float64')
v_weight = np.random.randint(-64, 64, [self.vdim, self.embed_dim],
np.int32).astype('float64')
self.q_weight_tensor = paddle.to_tensor(q_weight)
self.k_weight_tensor = paddle.to_tensor(k_weight)
self.v_weight_tensor = paddle.to_tensor(v_weight)
out_weight = np.random.randint(-64, 64,
[self.embed_dim, self.embed_dim],
np.int32).astype('float64')
ffn1_weight = np.random.randint(-64, 64,
[self.embed_dim, 4 * self.embed_dim],
np.int32).astype('float64')
ffn2_weight = np.random.randint(-64, 64,
[4 * self.embed_dim, self.embed_dim],
np.int32).astype('float64')
self.out_weight_tensor = paddle.to_tensor(out_weight)
self.ffn1_weight_tensor = paddle.to_tensor(ffn1_weight)
self.ffn2_weight_tensor = paddle.to_tensor(ffn2_weight)
q_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type)
k_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type)
v_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type)
self.q_proj_bias_tensor = paddle.to_tensor(q_proj_bias)
self.k_proj_bias_tensor = paddle.to_tensor(k_proj_bias)
self.v_proj_bias_tensor = paddle.to_tensor(v_proj_bias)
out_linear_proj_bias = np.random.rand(self.embed_dim).astype(
self.x_type)
ffn1_proj_bias = np.random.rand(4 * self.embed_dim).astype(self.x_type)
ffn2_proj_bias = np.random.rand(self.embed_dim).astype(self.x_type)
self.out_linear_proj_bias_tensor = paddle.to_tensor(
out_linear_proj_bias)
self.ffn1_proj_bias_tensor = paddle.to_tensor(ffn1_proj_bias)
self.ffn2_proj_bias_tensor = paddle.to_tensor(ffn2_proj_bias)
out_seq_len = self.key_length
self.qkv_in_scales = []
self.qkv_out_scales = []
self.out_linear_in_scales = []
self.out_linear_out_scales = []
self.ffn1_in_scales = []
self.ffn1_out_scales = []
self.ffn2_in_scales = []
self.ffn2_out_scales = []
if self.has_cache_kv:
self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads,
self.cache_length,
self.head_dim).astype(self.x_type)
if self.gen_cache_kv:
self.cache_kv[:] = 0
else:
out_seq_len += self.cache_length
else:
self.cache_kv = None
if self.has_attn_mask:
# [B, n_head, seq_len, out_seq_len]
self.attn_mask = np.ones(
(self.batch_size, 1, self.query_length, out_seq_len),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
if self.has_cache_kv and not self.gen_cache_kv:
# NOTE: decoder stage, -1(out_seq_len) should no mask
self.attn_mask[:, :, :, -2] = 0.0
self.attn_mask = (self.attn_mask - 1.0) * 1e4
else:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e4
elif self.attn_mask_type == np.bool_:
if self.has_cache_kv and not self.gen_cache_kv:
self.attn_mask[:, :, :, -2] = 0
else:
self.attn_mask = np.tril(self.attn_mask)
else:
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
def fake_quant(self, input, scale):
quant_value = 127.0 * (1.0 / scale) * paddle.cast(input, 'float32')
quant_value = paddle.round(quant_value)
# No need to clip here because scale is the max value
return paddle.cast(quant_value, 'float64')
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs = []
cache_kv = None
if self.has_cache_kv:
cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
for i in range(self.layers):
residual = tensor_query
ln1_out = tensor_query
if self.pre_layer_norm:
ln1_out = self.norm(tensor_query)
max_v = paddle.max(paddle.abs(paddle.cast(ln1_out, 'float32')))[0]
# self.qkv_in_scales.append(127.0 / max_v)
self.qkv_in_scales.append(max_v)
self.qkv_out_scales.append(127.0 * 127.0)
# print('qkv_in_scales ', i, self.qkv_in_scales[i])
# print('qkv_out_scales ', i, self.qkv_out_scales[i])
# quant ln1_out
ln1_out = self.fake_quant(ln1_out, self.qkv_in_scales[i])
q = paddle.nn.functional.linear(ln1_out, self.q_weight_tensor)
# de quant
q = paddle.cast(
paddle.cast(q, 'float32') * self.qkv_in_scales[i] /
self.qkv_out_scales[i], self.x_type)
q = q + self.q_proj_bias_tensor
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = paddle.nn.functional.linear(ln1_out, self.k_weight_tensor)
k = paddle.cast(
paddle.cast(k, 'float32') * self.qkv_in_scales[i] /
self.qkv_out_scales[i], self.x_type)
k = k + self.k_proj_bias_tensor
v = paddle.nn.functional.linear(ln1_out, self.v_weight_tensor)
v = paddle.cast(
paddle.cast(v, 'float32') * self.qkv_in_scales[i] /
self.qkv_out_scales[i], self.x_type)
v = v + self.v_proj_bias_tensor
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2)
cache_k = paddle.squeeze(cache_k, axis=0)
cache_v = paddle.squeeze(cache_v, axis=0)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
if self.debug:
print('q out is')
print(q_out[0, 0, :, :])
print('cache k out seq=128')
print(k_out[0, 0, :, :])
if self.gen_cache_kv:
cache_kvs.append((k_out, v_out))
else:
k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(x=q_out,
y=k_out,
transpose_y=True,
alpha=self.head_dim**-0.5)
if self.debug:
print('qk out is')
print(qk_out[0][0][0])
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
attn_mask_out = qk_out + attn_mask
if self.debug:
print('attn mask out is')
print(attn_mask_out[0][0][0])
softmax_out = F.softmax(attn_mask_out)
else:
softmax_out = F.softmax(qk_out)
if self.debug:
print('softmax out is')
print(softmax_out[0][0][0])
if self.dropout_prob:
dropout_out = F.dropout(softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
if self.debug:
print('fmha out is')
print(fmha_out[0][0][0])
out_linear_in = tensor.reshape(
x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
max_v = paddle.max(paddle.abs(paddle.cast(out_linear_in,
'float32')))[0]
# self.out_linear_in_scales.append(127.0 / max_v)
self.out_linear_in_scales.append(max_v)
self.out_linear_out_scales.append((127.0 * 127.0))
out_linear_in = self.fake_quant(out_linear_in,
self.out_linear_in_scales[i])
out = paddle.nn.functional.linear(out_linear_in,
self.out_weight_tensor)
out = paddle.cast(
paddle.cast(out, 'float32') * self.out_linear_in_scales[i] /
self.out_linear_out_scales[i], self.x_type)
out = out + self.out_linear_proj_bias_tensor
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
attn_out = self.norm(residual_out)
else:
attn_out = residual_out
ffn_ln_out = attn_out
if self.pre_layer_norm:
ffn_ln_out = self.ffn_norm(attn_out)
max_v = paddle.max(paddle.abs(paddle.cast(ffn_ln_out,
'float32')))[0]
self.ffn1_in_scales.append(max_v)
self.ffn1_out_scales.append((127.0 * 127.0))
ffn_ln_out = self.fake_quant(ffn_ln_out, self.ffn1_in_scales[i])
ffn1_out = paddle.nn.functional.linear(ffn_ln_out,
self.ffn1_weight_tensor)
ffn1_out = paddle.cast(
paddle.cast(ffn1_out, 'float32') * self.ffn1_in_scales[i] /
self.ffn1_out_scales[i], self.x_type)
ffn1_out = ffn1_out + self.ffn1_proj_bias_tensor
ffn1_out = self.dropout(self.activation(ffn1_out))
max_v = paddle.max(paddle.abs(paddle.cast(ffn1_out, 'float32')))[0]
# self.ffn2_in_scales.append(127.0 / max_v)
self.ffn2_in_scales.append(max_v)
self.ffn2_out_scales.append((127.0 * 127.0))
# print('ffn2_in_scales ', i, self.ffn2_in_scales[i])
ffn1_out = self.fake_quant(ffn1_out, self.ffn2_in_scales[i])
ffn2_out = paddle.nn.functional.linear(ffn1_out,
self.ffn2_weight_tensor)
ffn2_out = paddle.cast(
paddle.cast(ffn2_out, 'float32') * self.ffn2_in_scales[i] /
self.ffn2_out_scales[i], self.x_type)
ffn2_out = ffn2_out + self.ffn2_proj_bias_tensor
residual_out = attn_out + self.dropout(ffn2_out)
# print("residual ", attn_out)
# print("residual_out ", residual_out)
final_out = residual_out
if not self.pre_layer_norm:
final_out = self.ffn_norm(residual_out)
tensor_query = final_out
if self.has_cache_kv and self.gen_cache_kv:
return final_out, cache_kvs
return final_out
def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
ln_scale = paddle.ones([self.embed_dim], 'float32')
ln_bias = paddle.zeros([self.embed_dim], 'float32')
ffn_ln_scale = ln_scale
ffn_ln_bias = ln_bias
q_proj_weight = self.q_weight_tensor.numpy().transpose((1, 0))
k_proj_weight = self.k_weight_tensor.numpy().transpose((1, 0))
v_proj_weight = self.v_weight_tensor.numpy().transpose((1, 0))
qkv_weight = np.concatenate(
(q_proj_weight, k_proj_weight, v_proj_weight))
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim))
qkv_weight_tensor = paddle.to_tensor(qkv_weight)
qkv_weight_tensor = paddle.cast(qkv_weight_tensor, 'int8')
out_weight_tensor = paddle.cast(
paddle.to_tensor(self.out_weight_tensor.numpy().transpose((1, 0))),
'int8')
ffn1_weight_tensor = paddle.cast(
paddle.to_tensor(self.ffn1_weight_tensor.numpy().transpose((1, 0))),
'int8')
ffn2_weight_tensor = paddle.cast(
paddle.to_tensor(self.ffn2_weight_tensor.numpy().transpose((1, 0))),
'int8')
qkv_bias = np.concatenate(
(self.q_proj_bias_tensor.numpy(), self.k_proj_bias_tensor.numpy(),
self.v_proj_bias_tensor.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias)
x = paddle.to_tensor(self.query, stop_gradient=True)
cache_kvs, cache_kv = None, None
time_step = None
if self.has_cache_kv:
cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
elems = 4
if self.x_type is np.float16:
elems = 8
assert self.head_dim % elems == 0
v_elems = self.head_dim // elems
# [B, num_head, 128, head_dim]
# cache_k_tmp = self.cache_kv[0, :]
# [B, num_head, 128, head_dim / 4, 4]
cache_k_tmp = self.cache_kv[0].reshape([
self.batch_size, self.num_heads, self.cache_length, v_elems,
elems
])
# [B, num_head, head_dim / 4, 128, 4]
cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4])
cache_kv[0, :].reshape([
self.batch_size, self.num_heads, v_elems, max_seq_length, elems
])[:, :, :, :self.cache_length, :] = cache_k_tmp
cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1]
if self.gen_cache_kv:
assert self.query_length == self.cache_length
cache_kv[:] = 0
else:
time_step = paddle.to_tensor([self.cache_length],
dtype='int32',
place=paddle.CPUPlace())
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
else:
attn_mask = None
epsilon = 1e-05
ln2_epsilon = 1e-05
if attn_mask is not None and self.attn_mask_type != np.bool_:
attn_mask = _convert_attention_mask(attn_mask, x.dtype)
qkv_weights, qkv_biases = [], []
out_weights, out_biases = [], []
ln_scales, ln_biases = [], []
ffn1_weights, ffn1_biases = [], []
ffn2_weights, ffn2_biases = [], []
ffn_ln_scales, ffn_ln_biases = [], []
qkv_in_scale = []
out_linear_in_scale = []
ffn1_in_scale = []
ffn2_in_scale = []
qkv_out_scales_tensor = paddle.ones([self.layers, 3 * self.embed_dim],
'float32')
out_linear_out_scales_tensor = paddle.ones(
[self.layers, self.embed_dim], 'float32')
ffn1_out_scales_tensor = paddle.ones([self.layers, 4 * self.embed_dim],
'float32')
ffn2_out_scales_tensor = paddle.ones([self.layers, self.embed_dim],
'float32')
for i in range(self.layers):
qkv_weights.append(qkv_weight_tensor)
qkv_biases.append(qkv_bias_tensor)
out_weights.append(out_weight_tensor)
out_biases.append(self.out_linear_proj_bias_tensor)
ln_scales.append(ln_scale)
ln_biases.append(ln_bias)
ffn1_weights.append(ffn1_weight_tensor)
ffn1_biases.append(self.ffn1_proj_bias_tensor)
ffn2_weights.append(ffn2_weight_tensor)
ffn2_biases.append(self.ffn2_proj_bias_tensor)
ffn_ln_scales.append(ffn_ln_scale)
ffn_ln_biases.append(ffn_ln_bias)
qkv_in_scale.append(self.qkv_in_scales[i])
out_linear_in_scale.append(self.out_linear_in_scales[i])
ffn1_in_scale.append(self.ffn1_in_scales[i])
ffn2_in_scale.append(self.ffn2_in_scales[i])
qkv_out_scales_tensor[i, :] *= self.qkv_out_scales[i]
out_linear_out_scales_tensor[i, :] *= self.out_linear_out_scales[i]
ffn1_out_scales_tensor[i, :] *= self.ffn1_out_scales[i]
ffn2_out_scales_tensor[i, :] *= self.ffn2_out_scales[i]
if self.has_cache_kv:
cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=True))
final_out = fused_multi_transformer_int8(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
out_weights,
out_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
training=self.training,
mode='upscale_in_train',
trans_qkvw=True,
ring_id=-1,
name=None,
qkv_out_scales=qkv_out_scales_tensor,
out_linear_out_scales=out_linear_out_scales_tensor,
ffn1_out_scales=ffn1_out_scales_tensor,
ffn2_out_scales=ffn2_out_scales_tensor,
num_head=self.num_heads,
dim_head=self.head_dim,
dim_ffn=4 * self.embed_dim,
qkv_in_scale=qkv_in_scale,
out_linear_in_scale=out_linear_in_scale,
ffn1_in_scale=ffn1_in_scale,
ffn2_in_scale=ffn2_in_scale)
if self.has_cache_kv:
return final_out[0], final_out[1]
return final_out
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOut()
if self.has_cache_kv:
final_out, cache_kv_out = final_out
s = cache_kv_out[0].shape
bsz = s[1]
num_head = s[2]
max_seq_len = s[3]
head_dim = s[4]
elems = 8 if self.x_type is np.float16 else 4
v_elems = head_dim // elems
if self.debug:
print("cache_k out timestep=128")
print(cache_kv_out[0].reshape(
[2, bsz, num_head, v_elems, max_seq_len,
elems])[0, 0, 0, :, self.cache_length, :])
print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
if self.gen_cache_kv:
final_out_ref, cache_kvs = final_out_ref
for i in range(self.layers):
cache_k_ref = cache_kvs[i][0]
cache_v_ref = cache_kvs[i][1]
cache_k = cache_kv_out[i][0, :]
cache_k = cache_k.reshape(
[bsz, num_head, v_elems, max_seq_len, elems])
cache_k = cache_k[:, :, :, :self.cache_length, :]
cache_k = cache_k.transpose([0, 1, 3, 2, 4])
cache_k = cache_k.reshape(
[bsz, num_head, self.cache_length, head_dim])
cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :]
np.testing.assert_allclose(cache_k_ref,
cache_k,
rtol=self.rtol,
atol=self.atol)
np.testing.assert_allclose(cache_v_ref,
cache_v,
rtol=self.rtol,
atol=self.atol)
if i == 0:
break
np.testing.assert_allclose(final_out_ref,
final_out,
rtol=self.rtol,
atol=self.atol)
class TestFusedMultiTransformerInt8OpFp16(TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
class TestFusedMultiTransformerInt8OpCacheKV(TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
super().generate_input_data()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
class TestFusedMultiTransformerInt8OpCacheKVFp16(TestFusedMultiTransformerInt8Op
):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
class TestFusedMultiTransformerInt8OpGenCacheKV(TestFusedMultiTransformerInt8Op
):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
class TestFusedMultiTransformerInt8OpGenCacheKVFp16(
TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
class TestFusedMultiTransformerInt8OpPostLayerNormFp16(
TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerInt8OpCacheKVPostLayerNorm(
TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerInt8OpCacheKVPostLayerNormFp16(
TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.pre_layer_norm = False
class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNorm(
TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.pre_layer_norm = False
class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNormFp16(
TestFusedMultiTransformerInt8Op):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册