未验证 提交 3d7e2118 编写于 作者: R RichardWooSJTU 提交者: GitHub

Add INT8 support for fused_multi_transformer_op (#45284)

上级 7f346a76
......@@ -165,7 +165,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is "
<< var_data_type;
if (var_data_type == paddle::framework::proto::VarType::FP16) {
if (var_data_type == paddle::framework::proto::VarType::FP16 &&
t->dtype() != paddle::experimental::DataType::FLOAT16) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
......
......@@ -23,6 +23,7 @@ register_operators(
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
fused_multi_transformer_int8_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
......@@ -119,6 +120,7 @@ if(WITH_GPU OR WITH_ROCM)
# fused_attention_op
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
op_library(fused_multi_transformer_int8_op)
op_library(fused_bias_dropout_residual_layer_norm_op)
endif()
# resnet_unit needs cudnn 8.0 above
......
......@@ -19,7 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
// NOTE: T must be the same as OutType in ComputeBackward
template <typename T, typename InType = T, typename OutType = T>
class AttnLayerNorm {
public:
AttnLayerNorm(const phi::GPUContext& dev_ctx,
......@@ -33,17 +34,28 @@ class AttnLayerNorm {
~AttnLayerNorm() {}
void ComputeForward(const T* x_data,
void ComputeForward(const InType* x_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* bias_data,
T* y_data,
OutType* y_data,
LayerNormParamType<T>* mean_data,
LayerNormParamType<T>* var_data) {
LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();
switch (GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, LayerNormParamType<T>, kBlockDim>
LayerNormForward<T,
LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
......@@ -51,7 +63,13 @@ class AttnLayerNorm {
mean_data,
var_data,
epsilon_,
feature_size_));
feature_size_,
dequant_out_scale_data,
quant_out_scale_offset,
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Feature_size must be larger than 1"));
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/operators/fused/cublaslt.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class AttnMatmulINT8 {
public:
AttnMatmulINT8(
const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias)
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper);
}
~AttnMatmulINT8() {}
// This function is used to execute GEMM, with input and output's types are
// both T.
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const float quant_in_scale,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());
dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}
// This function is used to execute GEMM, with input and output's types are
// both INT8.
void ComputeForwardINT8ToINT8(const framework::Tensor* weight,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}
// This function is used to execute GEMM, with input and output's types are
// INT8 and T.
void ComputeForwardINT8ToT(const framework::Tensor* weight,
const float quant_in_scale,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());
dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}
// This function is used to execute GEMM, with input and output's types are T
// and INT8.
void ComputeForwardTToINT8(const framework::Tensor* weight,
const float quant_in_scale,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}
private:
const phi::GPUContext& dev_ctx_;
int m_; // m
int n_; // n
int k_; // k
int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
namespace dyl = paddle::platform::dynload;
namespace paddle {
namespace operators {
class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else
status = dyl::cublasLtMatmulDescCreate(
&matmul_desc_, cudaComputeType, CUDA_R_32I);
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatmulDescCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
cublasOperation_t op_transpose = CUBLAS_OP_T;
status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op_transpose,
sizeof(op_transpose));
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatmulDescSetAttribute execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
// matrix desc
status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
~CublasLtHelper() {
if (handle_) dyl::cublasLtDestroy(handle_);
if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_);
if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_);
if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_);
if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_);
}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream) {
cublasStatus_t status;
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
cublasLtMatmulAlgo_t algo;
int algoId = 21;
int swizzle = 0;
int customOption = 0;
int tile = 15;
int splitK_val = 0;
int reductionScheme = 0;
#if CUDA_VERSION >= 11000
int stages = 23;
#endif
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
dyl::cublasLtMatmulAlgoInit(handle_,
cudaComputeType,
CUDA_R_32I,
CUDA_R_8I,
CUDA_R_8I,
CUDA_R_32I,
CUDA_R_32I,
algoId,
&algo);
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption),
sizeof(customOption));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val),
sizeof(splitK_val));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme),
sizeof(int));
#if CUDA_VERSION >= 11000
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif
#endif
status = dyl::cublasLtMatmul(handle_,
matmul_desc_,
&alpha_,
B_dev,
B_desc_,
A_dev,
A_desc_,
&beta_,
C_dev,
C_desc_,
C_dev,
C_desc_,
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
&algo,
#else
nullptr,
#endif
nullptr,
0,
stream);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
platform::errors::External(
"cublasLtMatmul execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
private:
cublasLtHandle_t handle_;
cublasLtMatmulDesc_t matmul_desc_;
cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_;
int32_t alpha_;
int32_t beta_;
int m_;
int k_;
int n_;
};
} // namespace operators
} // namespace paddle
......@@ -60,19 +60,32 @@ struct GeluGradFunctor {
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActBias(Functor act,
const uint64_t seed,
const uint64_t rows,
const uint64_t cols,
const int increment,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const T *__restrict__ src,
const T *__restrict__ bias,
T *dst,
MaskType *mask) {
template <typename T,
typename MaskType,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedDropoutActBias(
Functor act,
const uint64_t seed,
const uint64_t rows,
const uint64_t cols,
const int increment,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst,
MaskType *mask,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
......@@ -90,7 +103,9 @@ __global__ void FusedDropoutActBias(Functor act,
VecSize,
false,
true,
Functor>(r,
Functor,
InType,
OutType>(r,
i,
cols,
&state,
......@@ -104,7 +119,14 @@ __global__ void FusedDropoutActBias(Functor act,
is_test,
nullptr,
nullptr,
act);
act,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
}
......@@ -112,7 +134,11 @@ __global__ void FusedDropoutActBias(Functor act,
/**
* @brief dst = dropout(activation(src + bias));
*/
template <typename T, typename MaskType, typename Functor>
template <typename T,
typename MaskType,
typename Functor,
typename InType = T,
typename OutType = T>
void LaunchDropoutActBias(Functor act_functor,
const uint64_t seed,
const uint32_t rows,
......@@ -121,14 +147,21 @@ void LaunchDropoutActBias(Functor act_functor,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const T *src,
const InType *src,
const T *bias,
T *dst,
OutType *dst,
MaskType *mask_data,
const phi::GPUContext &ctx) {
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
SetZero<T>(ctx, dst, rows * cols);
SetZero<T>(ctx, reinterpret_cast<T *>(dst), rows * cols);
SetZero<MaskType>(ctx, mask_data, rows * cols);
return;
}
......@@ -137,7 +170,7 @@ void LaunchDropoutActBias(Functor act_functor,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedDropoutActBias<T, MaskType, VecSize, Functor>
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
......@@ -150,9 +183,13 @@ void LaunchDropoutActBias(Functor act_functor,
src,
bias,
dst,
mask_data);
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
} else {
FusedDropoutActBias<T, MaskType, 1, Functor>
FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
......@@ -165,7 +202,11 @@ void LaunchDropoutActBias(Functor act_functor,
src,
bias,
dst,
mask_data);
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
......
......@@ -109,7 +109,10 @@ struct DropoutParam {
}
};
template <typename T, typename MaskType>
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutHelper {
private:
int GetIncrement(const phi::GPUContext& ctx) {
......@@ -140,25 +143,34 @@ class FusedDropoutHelper {
// out = residual + dropout( src + bias )
void ResidualDropoutBias(const phi::GPUContext& ctx,
const T* src,
const InType* src,
const T* residual,
const T* bias,
T* out,
MaskType* mask) {
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType>(rows_,
cols_,
increment,
dropout_param_.seed,
dropout_param_.dropout_prob,
dropout_param_.is_test,
dropout_param_.is_upscale_in_train,
src,
residual,
bias,
mask,
out,
ctx);
LaunchResidualDropoutBias<T, MaskType, InType, OutType>(
rows_,
cols_,
increment,
dropout_param_.seed,
dropout_param_.dropout_prob,
dropout_param_.is_test,
dropout_param_.is_upscale_in_train,
src,
residual,
bias,
mask,
out,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
void ResidualDropoutBiasGrad(const phi::GPUContext& ctx,
......@@ -189,15 +201,22 @@ class FusedDropoutHelper {
// out = dropout(activation(src + bias))
void DropoutActBias(const phi::GPUContext& ctx,
const T* src,
const InType* src,
const T* bias,
const std::string& act_method,
T* out,
MaskType* mask) {
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto increment = GetIncrement(ctx);
if (act_method == "gelu") {
GeluFunctor<T> gelu;
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>>(
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>, InType, OutType>(
gelu,
dropout_param_.seed,
rows_,
......@@ -210,23 +229,40 @@ class FusedDropoutHelper {
bias,
out,
mask,
ctx);
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else if (act_method == "relu") {
phi::funcs::ReluFunctor<T> relu;
LaunchDropoutActBias<T, MaskType, phi::funcs::ReluFunctor<T>>(
relu,
dropout_param_.seed,
rows_,
cols_,
increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx);
LaunchDropoutActBias<T,
MaskType,
phi::funcs::ReluFunctor<T>,
InType,
OutType>(relu,
dropout_param_.seed,
rows_,
cols_,
increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
......@@ -283,8 +319,12 @@ class FusedDropoutHelper {
DropoutParam dropout_param_;
};
template <typename T, typename MaskType>
class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutLayerNormHelper
: public FusedDropoutHelper<T, MaskType, InType, OutType> {
public:
FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows,
......@@ -301,23 +341,24 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
const int cols,
const DropoutParam& dropout_param,
const float epsilon)
: FusedDropoutHelper<T, MaskType>(ctx, rows, cols, dropout_param) {
: FusedDropoutHelper<T, MaskType, InType, OutType>(
ctx, rows, cols, dropout_param) {
using U = LayerNormParamType<T>;
epsilon_ = epsilon;
}
// call layer_norm
void LayerNorm(const phi::GPUContext& ctx,
const T* src,
const InType* src,
const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* beta,
T* out,
OutType* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
switch (GetDesiredBlockDim(this->cols_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U, kBlockDim>
LayerNormForward<T, U, kBlockDim, false, InType, OutType>
<<<this->rows_, kBlockDim, 0, ctx.stream()>>>(
src, gamma, beta, out, mean, variance, epsilon_, this->cols_));
}
......@@ -349,17 +390,25 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
// out = layernorm(residual + dropout(src + bias))
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
void LayernormResidualDropoutBias(const phi::GPUContext& ctx,
const T* src,
const T* residual,
const T* bias,
const P* gamma,
const P* beta,
T* dropout_out,
MaskType* mask,
T* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
void LayernormResidualDropoutBias(
const phi::GPUContext& ctx,
const InType* src,
const T* residual,
const T* bias,
const P* gamma,
const P* beta,
T* dropout_out,
MaskType* mask,
OutType* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using U = LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
......@@ -368,7 +417,12 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T, MaskType, U, is_same_type>(
LaunchLayernormResidualDropoutBias<T,
MaskType,
U,
is_same_type,
InType,
OutType>(
this->rows_,
this->cols_,
increment,
......@@ -387,7 +441,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
out,
mean,
variance,
ctx);
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel {
private:
static constexpr const char *OpName = "FusedMultiTransformerINT8Op";
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
#define CHECK_INPUT(name) \
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName)
#define CHECK_INPUTS(name) \
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName)
#define CHECK_OUTPUT(name) \
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName)
#define CHECK_OUTPUTS(name) \
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName)
CHECK_INPUT(X);
// attention
CHECK_INPUTS(QKVW);
CHECK_INPUTS(OutLinearW);
if (ctx->HasInput("TimeStep")) {
CHECK_INPUTS(CacheKV);
}
if (ctx->HasInputs("CacheKV")) {
CHECK_OUTPUTS(CacheKVOut);
}
// ffn
CHECK_INPUTS(FFN1Weight);
CHECK_INPUTS(FFN2Weight);
CHECK_OUTPUT(Out);
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputsDim("QKVW")[0];
bool trans_qkvw = ctx->Attrs().Get<bool>("trans_qkvw");
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
platform::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(y_dim.size(),
4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(
x_dim[2],
trans_qkvw ? y_dim[3] : y_dim[0],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
if (trans_qkvw) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
} else {
PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3],
y_dim[0],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
}
if (ctx->HasInputs("CacheKV")) {
// [2, batch_size, num_head, max_seq_len, head_size]
const auto &c_dims = ctx->GetInputsDim("CacheKV");
const auto &c_dim = c_dims[0];
PADDLE_ENFORCE_EQ(
c_dim.size(),
5,
paddle::platform::errors::InvalidArgument(
"The CacheKV must be 5 dims, but got %d", c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0],
2,
paddle::platform::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1],
x_dim[0],
paddle::platform::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
trans_qkvw ? y_dim[1] : y_dim[2],
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3],
0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "TimeStep") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class FusedMultiTransformerINT8OpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("LnScale",
"Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("LnBias",
"Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("QKVW", "The qkv weight tensor.").AsDuplicable();
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable();
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
AddInput("OutLinearBias", "The out_linear bias tensor.")
.AsDispensable()
.AsDuplicable();
AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op")
.AsDuplicable();
AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddInput("QKVOutScale",
"QKVOutScale is used to dequantize qkv output tensor."
"In order to keep consistent with the PTQ/QAT calculation logic,"
"QKVOutScale should be max_bound * max_bound / max_range."
"Here max_range is per-channel weight scale."
"The shape of QKVOutScale is [num_layers, num_channels]")
.AsDispensable();
AddInput("OutLinearOutScale",
"OutLinearOutScale is used to dequantize out_linear output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
AddInput("FFN1OutScale",
"FFN1OutScale is used to dequantize ffn1 output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
AddInput("FFN2OutScale",
"FFN2OutScale is used to dequantize ffn2 output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "Result after multi .");
AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
"[default true].")
.SetDefault(true);
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
true,
platform::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));
});
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f,
true,
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"The meaning is the same as 'attn_dropout_implementation'.")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train",
true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<bool>(
"trans_qkvw",
"Whether the weights of qkv should be transposed. If true,"
"the shape eights of qkv should be [3, num_head, dim_head, dim_embed]."
"Otherwise the shape of weights of qkv should be"
"[dim_embed, 3, num_head, dim_head]")
.SetDefault(true);
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1);
AddAttr<int>("num_head", "num_head").SetDefault(0);
AddAttr<int>("dim_head", "dim_head").SetDefault(0);
AddAttr<int>("dim_ffn", "dim_ffn").SetDefault(0);
AddAttr<std::vector<float>>(
"qkv_in_scale",
"qkv_in_scale is used to quantize qkv input tensor."
"in_scale is generated by PTQ or QAT, which represents valid max range "
"of this tensor."
"the size of qkv_in_scale should be num_layers, which is equal to "
"QKVW.dims()[0]")
.SetDefault({});
AddAttr<std::vector<float>>(
"out_linear_in_scale",
"out_linear_in_scale is used to quantize out_linear input tensor."
"the size of out_linear_in_scale is the same as qkv_in_scale")
.SetDefault({});
AddAttr<std::vector<float>>(
"ffn1_in_scale",
"ffn1_in_scale is used to quantize ffn1 input tensor."
"the size of ffn1_in_scale is the same as qkv_in_scale")
.SetDefault({});
AddAttr<std::vector<float>>(
"ffn2_in_scale",
"ffn2_in_scale is used to quantize ffn2 input tensor."
"the size of ffn2_in_scale is the same as qkv_in_scale")
.SetDefault({});
AddAttr<int>(
"quant_round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(-2.5)=-3")
.SetDefault(1);
AddAttr<float>(
"quant_max_bound",
"(float, default 127.0) the max bound of float type to int type")
.SetDefault(127.0);
AddAttr<float>(
"quant_min_bound",
"(float, default -127.0) the min bound of float type to int type")
.SetDefault(-127.0);
AddComment(R"DOC(fused multi transformer layers op)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_multi_transformer_int8,
ops::FusedMultiTransformerINT8Op,
ops::FusedMultiTransformerINT8OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -28,7 +28,9 @@ template <typename T,
int VecSize,
bool ComputeLayerNorm,
bool Activation,
typename Functor>
typename Functor,
typename InType = T,
typename OutType = T>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id,
const int col_id,
......@@ -36,30 +38,45 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
curandStatePhilox4_32_10_t *state,
const float dropout_prob,
const T factor,
const T *__restrict__ src,
const InType *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
T *dst,
OutType *dst,
MaskType *mask,
const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val,
Functor act_func) {
Functor act_func,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
using LoadFloat = phi::AlignedVector<float, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
using U = typename details::MPTypeTrait<T>::Type;
LoadT src_vec;
LoadInType src_vec;
LoadT residual_vec;
LoadT bias_vec;
LoadFloat quant_out_scale_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
residual_vec[ii] = static_cast<T>(0);
}
// vectorize load data from global
phi::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<float, VecSize>(
&dequant_out_scale_data[quant_out_scale_offset + col_id],
&quant_out_scale_vec);
if (residual) {
phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
}
......@@ -84,10 +101,18 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
}
StoreT dest_vec;
StoreOutType dest_vec_out_type;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T tmp = src_vec[ii] + bias_vec[ii];
T tmp;
if (std::is_same<InType, int32_t>::value) {
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_last_in_scale / quant_out_scale_vec[ii]);
tmp = tmp0 + bias_vec[ii];
} else {
tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
}
if (Activation) {
tmp = act_func(tmp);
}
......@@ -98,10 +123,23 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
*mean_val += tmp;
*var_val += (tmp * tmp);
}
if (std::is_same<OutType, int8_t>::value) {
dest_vec_out_type[ii] = quant_helper(dest_vec[ii],
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
// store result to global
phi::Store<T, VecSize>(dest_vec, &dst[row_id * cols + col_id]);
if (std::is_same<OutType, int8_t>::value) {
phi::Store<OutType, VecSize>(dest_vec_out_type,
&dst[row_id * cols + col_id]);
} else {
phi::Store<T, VecSize>(dest_vec,
reinterpret_cast<T *>(&dst[row_id * cols + col_id]));
}
if (!is_test) {
phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
}
......@@ -114,19 +152,28 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
* is_test: only used in inference
* mask: can be null if is_test=true
*/
template <typename T, typename MaskType, int VecSize>
__global__ void FusedResidualDropoutBias(const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const T *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
MaskType *mask,
T *dst,
uint64_t increment,
const bool is_test) {
template <typename T,
typename MaskType,
int VecSize,
typename InType = T,
typename OutType = T>
__global__ void FusedResidualDropoutBias(
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const InType *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
MaskType *mask,
OutType *dst,
uint64_t increment,
const bool is_test,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
......@@ -142,22 +189,27 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
VecSize,
false,
false,
phi::funcs::ReluFunctor<T>>(
r,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
relu);
phi::funcs::ReluFunctor<T>,
InType,
OutType>(r,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
relu,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
}
......@@ -165,7 +217,10 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
/**
* @brief dst = residual + dropout(src + bias);
*/
template <typename T, typename MaskType>
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
void LaunchResidualDropoutBias(const uint32_t rows,
const uint32_t cols,
const int increment,
......@@ -173,14 +228,19 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const float dropout_prob,
const bool is_test,
bool is_upscale_in_train,
const T *src,
const InType *src,
const T *residual,
const T *bias,
MaskType *mask_data,
T *dst,
const phi::GPUContext &ctx) {
OutType *dst,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
// NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
if (residual == dst) return;
if (residual) {
memory::Copy(ctx.GetPlace(),
......@@ -202,7 +262,7 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedResidualDropoutBias<T, uint8_t, VecSize>
FusedResidualDropoutBias<T, uint8_t, VecSize, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
......@@ -215,9 +275,13 @@ void LaunchResidualDropoutBias(const uint32_t rows,
mask_data,
dst,
increment,
is_test);
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
} else {
FusedResidualDropoutBias<T, uint8_t, 1>
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
......@@ -230,7 +294,11 @@ void LaunchResidualDropoutBias(const uint32_t rows,
mask_data,
dst,
increment,
is_test);
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......
此差异已折叠。
......@@ -40,26 +40,28 @@ namespace dynload {
// APIs available after CUDA 10.1
// #if CUDA_VERSION >= 10100
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute); \
__macro(cublasLtMatmulAlgoInit); \
__macro(cublasLtMatmulAlgoConfigSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP)
// #endif
......
......@@ -71,6 +71,12 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"FFN1Bias",
"FFN2Weight",
"FFN2Bias"}},
{"fused_multi_transformer_int8",
{"X", "LnScale", "LnBias", "QKVW",
"QKVBias", "CacheKV", "TimeStep", "SrcMask",
"OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias",
"FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias",
"QKVOutScale", "OutLinearOutScale", "FFN1OutScale", "FFN2OutScale"}},
{"fused_bias_dropout_residual_layer_norm",
{"X", "Residual", "Bias", "LnScale", "LnBias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
......@@ -329,6 +335,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta2PowOut",
"MasterParamOut"}},
{"fused_multi_transformer", {"CacheKVOut", "Out"}},
{"fused_multi_transformer_int8", {"CacheKVOut", "Out"}},
{"resnet_basic_block",
{"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out",
"Var1Out", "Conv2", "SavedMean2", "SavedInvstd2", "Mean2Out",
......@@ -433,6 +440,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"split", {"Out"}},
{"concat", {"Out"}},
{"fused_multi_transformer", {"CacheKVOut"}},
{"fused_multi_transformer_int8", {"CacheKVOut"}},
{"group_norm", {"Mean", "Variance"}},
{"resnet_basic_block",
{"Mean1Out", "Var1Out", "Mean2Out", "Var2Out", "Mean3Out", "Var3Out"}},
......
......@@ -326,7 +326,7 @@ void* GetCublasDsoHandle() {
void* GetCublasLtDsoHandle() {
// APIs available after CUDA 10.1
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10100
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10010
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublasLt.so");
#else
std::string warning_msg(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册