Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3d7e2118
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d7e2118
编写于
9月 18, 2022
作者:
R
RichardWooSJTU
提交者:
GitHub
9月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add INT8 support for fused_multi_transformer_op (#45284)
上级
7f346a76
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
4168 addition
and
1428 deletion
+4168
-1428
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
...ence/analysis/passes/ir_params_sync_among_devices_pass.cc
+2
-1
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+2
-0
paddle/fluid/operators/fused/attention_layer_norm.h
paddle/fluid/operators/fused/attention_layer_norm.h
+24
-6
paddle/fluid/operators/fused/attn_gemm_int8.h
paddle/fluid/operators/fused/attn_gemm_int8.h
+189
-0
paddle/fluid/operators/fused/cublaslt.h
paddle/fluid/operators/fused/cublaslt.h
+211
-0
paddle/fluid/operators/fused/fused_dropout_act_bias.h
paddle/fluid/operators/fused/fused_dropout_act_bias.h
+65
-24
paddle/fluid/operators/fused/fused_dropout_common.h
paddle/fluid/operators/fused/fused_dropout_common.h
+1
-0
paddle/fluid/operators/fused/fused_dropout_helper.h
paddle/fluid/operators/fused/fused_dropout_helper.h
+116
-55
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+188
-97
paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc
.../fluid/operators/fused/fused_multi_transformer_int8_op.cc
+369
-0
paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu
.../fluid/operators/fused/fused_multi_transformer_int8_op.cu
+670
-0
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+5
-1147
paddle/fluid/operators/fused/fused_multi_transformer_op.h
paddle/fluid/operators/fused/fused_multi_transformer_op.h
+1161
-0
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
+113
-45
paddle/fluid/operators/fused/quant_dequant_kernel.h
paddle/fluid/operators/fused/quant_dequant_kernel.h
+136
-0
paddle/fluid/operators/layer_norm_kernel.cu.h
paddle/fluid/operators/layer_norm_kernel.cu.h
+65
-12
paddle/fluid/platform/dynload/cublasLt.h
paddle/fluid/platform/dynload/cublasLt.h
+22
-20
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+8
-0
paddle/phi/backends/dynload/cublasLt.h
paddle/phi/backends/dynload/cublasLt.h
+22
-20
paddle/phi/backends/dynload/dynamic_loader.cc
paddle/phi/backends/dynload/dynamic_loader.cc
+1
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+6
-0
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py
...d/tests/unittests/test_fused_multi_transformer_int8_op.py
+792
-0
未找到文件。
paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc
浏览文件 @
3d7e2118
...
...
@@ -165,7 +165,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto
var_data_type
=
var_node
->
Var
()
->
GetDataType
();
VLOG
(
5
)
<<
"var_name is "
<<
var_name
<<
", data type is "
<<
var_data_type
;
if
(
var_data_type
==
paddle
::
framework
::
proto
::
VarType
::
FP16
)
{
if
(
var_data_type
==
paddle
::
framework
::
proto
::
VarType
::
FP16
&&
t
->
dtype
()
!=
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
framework
::
Tensor
half_tensor
;
half_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
half_tensor
.
Resize
(
t
->
dims
());
...
...
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
3d7e2118
...
...
@@ -23,6 +23,7 @@ register_operators(
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
fused_multi_transformer_int8_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
...
...
@@ -119,6 +120,7 @@ if(WITH_GPU OR WITH_ROCM)
# fused_attention_op
op_library
(
fused_attention_op
)
op_library
(
fused_multi_transformer_op
)
op_library
(
fused_multi_transformer_int8_op
)
op_library
(
fused_bias_dropout_residual_layer_norm_op
)
endif
()
# resnet_unit needs cudnn 8.0 above
...
...
paddle/fluid/operators/fused/attention_layer_norm.h
浏览文件 @
3d7e2118
...
...
@@ -19,7 +19,8 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
// NOTE: T must be the same as OutType in ComputeBackward
template
<
typename
T
,
typename
InType
=
T
,
typename
OutType
=
T
>
class
AttnLayerNorm
{
public:
AttnLayerNorm
(
const
phi
::
GPUContext
&
dev_ctx
,
...
...
@@ -33,17 +34,28 @@ class AttnLayerNorm {
~
AttnLayerNorm
()
{}
void
ComputeForward
(
const
T
*
x_data
,
void
ComputeForward
(
const
InType
*
x_data
,
const
LayerNormParamType
<
T
>*
scale_data
,
const
LayerNormParamType
<
T
>*
bias_data
,
T
*
y_data
,
OutType
*
y_data
,
LayerNormParamType
<
T
>*
mean_data
,
LayerNormParamType
<
T
>*
var_data
)
{
LayerNormParamType
<
T
>*
var_data
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
auto
stream
=
dev_ctx_
.
stream
();
switch
(
GetDesiredBlockDim
(
feature_size_
))
{
FIXED_BLOCK_DIM_CASE
(
LayerNormForward
<
T
,
LayerNormParamType
<
T
>
,
kBlockDim
>
LayerNormForward
<
T
,
LayerNormParamType
<
T
>
,
kBlockDim
,
false
,
InType
,
OutType
>
<<<
batch_size_
,
kBlockDim
,
0
,
stream
>>>
(
x_data
,
scale_data
,
bias_data
,
...
...
@@ -51,7 +63,13 @@ class AttnLayerNorm {
mean_data
,
var_data
,
epsilon_
,
feature_size_
));
feature_size_
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
));
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Feature_size must be larger than 1"
));
...
...
paddle/fluid/operators/fused/attn_gemm_int8.h
0 → 100644
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/operators/fused/cublaslt.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
AttnMatmulINT8
{
public:
AttnMatmulINT8
(
const
phi
::
GPUContext
&
dev_ctx
,
int
m
,
int
n
,
int
k
,
bool
compute_bias
)
:
dev_ctx_
(
dev_ctx
),
m_
(
m
),
n_
(
n
),
k_
(
k
),
compute_bias_
(
compute_bias
)
{
auto
helper
=
std
::
make_shared
<
CublasLtHelper
>
(
m
,
k
,
n
);
helpers_
.
emplace_back
(
helper
);
}
~
AttnMatmulINT8
()
{}
// This function is used to execute GEMM, with input and output's types are
// both T.
void
ComputeForward
(
const
framework
::
Tensor
*
weight
,
const
framework
::
Tensor
*
input
,
framework
::
Tensor
*
input_tmp
,
const
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
output_tmp
,
framework
::
Tensor
*
bias_out
,
const
float
quant_in_scale
,
const
framework
::
Tensor
*
dequant_out_scale
,
const
int
quant_out_scale_offset
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
quantize_kernel_launcher
<
T
>
(
input
->
data
<
T
>
(),
input_tmp
->
data
<
int8_t
>
(),
quant_in_scale
,
m_
,
k_
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
dev_ctx_
.
stream
());
helpers_
[
0
]
->
GEMM
(
input_tmp
->
data
<
int8_t
>
(),
weight
->
data
<
int8_t
>
(),
output_tmp
->
data
<
int32_t
>
(),
dev_ctx_
.
stream
());
dequantize_kernel_launcher
<
T
>
(
output_tmp
->
data
<
int32_t
>
(),
output
->
data
<
T
>
(),
m_
,
n_
,
dev_ctx_
.
stream
(),
quant_in_scale
,
dequant_out_scale
->
data
<
float
>
(),
quant_out_scale_offset
);
if
(
compute_bias_
)
{
// bias_out = output + bias
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
output
,
bias
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
bias_out
};
phi
::
funcs
::
BroadcastKernel
<
phi
::
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
-
1
,
phi
::
funcs
::
AddFunctor
<
T
>
());
PADDLE_ENFORCE_EQ
(
cudaGetLastError
(),
cudaSuccess
,
platform
::
errors
::
Fatal
(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"
));
}
}
// This function is used to execute GEMM, with input and output's types are
// both INT8.
void
ComputeForwardINT8ToINT8
(
const
framework
::
Tensor
*
weight
,
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias_out
)
{
helpers_
[
0
]
->
GEMM
(
input
->
data
<
int8_t
>
(),
weight
->
data
<
int8_t
>
(),
output
->
data
<
int32_t
>
(),
dev_ctx_
.
stream
());
}
// This function is used to execute GEMM, with input and output's types are
// INT8 and T.
void
ComputeForwardINT8ToT
(
const
framework
::
Tensor
*
weight
,
const
float
quant_in_scale
,
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
output_tmp
,
framework
::
Tensor
*
bias_out
,
const
framework
::
Tensor
*
dequant_out_scale
,
const
int
quant_out_scale_offset
)
{
helpers_
[
0
]
->
GEMM
(
input
->
data
<
int8_t
>
(),
weight
->
data
<
int8_t
>
(),
output_tmp
->
data
<
int32_t
>
(),
dev_ctx_
.
stream
());
dequantize_kernel_launcher
<
T
>
(
output_tmp
->
data
<
int32_t
>
(),
output
->
data
<
T
>
(),
m_
,
n_
,
dev_ctx_
.
stream
(),
quant_in_scale
,
dequant_out_scale
->
data
<
float
>
(),
quant_out_scale_offset
);
if
(
compute_bias_
)
{
// bias_out = output + bias
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
output
,
bias
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
bias_out
};
phi
::
funcs
::
BroadcastKernel
<
phi
::
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
-
1
,
phi
::
funcs
::
AddFunctor
<
T
>
());
PADDLE_ENFORCE_EQ
(
cudaGetLastError
(),
cudaSuccess
,
platform
::
errors
::
Fatal
(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"
));
}
}
// This function is used to execute GEMM, with input and output's types are T
// and INT8.
void
ComputeForwardTToINT8
(
const
framework
::
Tensor
*
weight
,
const
float
quant_in_scale
,
const
framework
::
Tensor
*
input
,
framework
::
Tensor
*
input_tmp
,
const
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias_out
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
quantize_kernel_launcher
<
T
>
(
input
->
data
<
T
>
(),
input_tmp
->
data
<
int8_t
>
(),
quant_in_scale
,
m_
,
k_
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
,
dev_ctx_
.
stream
());
helpers_
[
0
]
->
GEMM
(
input_tmp
->
data
<
int8_t
>
(),
weight
->
data
<
int8_t
>
(),
output
->
data
<
int32_t
>
(),
dev_ctx_
.
stream
());
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
int
m_
;
// m
int
n_
;
// n
int
k_
;
// k
int
compute_bias_
;
std
::
vector
<
std
::
shared_ptr
<
CublasLtHelper
>>
helpers_
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/cublaslt.h
0 → 100644
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
namespace
dyl
=
paddle
::
platform
::
dynload
;
namespace
paddle
{
namespace
operators
{
class
CublasLtHelper
{
public:
CublasLtHelper
(
int
m
,
int
k
,
int
n
)
:
alpha_
(
1
),
beta_
(
0
),
m_
(
m
),
k_
(
k
),
n_
(
n
)
{
cublasStatus_t
status
;
// handle and matmul desc
status
=
dyl
::
cublasLtCreate
(
&
handle_
);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t
cudaComputeType
=
CUDA_R_32I
;
#else
cublasComputeType_t
cudaComputeType
=
CUBLAS_COMPUTE_32I
;
#endif
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
#if CUBLAS_VER_MAJOR < 11
status
=
dyl
::
cublasLtMatmulDescCreate
(
&
matmul_desc_
,
cudaComputeType
);
#else
status
=
dyl
::
cublasLtMatmulDescCreate
(
&
matmul_desc_
,
cudaComputeType
,
CUDA_R_32I
);
#endif
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatmulDescCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
cublasOperation_t
op_transpose
=
CUBLAS_OP_T
;
status
=
dyl
::
cublasLtMatmulDescSetAttribute
(
matmul_desc_
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
op_transpose
,
sizeof
(
op_transpose
));
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatmulDescSetAttribute execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
// matrix desc
status
=
dyl
::
cublasLtMatrixLayoutCreate
(
&
B_desc_
,
CUDA_R_8I
,
k
,
n
,
k
);
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
status
=
dyl
::
cublasLtMatrixLayoutCreate
(
&
A_desc_
,
CUDA_R_8I
,
k
,
m
,
k
);
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
status
=
dyl
::
cublasLtMatrixLayoutCreate
(
&
C_desc_
,
CUDA_R_32I
,
n
,
m
,
n
);
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
}
~
CublasLtHelper
()
{
if
(
handle_
)
dyl
::
cublasLtDestroy
(
handle_
);
if
(
matmul_desc_
)
dyl
::
cublasLtMatmulDescDestroy
(
matmul_desc_
);
if
(
A_desc_
)
dyl
::
cublasLtMatrixLayoutDestroy
(
A_desc_
);
if
(
B_desc_
)
dyl
::
cublasLtMatrixLayoutDestroy
(
B_desc_
);
if
(
C_desc_
)
dyl
::
cublasLtMatrixLayoutDestroy
(
C_desc_
);
}
void
GEMM
(
int8_t
*
A_dev
,
const
int8_t
*
B_dev
,
int32_t
*
C_dev
,
cudaStream_t
stream
)
{
cublasStatus_t
status
;
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
cublasLtMatmulAlgo_t
algo
;
int
algoId
=
21
;
int
swizzle
=
0
;
int
customOption
=
0
;
int
tile
=
15
;
int
splitK_val
=
0
;
int
reductionScheme
=
0
;
#if CUDA_VERSION >= 11000
int
stages
=
23
;
#endif
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t
cudaComputeType
=
CUDA_R_32I
;
#else
cublasComputeType_t
cudaComputeType
=
CUBLAS_COMPUTE_32I
;
#endif
dyl
::
cublasLtMatmulAlgoInit
(
handle_
,
cudaComputeType
,
CUDA_R_32I
,
CUDA_R_8I
,
CUDA_R_8I
,
CUDA_R_32I
,
CUDA_R_32I
,
algoId
,
&
algo
);
dyl
::
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
,
&
(
customOption
),
sizeof
(
customOption
));
dyl
::
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_TILE_ID
,
&
(
tile
),
sizeof
(
tile
));
dyl
::
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM
,
&
(
splitK_val
),
sizeof
(
splitK_val
));
dyl
::
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
,
&
(
swizzle
),
sizeof
(
swizzle
));
dyl
::
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
,
&
(
reductionScheme
),
sizeof
(
int
));
#if CUDA_VERSION >= 11000
dyl
::
cublasLtMatmulAlgoConfigSetAttribute
(
&
algo
,
CUBLASLT_ALGO_CONFIG_STAGES_ID
,
&
(
stages
),
sizeof
(
stages
));
#endif
#endif
status
=
dyl
::
cublasLtMatmul
(
handle_
,
matmul_desc_
,
&
alpha_
,
B_dev
,
B_desc_
,
A_dev
,
A_desc_
,
&
beta_
,
C_dev
,
C_desc_
,
C_dev
,
C_desc_
,
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
&
algo
,
#else
nullptr
,
#endif
nullptr
,
0
,
stream
);
PADDLE_ENFORCE_EQ
(
status
,
CUBLAS_STATUS_SUCCESS
,
platform
::
errors
::
External
(
"cublasLtMatmul execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"
));
}
private:
cublasLtHandle_t
handle_
;
cublasLtMatmulDesc_t
matmul_desc_
;
cublasLtMatrixLayout_t
A_desc_
;
cublasLtMatrixLayout_t
B_desc_
;
cublasLtMatrixLayout_t
C_desc_
;
int32_t
alpha_
;
int32_t
beta_
;
int
m_
;
int
k_
;
int
n_
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fused_dropout_act_bias.h
浏览文件 @
3d7e2118
...
...
@@ -60,19 +60,32 @@ struct GeluGradFunctor {
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
Functor
>
__global__
void
FusedDropoutActBias
(
Functor
act
,
const
uint64_t
seed
,
const
uint64_t
rows
,
const
uint64_t
cols
,
const
int
increment
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
bias
,
T
*
dst
,
MaskType
*
mask
)
{
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
Functor
,
typename
InType
=
T
,
typename
OutType
=
T
>
__global__
void
FusedDropoutActBias
(
Functor
act
,
const
uint64_t
seed
,
const
uint64_t
rows
,
const
uint64_t
cols
,
const
int
increment
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
InType
*
__restrict__
src
,
const
T
*
__restrict__
bias
,
OutType
*
dst
,
MaskType
*
mask
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
int
col_id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row_id
=
blockIdx
.
y
;
int
idx
=
row_id
*
cols
+
col_id
;
...
...
@@ -90,7 +103,9 @@ __global__ void FusedDropoutActBias(Functor act,
VecSize
,
false
,
true
,
Functor
>
(
r
,
Functor
,
InType
,
OutType
>
(
r
,
i
,
cols
,
&
state
,
...
...
@@ -104,7 +119,14 @@ __global__ void FusedDropoutActBias(Functor act,
is_test
,
nullptr
,
nullptr
,
act
);
act
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
}
}
...
...
@@ -112,7 +134,11 @@ __global__ void FusedDropoutActBias(Functor act,
/**
* @brief dst = dropout(activation(src + bias));
*/
template
<
typename
T
,
typename
MaskType
,
typename
Functor
>
template
<
typename
T
,
typename
MaskType
,
typename
Functor
,
typename
InType
=
T
,
typename
OutType
=
T
>
void
LaunchDropoutActBias
(
Functor
act_functor
,
const
uint64_t
seed
,
const
uint32_t
rows
,
...
...
@@ -121,14 +147,21 @@ void LaunchDropoutActBias(Functor act_functor,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
T
*
src
,
const
InType
*
src
,
const
T
*
bias
,
T
*
dst
,
OutType
*
dst
,
MaskType
*
mask_data
,
const
phi
::
GPUContext
&
ctx
)
{
const
phi
::
GPUContext
&
ctx
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
// dropout_prob == 1.0f
if
(
std
::
abs
(
dropout_prob
-
1.0
f
)
<
1e-5
)
{
SetZero
<
T
>
(
ctx
,
dst
,
rows
*
cols
);
SetZero
<
T
>
(
ctx
,
reinterpret_cast
<
T
*>
(
dst
)
,
rows
*
cols
);
SetZero
<
MaskType
>
(
ctx
,
mask_data
,
rows
*
cols
);
return
;
}
...
...
@@ -137,7 +170,7 @@ void LaunchDropoutActBias(Functor act_functor,
const
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
const
auto
config
=
Get1DBlocksAnd2DGrids
(
ctx
,
rows
,
cols
,
real_vec_size
);
if
(
cols
%
VecSize
==
0
)
{
FusedDropoutActBias
<
T
,
MaskType
,
VecSize
,
Functor
>
FusedDropoutActBias
<
T
,
MaskType
,
VecSize
,
Functor
,
InType
,
OutType
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
seed
,
...
...
@@ -150,9 +183,13 @@ void LaunchDropoutActBias(Functor act_functor,
src
,
bias
,
dst
,
mask_data
);
mask_data
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
);
}
else
{
FusedDropoutActBias
<
T
,
MaskType
,
1
,
Functor
>
FusedDropoutActBias
<
T
,
MaskType
,
1
,
Functor
,
InType
,
OutType
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
seed
,
...
...
@@ -165,7 +202,11 @@ void LaunchDropoutActBias(Functor act_functor,
src
,
bias
,
dst
,
mask_data
);
mask_data
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
);
}
}
...
...
paddle/fluid/operators/fused/fused_dropout_common.h
浏览文件 @
3d7e2118
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
...
...
paddle/fluid/operators/fused/fused_dropout_helper.h
浏览文件 @
3d7e2118
...
...
@@ -109,7 +109,10 @@ struct DropoutParam {
}
};
template
<
typename
T
,
typename
MaskType
>
template
<
typename
T
,
typename
MaskType
,
typename
InType
=
T
,
typename
OutType
=
T
>
class
FusedDropoutHelper
{
private:
int
GetIncrement
(
const
phi
::
GPUContext
&
ctx
)
{
...
...
@@ -140,25 +143,34 @@ class FusedDropoutHelper {
// out = residual + dropout( src + bias )
void
ResidualDropoutBias
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
src
,
const
InType
*
src
,
const
T
*
residual
,
const
T
*
bias
,
T
*
out
,
MaskType
*
mask
)
{
OutType
*
out
,
MaskType
*
mask
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
)
{
auto
increment
=
GetIncrement
(
ctx
);
LaunchResidualDropoutBias
<
T
,
MaskType
>
(
rows_
,
cols_
,
increment
,
dropout_param_
.
seed
,
dropout_param_
.
dropout_prob
,
dropout_param_
.
is_test
,
dropout_param_
.
is_upscale_in_train
,
src
,
residual
,
bias
,
mask
,
out
,
ctx
);
LaunchResidualDropoutBias
<
T
,
MaskType
,
InType
,
OutType
>
(
rows_
,
cols_
,
increment
,
dropout_param_
.
seed
,
dropout_param_
.
dropout_prob
,
dropout_param_
.
is_test
,
dropout_param_
.
is_upscale_in_train
,
src
,
residual
,
bias
,
mask
,
out
,
ctx
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
);
}
void
ResidualDropoutBiasGrad
(
const
phi
::
GPUContext
&
ctx
,
...
...
@@ -189,15 +201,22 @@ class FusedDropoutHelper {
// out = dropout(activation(src + bias))
void
DropoutActBias
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
src
,
const
InType
*
src
,
const
T
*
bias
,
const
std
::
string
&
act_method
,
T
*
out
,
MaskType
*
mask
)
{
OutType
*
out
,
MaskType
*
mask
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
auto
increment
=
GetIncrement
(
ctx
);
if
(
act_method
==
"gelu"
)
{
GeluFunctor
<
T
>
gelu
;
LaunchDropoutActBias
<
T
,
MaskType
,
GeluFunctor
<
T
>>
(
LaunchDropoutActBias
<
T
,
MaskType
,
GeluFunctor
<
T
>
,
InType
,
OutType
>
(
gelu
,
dropout_param_
.
seed
,
rows_
,
...
...
@@ -210,23 +229,40 @@ class FusedDropoutHelper {
bias
,
out
,
mask
,
ctx
);
ctx
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
if
(
act_method
==
"relu"
)
{
phi
::
funcs
::
ReluFunctor
<
T
>
relu
;
LaunchDropoutActBias
<
T
,
MaskType
,
phi
::
funcs
::
ReluFunctor
<
T
>>
(
relu
,
dropout_param_
.
seed
,
rows_
,
cols_
,
increment
,
dropout_param_
.
dropout_prob
,
dropout_param_
.
is_upscale_in_train
,
dropout_param_
.
is_test
,
src
,
bias
,
out
,
mask
,
ctx
);
LaunchDropoutActBias
<
T
,
MaskType
,
phi
::
funcs
::
ReluFunctor
<
T
>
,
InType
,
OutType
>
(
relu
,
dropout_param_
.
seed
,
rows_
,
cols_
,
increment
,
dropout_param_
.
dropout_prob
,
dropout_param_
.
is_upscale_in_train
,
dropout_param_
.
is_test
,
src
,
bias
,
out
,
mask
,
ctx
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Currently only supports gelu or relu activation functions!"
));
...
...
@@ -283,8 +319,12 @@ class FusedDropoutHelper {
DropoutParam
dropout_param_
;
};
template
<
typename
T
,
typename
MaskType
>
class
FusedDropoutLayerNormHelper
:
public
FusedDropoutHelper
<
T
,
MaskType
>
{
template
<
typename
T
,
typename
MaskType
,
typename
InType
=
T
,
typename
OutType
=
T
>
class
FusedDropoutLayerNormHelper
:
public
FusedDropoutHelper
<
T
,
MaskType
,
InType
,
OutType
>
{
public:
FusedDropoutLayerNormHelper
()
{}
FusedDropoutLayerNormHelper
(
const
int
rows
,
...
...
@@ -301,23 +341,24 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
const
int
cols
,
const
DropoutParam
&
dropout_param
,
const
float
epsilon
)
:
FusedDropoutHelper
<
T
,
MaskType
>
(
ctx
,
rows
,
cols
,
dropout_param
)
{
:
FusedDropoutHelper
<
T
,
MaskType
,
InType
,
OutType
>
(
ctx
,
rows
,
cols
,
dropout_param
)
{
using
U
=
LayerNormParamType
<
T
>
;
epsilon_
=
epsilon
;
}
// call layer_norm
void
LayerNorm
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
src
,
const
InType
*
src
,
const
LayerNormParamType
<
T
>*
gamma
,
const
LayerNormParamType
<
T
>*
beta
,
T
*
out
,
OutType
*
out
,
LayerNormParamType
<
T
>*
mean
,
LayerNormParamType
<
T
>*
variance
)
{
using
U
=
LayerNormParamType
<
T
>
;
switch
(
GetDesiredBlockDim
(
this
->
cols_
))
{
FIXED_BLOCK_DIM_CASE
(
LayerNormForward
<
T
,
U
,
kBlockDim
>
LayerNormForward
<
T
,
U
,
kBlockDim
,
false
,
InType
,
OutType
>
<<<
this
->
rows_
,
kBlockDim
,
0
,
ctx
.
stream
()
>>>
(
src
,
gamma
,
beta
,
out
,
mean
,
variance
,
epsilon_
,
this
->
cols_
));
}
...
...
@@ -349,17 +390,25 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
// out = layernorm(residual + dropout(src + bias))
template
<
typename
P
=
LayerNormParamType
<
T
>,
bool
is_same_type
=
false
>
void
LayernormResidualDropoutBias
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
const
P
*
gamma
,
const
P
*
beta
,
T
*
dropout_out
,
MaskType
*
mask
,
T
*
out
,
LayerNormParamType
<
T
>*
mean
,
LayerNormParamType
<
T
>*
variance
)
{
void
LayernormResidualDropoutBias
(
const
phi
::
GPUContext
&
ctx
,
const
InType
*
src
,
const
T
*
residual
,
const
T
*
bias
,
const
P
*
gamma
,
const
P
*
beta
,
T
*
dropout_out
,
MaskType
*
mask
,
OutType
*
out
,
LayerNormParamType
<
T
>*
mean
,
LayerNormParamType
<
T
>*
variance
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
using
U
=
LayerNormParamType
<
T
>
;
int
vec_size
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
if
(
this
->
cols_
%
vec_size
!=
0
)
{
...
...
@@ -368,7 +417,12 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
int
threads
=
GetDesiredBlockDim
(
this
->
cols_
/
vec_size
);
int
increment
=
((
this
->
cols_
-
1
)
/
(
threads
*
vec_size
)
+
1
)
*
vec_size
;
increment
=
this
->
dropout_param_
.
UpdateSeedAndIncrement
(
ctx
,
increment
);
LaunchLayernormResidualDropoutBias
<
T
,
MaskType
,
U
,
is_same_type
>
(
LaunchLayernormResidualDropoutBias
<
T
,
MaskType
,
U
,
is_same_type
,
InType
,
OutType
>
(
this
->
rows_
,
this
->
cols_
,
increment
,
...
...
@@ -387,7 +441,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
out
,
mean
,
variance
,
ctx
);
ctx
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
template
<
typename
P
=
LayerNormParamType
<
T
>,
bool
is_same_type
=
false
>
...
...
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
3d7e2118
...
...
@@ -418,7 +418,9 @@ template <typename T,
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
,
typename
InType
=
T
,
typename
OutType
=
T
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_fast_ln_fwd_kernel
(
int
rows
,
int
cols
,
...
...
@@ -428,7 +430,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
const
bool
is_test
,
const
uint64_t
increment
,
const
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
InType
*
__restrict__
x_ptr
,
const
T
*
__restrict__
residual_ptr
,
const
T
*
__restrict__
bias_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
...
...
@@ -437,10 +439,20 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
residual_out_ptr
,
T
*
__restrict__
y_ptr
)
{
OutType
*
__restrict__
y_ptr
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
__restrict__
quant_out_scale_ptr
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
__shared__
U
smem
[
WARPS_M
*
WARPS_N
];
using
Vec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
Vec_scale
=
phi
::
AlignedVector
<
ScaleT
,
VecSize
>
;
using
Vec_in_type
=
phi
::
AlignedVector
<
InType
,
VecSize
>
;
using
Vec_out_type
=
phi
::
AlignedVector
<
OutType
,
VecSize
>
;
using
Vec_float
=
phi
::
AlignedVector
<
float
,
VecSize
>
;
using
MaskStoreT
=
phi
::
AlignedVector
<
MaskType
,
VecSize
>
;
const
int
tidx
=
threadIdx
.
x
;
...
...
@@ -481,12 +493,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
constexpr
U
rn
=
1.
f
/
U
(
ELTS_PER_ROW
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
Vec_in_type
x_input
[
LDGS
];
Vec
residual
[
LDGS
];
Vec_float
dequant_out_scale
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
residual_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
residual
[
it
]);
phi
::
Load
<
InType
,
VecSize
>
(
x_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
x_input
[
it
]);
if
(
quant_out_scale_ptr
!=
nullptr
)
{
phi
::
Load
<
float
,
VecSize
>
(
quant_out_scale_ptr
+
quant_out_scale_offset
+
col
*
VecSize
,
&
dequant_out_scale
[
it
]);
}
col
+=
THREADS_PER_ROW
;
}
...
...
@@ -520,10 +541,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
(
x
[
it
][
jt
]
+
bias
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
if
(
std
::
is_same
<
InType
,
int32_t
>::
value
)
{
T
tmp
=
(
static_cast
<
T
>
(
static_cast
<
float
>
(
x_input
[
it
][
jt
])
*
quant_last_in_scale
/
dequant_out_scale
[
it
][
jt
])
+
bias
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
x
[
it
][
jt
]
=
tmp
;
xf
[
it
*
VecSize
+
jt
]
=
U
(
tmp
);
}
else
{
x
[
it
][
jt
]
=
(
static_cast
<
T
>
(
x_input
[
it
][
jt
])
+
bias
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
}
}
else
{
...
...
@@ -532,8 +564,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
x
[
it
][
jt
]
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
if
(
std
::
is_same
<
InType
,
int32_t
>::
value
)
{
// for int32 input, we need to dequantize.
T
tmp
=
static_cast
<
T
>
(
static_cast
<
float
>
(
x_input
[
it
][
jt
])
*
quant_last_in_scale
/
dequant_out_scale
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
x
[
it
][
jt
]
=
tmp
;
}
else
{
x
[
it
][
jt
]
=
static_cast
<
T
>
(
x_input
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
}
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
...
...
@@ -626,6 +669,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
var_out_ptr
[
row
]
=
var_local
*
rn
;
}
Vec_out_type
x_output
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
...
...
@@ -638,12 +683,26 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
U
tmp
=
rsigma
*
(
static_cast
<
U
>
(
xf
[
it
*
VecSize
+
jt
])
-
mu_local
);
x
[
it
][
jt
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
gamma
[
it
][
jt
])
*
tmp
+
static_cast
<
U
>
(
beta
[
it
][
jt
]));
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
x_output
[
it
][
jt
]
=
quant_helper
(
x
[
it
][
jt
],
quant_next_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
}
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
phi
::
Store
<
OutType
,
VecSize
>
(
x_output
[
it
],
y_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
}
else
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
reinterpret_cast
<
T
*>
(
y_ptr
)
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
}
col
+=
THREADS_PER_ROW
;
}
}
...
...
@@ -668,7 +727,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
template
<
typename
T
,
typename
MaskType
,
typename
U
,
bool
ScaleBiasWithSameTypeX
=
false
>
bool
ScaleBiasWithSameTypeX
=
false
,
typename
InType
=
T
,
typename
OutType
=
T
>
void
LaunchLayernormResidualDropoutBias
(
const
uint32_t
rows
,
const
uint32_t
cols
,
...
...
@@ -678,18 +739,26 @@ void LaunchLayernormResidualDropoutBias(
const
float
epsilon
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
T
*
src
,
const
InType
*
src
,
const
T
*
residual
,
const
T
*
bias
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
scale
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
layernorm_bias
,
MaskType
*
mask_data
,
T
*
dst
,
T
*
layernorm_dst
,
OutType
*
layernorm_dst
,
LayerNormParamType
<
T
>
*
mean
,
LayerNormParamType
<
T
>
*
var
,
const
phi
::
GPUContext
&
ctx
)
{
const
phi
::
GPUContext
&
ctx
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
// dropout_prob == 1.0f
// NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0
if
(
std
::
abs
(
dropout_prob
-
1.0
f
)
<
1e-5
)
{
auto
cuda_place
=
ctx
.
GetPlace
();
memory
::
Copy
(
cuda_place
,
...
...
@@ -705,14 +774,15 @@ void LaunchLayernormResidualDropoutBias(
switch
(
GetDesiredBlockDim
(
cols
))
{
FIXED_BLOCK_DIM_CASE
(
LayerNormForward
<
T
,
U
,
kBlockDim
,
ScaleBiasWithSameTypeX
>
<<<
rows
,
kBlockDim
,
0
,
ctx
.
stream
()
>>>
(
dst
,
scale
,
layernorm_bias
,
layernorm_dst
,
mean
,
var
,
epsilon
,
cols
));
<<<
rows
,
kBlockDim
,
0
,
ctx
.
stream
()
>>>
(
dst
,
scale
,
layernorm_bias
,
reinterpret_cast
<
T
*>
(
layernorm_dst
),
mean
,
var
,
epsilon
,
cols
));
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Product from begin_norm_axis to end must be larger than 1"
));
...
...
@@ -722,44 +792,63 @@ void LaunchLayernormResidualDropoutBias(
return
;
}
#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
case (cols): { \
constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
T, \
U, \
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, \
uint8_t, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
cols><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
is_test, \
increment, \
epsilon, \
src, \
residual, \
bias, \
scale, \
layernorm_bias, \
mask_data, \
mean, \
var, \
dst, \
layernorm_dst); \
#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
case (cols): { \
constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; \
const int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW * VecSize; \
const int LDGS = cols / ELTS_PER_ROW_PER_CTA; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
T, \
U, \
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, \
uint8_t, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
cols, \
THREADS_PER_WARP, \
THREADS_PER_ROW, \
THREADS_PER_CTA, \
ROWS_PER_CTA, \
ELTS_PER_ROW_PER_CTA, \
LDGS, \
InType, \
OutType> \
<<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
is_test, \
increment, \
epsilon, \
src, \
residual, \
bias, \
scale, \
layernorm_bias, \
mask_data, \
mean, \
var, \
dst, \
layernorm_dst, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_out_scale_offset, \
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound); \
} break
#define LAUNCH_FUSED_FAST_LN_KERNEL \
...
...
@@ -784,24 +873,25 @@ void LaunchLayernormResidualDropoutBias(
if
(
cols
%
VecSize
!=
0
)
{
int
blockDim
=
GetDesiredBlockDim
(
cols
);
FusedLayernormResidualDropoutBias
<
T
,
uint8_t
,
1
,
U
,
ScaleBiasWithSameTypeX
>
<<<
rows
,
blockDim
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
<<<
rows
,
blockDim
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
reinterpret_cast
<
const
T
*>
(
src
),
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
reinterpret_cast
<
T
*>
(
layernorm_dst
),
mean
,
var
);
}
else
{
if
(
can_call_fast_ln_kernel
)
{
switch
(
cols
)
{
...
...
@@ -819,24 +909,25 @@ void LaunchLayernormResidualDropoutBias(
VecSize
,
U
,
ScaleBiasWithSameTypeX
>
<<<
rows
,
blockDim
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
<<<
rows
,
blockDim
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
reinterpret_cast
<
const
T
*>
(
src
),
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
reinterpret_cast
<
T
*>
(
layernorm_dst
),
mean
,
var
);
}
}
}
...
...
paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc
0 → 100644
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
FusedMultiTransformerINT8Op
:
public
framework
::
OperatorWithKernel
{
private:
static
constexpr
const
char
*
OpName
=
"FusedMultiTransformerINT8Op"
;
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
#define CHECK_INPUT(name) \
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName)
#define CHECK_INPUTS(name) \
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName)
#define CHECK_OUTPUT(name) \
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName)
#define CHECK_OUTPUTS(name) \
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName)
CHECK_INPUT
(
X
);
// attention
CHECK_INPUTS
(
QKVW
);
CHECK_INPUTS
(
OutLinearW
);
if
(
ctx
->
HasInput
(
"TimeStep"
))
{
CHECK_INPUTS
(
CacheKV
);
}
if
(
ctx
->
HasInputs
(
"CacheKV"
))
{
CHECK_OUTPUTS
(
CacheKVOut
);
}
// ffn
CHECK_INPUTS
(
FFN1Weight
);
CHECK_INPUTS
(
FFN2Weight
);
CHECK_OUTPUT
(
Out
);
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto
x_dim
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dim
=
ctx
->
GetInputsDim
(
"QKVW"
)[
0
];
bool
trans_qkvw
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_qkvw"
);
PADDLE_ENFORCE_EQ
(
x_dim
.
size
(),
3
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]"
,
x_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
y_dim
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]"
,
y_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
x_dim
[
2
],
trans_qkvw
?
y_dim
[
3
]
:
y_dim
[
0
],
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
if
(
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
)
==
-
1
)
{
if
(
trans_qkvw
)
{
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
}
else
{
PADDLE_ENFORCE_EQ
(
y_dim
[
2
]
*
y_dim
[
3
],
y_dim
[
0
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
}
}
if
(
ctx
->
HasInputs
(
"CacheKV"
))
{
// [2, batch_size, num_head, max_seq_len, head_size]
const
auto
&
c_dims
=
ctx
->
GetInputsDim
(
"CacheKV"
);
const
auto
&
c_dim
=
c_dims
[
0
];
PADDLE_ENFORCE_EQ
(
c_dim
.
size
(),
5
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The CacheKV must be 5 dims, but got %d"
,
c_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
c_dim
[
0
],
2
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The first dim of CacheKV must be 2, but got %d"
,
c_dim
[
0
]));
// 2
PADDLE_ENFORCE_EQ
(
c_dim
[
1
],
x_dim
[
0
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d"
,
x_dim
[
0
],
c_dim
[
1
]));
// batch_size
PADDLE_ENFORCE_EQ
(
c_dim
[
2
],
trans_qkvw
?
y_dim
[
1
]
:
y_dim
[
2
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d"
,
trans_qkvw
?
y_dim
[
1
]
:
y_dim
[
2
],
c_dim
[
2
]));
// num_head
PADDLE_ENFORCE_GT
(
c_dim
[
3
],
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The forth dim of CacheKV must be greater than 0, but got %d"
,
c_dim
[
3
]));
// cache_seq_len
PADDLE_ENFORCE_EQ
(
c_dim
[
4
],
trans_qkvw
?
y_dim
[
2
]
:
y_dim
[
3
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d"
,
trans_qkvw
?
y_dim
[
2
]
:
y_dim
[
3
],
c_dim
[
4
]));
// head_size
}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
if
(
var_name
==
"TimeStep"
)
{
VLOG
(
10
)
<<
"var_name:"
<<
var_name
<<
" need not to transform"
;
return
expected_kernel_type
;
}
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
class
FusedMultiTransformerINT8OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input tensor."
);
AddInput
(
"LnScale"
,
"Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor."
)
.
AsDuplicable
();
AddInput
(
"LnBias"
,
"Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor."
)
.
AsDuplicable
();
AddInput
(
"QKVW"
,
"The qkv weight tensor."
).
AsDuplicable
();
AddInput
(
"QKVBias"
,
"The qkv bias tensor."
).
AsDispensable
().
AsDuplicable
();
AddInput
(
"CacheKV"
,
"(optional) The cached KV for generation inference."
)
.
AsDispensable
()
.
AsDuplicable
();
AddInput
(
"TimeStep"
,
"(optional, int) The time step for generation inference."
)
.
AsDispensable
();
AddInput
(
"SrcMask"
,
"(optional) The attention mask tensor in fmha."
)
.
AsDispensable
();
AddInput
(
"OutLinearW"
,
"The out_linear weight tensor."
).
AsDuplicable
();
AddInput
(
"OutLinearBias"
,
"The out_linear bias tensor."
)
.
AsDispensable
()
.
AsDuplicable
();
AddInput
(
"FFNLnScale"
,
"The layer_norm scale of FusedFeedForward op"
)
.
AsDuplicable
();
AddInput
(
"FFNLnBias"
,
"The layer_norm bias of FusedFeedForward op"
)
.
AsDuplicable
();
AddInput
(
"FFN1Weight"
,
"The linear1 weight of FusedFeedForward op"
)
.
AsDuplicable
();
AddInput
(
"FFN1Bias"
,
"The linear1 bias of FusedFeedForward op"
)
.
AsDispensable
()
.
AsDuplicable
();
AddInput
(
"FFN2Weight"
,
"The linear2 weight of FusedFeedForward op"
)
.
AsDuplicable
();
AddInput
(
"FFN2Bias"
,
"The linear2 bias input of FusedFeedForward op"
)
.
AsDispensable
()
.
AsDuplicable
();
AddInput
(
"QKVOutScale"
,
"QKVOutScale is used to dequantize qkv output tensor."
"In order to keep consistent with the PTQ/QAT calculation logic,"
"QKVOutScale should be max_bound * max_bound / max_range."
"Here max_range is per-channel weight scale."
"The shape of QKVOutScale is [num_layers, num_channels]"
)
.
AsDispensable
();
AddInput
(
"OutLinearOutScale"
,
"OutLinearOutScale is used to dequantize out_linear output tensor."
"The definition and shape is the same as QKVOutScale"
)
.
AsDispensable
();
AddInput
(
"FFN1OutScale"
,
"FFN1OutScale is used to dequantize ffn1 output tensor."
"The definition and shape is the same as QKVOutScale"
)
.
AsDispensable
();
AddInput
(
"FFN2OutScale"
,
"FFN2OutScale is used to dequantize ffn2 output tensor."
"The definition and shape is the same as QKVOutScale"
)
.
AsDispensable
();
AddOutput
(
"CacheKVOut"
,
"The updated cache KV. Inplace with CacheKV"
)
.
AsDispensable
()
.
AsDuplicable
();
AddOutput
(
"Out"
,
"Result after multi ."
);
AddAttr
<
bool
>
(
"pre_layer_norm"
,
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
"[default true]."
)
.
SetDefault
(
true
);
AddAttr
<
float
>
(
"epsilon"
,
"Constant for numerical stability [default 1e-5]."
)
.
SetDefault
(
1e-5
)
.
AddCustomChecker
([](
const
float
&
epsilon
)
{
PADDLE_ENFORCE_EQ
(
epsilon
>=
0.0
f
&&
epsilon
<=
0.001
f
,
true
,
platform
::
errors
::
InvalidArgument
(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s]."
,
epsilon
));
});
AddAttr
<
float
>
(
"dropout_rate"
,
"Probability of setting units to zero."
)
.
SetDefault
(
.5
f
)
.
AddCustomChecker
([](
const
float
&
drop_p
)
{
PADDLE_ENFORCE_EQ
(
drop_p
>=
0.0
f
&&
drop_p
<=
1.0
f
,
true
,
platform
::
errors
::
InvalidArgument
(
"'dropout_rate' must be between 0.0 and 1.0."
));
});
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"dropout_implementation"
,
"[
\"
downgrade_in_infer
\"
|
\"
upscale_in_train
\"
]"
"The meaning is the same as 'attn_dropout_implementation'."
)
.
SetDefault
(
"downgrade_in_infer"
)
.
AddCustomChecker
([](
const
std
::
string
&
type
)
{
PADDLE_ENFORCE_EQ
(
type
==
"downgrade_in_infer"
||
type
==
"upscale_in_train"
,
true
,
platform
::
errors
::
InvalidArgument
(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"
));
});
AddAttr
<
std
::
string
>
(
"act_method"
,
"act_method"
).
SetDefault
(
"gelu"
);
AddAttr
<
bool
>
(
"trans_qkvw"
,
"Whether the weights of qkv should be transposed. If true,"
"the shape eights of qkv should be [3, num_head, dim_head, dim_embed]."
"Otherwise the shape of weights of qkv should be"
"[dim_embed, 3, num_head, dim_head]"
)
.
SetDefault
(
true
);
AddAttr
<
int
>
(
"ring_id"
,
"ring id for tensor model parallel. distributed training and inference"
)
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
"num_head"
,
"num_head"
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dim_head"
,
"dim_head"
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dim_ffn"
,
"dim_ffn"
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
float
>>
(
"qkv_in_scale"
,
"qkv_in_scale is used to quantize qkv input tensor."
"in_scale is generated by PTQ or QAT, which represents valid max range "
"of this tensor."
"the size of qkv_in_scale should be num_layers, which is equal to "
"QKVW.dims()[0]"
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
float
>>
(
"out_linear_in_scale"
,
"out_linear_in_scale is used to quantize out_linear input tensor."
"the size of out_linear_in_scale is the same as qkv_in_scale"
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
float
>>
(
"ffn1_in_scale"
,
"ffn1_in_scale is used to quantize ffn1 input tensor."
"the size of ffn1_in_scale is the same as qkv_in_scale"
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
float
>>
(
"ffn2_in_scale"
,
"ffn2_in_scale is used to quantize ffn2 input tensor."
"the size of ffn2_in_scale is the same as qkv_in_scale"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"quant_round_type"
,
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(-2.5)=-3"
)
.
SetDefault
(
1
);
AddAttr
<
float
>
(
"quant_max_bound"
,
"(float, default 127.0) the max bound of float type to int type"
)
.
SetDefault
(
127.0
);
AddAttr
<
float
>
(
"quant_min_bound"
,
"(float, default -127.0) the min bound of float type to int type"
)
.
SetDefault
(
-
127.0
);
AddComment
(
R"DOC(fused multi transformer layers op)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fused_multi_transformer_int8
,
ops
::
FusedMultiTransformerINT8Op
,
ops
::
FusedMultiTransformerINT8OpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu
0 → 100644
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm_int8.h"
#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
FusedMultiTransformerINT8OpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
U
=
LayerNormParamType
<
T
>
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
auto
*
time_step
=
ctx
.
Input
<
Tensor
>
(
"TimeStep"
);
// 0. input
auto
*
input_x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
input_x_dims
=
input_x
->
dims
();
int
bsz
=
input_x_dims
[
0
];
int
seq_len
=
input_x_dims
[
1
];
int
dim_embed
=
input_x_dims
[
2
];
int
bsz_seq
=
bsz
*
seq_len
;
// quant input scales, vector, size = num_layers
auto
qkv_in_scale
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"qkv_in_scale"
);
auto
out_linear_in_scale
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"out_linear_in_scale"
);
auto
ffn1_in_scale
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"ffn1_in_scale"
);
auto
ffn2_in_scale
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"ffn2_in_scale"
);
// quant round type and bound
auto
quant_round_type
=
ctx
.
Attr
<
int
>
(
"quant_round_type"
);
auto
quant_max_bound
=
ctx
.
Attr
<
float
>
(
"quant_max_bound"
);
auto
quant_min_bound
=
ctx
.
Attr
<
float
>
(
"quant_min_bound"
);
// dequant output scales, tensor, size = [num_layers, n], n is gemm output
// size
auto
*
qkv_out_scale
=
ctx
.
Input
<
Tensor
>
(
"QKVOutScale"
);
auto
*
out_linear_out_scale
=
ctx
.
Input
<
Tensor
>
(
"OutLinearOutScale"
);
auto
*
ffn1_out_scale
=
ctx
.
Input
<
Tensor
>
(
"FFN1OutScale"
);
auto
*
ffn2_out_scale
=
ctx
.
Input
<
Tensor
>
(
"FFN2OutScale"
);
int
qkv_out_scale_n
=
qkv_out_scale
->
dims
()[
1
];
int
out_linear_out_scale_n
=
out_linear_out_scale
->
dims
()[
1
];
int
ffn1_out_scale_n
=
ffn1_out_scale
->
dims
()[
1
];
int
ffn2_out_scale_n
=
ffn2_out_scale
->
dims
()[
1
];
// 1. layer norm
const
auto
pre_layer_norm
=
ctx
.
Attr
<
bool
>
(
"pre_layer_norm"
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"LnScale"
);
auto
ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"LnBias"
);
auto
ln_compute
=
AttnLayerNorm
<
T
,
T
,
int8_t
>
(
dev_ctx
,
epsilon
,
bsz_seq
,
dim_embed
);
Tensor
ln_mean
,
ln_var
;
ln_mean
.
Resize
({{
bsz_seq
}});
auto
*
ln_mean_data
=
dev_ctx
.
Alloc
<
U
>
(
&
ln_mean
,
ln_mean
.
numel
()
*
sizeof
(
U
));
ln_var
.
Resize
({{
bsz_seq
}});
auto
*
ln_var_data
=
dev_ctx
.
Alloc
<
U
>
(
&
ln_var
,
ln_var
.
numel
()
*
sizeof
(
U
));
// 2. qkv
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto
qkv_weights
=
ctx
.
MultiInput
<
Tensor
>
(
"QKVW"
);
auto
qkv_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"QKVBias"
);
const
bool
trans_qkvw
=
ctx
.
Attr
<
bool
>
(
"trans_qkvw"
);
const
auto
qkv_w_dims
=
qkv_weights
[
0
]
->
dims
();
int
num_head
=
trans_qkvw
?
qkv_w_dims
[
1
]
:
qkv_w_dims
[
2
];
int
dim_head
=
trans_qkvw
?
qkv_w_dims
[
2
]
:
qkv_w_dims
[
3
];
int
hidden_size
=
num_head
*
dim_head
;
int
output_size
=
3
*
hidden_size
;
int
input_size
=
dim_embed
;
bool
compute_bias
=
qkv_biases
.
size
()
>
0
&&
time_step
==
nullptr
;
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
AttnMatmulINT8
<
T
>
qkv_compute
(
dev_ctx
,
bsz_seq
,
output_size
,
input_size
,
compute_bias
);
Tensor
qkv_out
;
qkv_out
.
Resize
({{
bsz
,
seq_len
,
3
,
num_head
,
dim_head
}});
auto
*
qkv_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
qkv_out
,
qkv_out
.
numel
()
*
sizeof
(
T
));
// 3. fmha
AttnDropoutParam
attn_param
(
true
,
"upscale_in_train"
,
0.0
,
true
,
true
,
0
,
nullptr
);
auto
fmha_compute
=
FMHARef
<
T
>
(
dev_ctx
,
bsz
,
seq_len
,
num_head
,
dim_head
,
attn_param
);
auto
*
src_mask
=
ctx
.
Input
<
Tensor
>
(
"SrcMask"
);
auto
cache_kvs
=
ctx
.
MultiInput
<
Tensor
>
(
"CacheKV"
);
auto
cache_kv_outs
=
ctx
.
MultiOutput
<
Tensor
>
(
"CacheKVOut"
);
// auto *time_step = ctx.Input<Tensor>("TimeStep");
auto
out_seq_len
=
seq_len
;
if
(
time_step
)
{
PADDLE_ENFORCE_EQ
(
time_step
->
place
(),
platform
::
CPUPlace
(),
platform
::
errors
::
PreconditionNotMet
(
"The place of input(TimeStep) must be CPUPlace."
));
// cache_seq_len
int
time_step_value
=
time_step
->
data
<
int
>
()[
0
];
PADDLE_ENFORCE_GT
(
time_step_value
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"The value of time_step must > 0, but now is %d"
,
time_step_value
));
PADDLE_ENFORCE_EQ
(
seq_len
,
1
,
platform
::
errors
::
PreconditionNotMet
(
"In decode stage, the seq_len of input must be 1, but now is %d"
,
seq_len
));
out_seq_len
+=
time_step_value
;
}
Tensor
transpose_out_2
,
qk_out
;
transpose_out_2
.
Resize
({{
3
,
bsz
,
num_head
,
seq_len
,
dim_head
}});
auto
*
transpose_out_2_data
=
dev_ctx
.
Alloc
<
T
>
(
&
transpose_out_2
,
transpose_out_2
.
numel
()
*
sizeof
(
T
));
qk_out
.
Resize
({{
bsz
,
num_head
,
seq_len
,
out_seq_len
}});
auto
*
qk_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
qk_out
,
qk_out
.
numel
()
*
sizeof
(
T
));
Tensor
softmax_out
;
Tensor
attn_dropout_mask_out
,
attn_dropout_out
;
Tensor
qktv_out
,
fmha_out
;
softmax_out
.
Resize
({{
bsz
,
num_head
,
seq_len
,
out_seq_len
}});
auto
*
softmax_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
softmax_out
,
softmax_out
.
numel
()
*
sizeof
(
T
));
attn_dropout_mask_out
.
Resize
({{
bsz
,
num_head
,
seq_len
,
out_seq_len
}});
auto
*
attn_dropout_mask_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
attn_dropout_mask_out
,
attn_dropout_mask_out
.
numel
()
*
sizeof
(
T
));
attn_dropout_out
.
Resize
({{
bsz
,
num_head
,
seq_len
,
out_seq_len
}});
auto
*
attn_dropout_data_data
=
dev_ctx
.
Alloc
<
T
>
(
&
attn_dropout_out
,
attn_dropout_out
.
numel
()
*
sizeof
(
T
));
qktv_out
.
Resize
({{
bsz
,
num_head
,
seq_len
,
dim_head
}});
auto
*
qktv_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
qktv_out
,
qktv_out
.
numel
()
*
sizeof
(
T
));
fmha_out
.
Resize
({{
bsz
,
seq_len
,
num_head
,
dim_head
}});
auto
*
fmha_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
fmha_out
,
fmha_out
.
numel
()
*
sizeof
(
T
));
// 4. out_linear
auto
out_linear_weights
=
ctx
.
MultiInput
<
Tensor
>
(
"OutLinearW"
);
auto
out_linear_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"OutLinearBias"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// (transA, transB, compute_bias) = (false, false, false)
AttnMatmulINT8
<
T
>
out_linear_compute
(
dev_ctx
,
bsz_seq
,
dim_embed
,
hidden_size
,
false
);
// 5. ln(residual + bias)
DropoutParam
dropout_param2
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
,
int32_t
,
int8_t
>
fused_dropout_layernorm_helper
(
dev_ctx
,
bsz_seq
,
dim_embed
,
dropout_param2
,
epsilon
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper_for_post_layernorm
(
dev_ctx
,
bsz_seq
,
dim_embed
,
dropout_param2
,
epsilon
);
auto
ffn_ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnScale"
);
auto
ffn_ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnBias"
);
Tensor
bias_dropout_residual_out
,
dropout_mask_out
;
T
*
bias_dropout_residual_out_data
=
nullptr
;
if
(
pre_layer_norm
)
{
bias_dropout_residual_out
.
Resize
({{
bsz
,
seq_len
,
dim_embed
}});
bias_dropout_residual_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
bias_dropout_residual_out
,
bias_dropout_residual_out
.
numel
()
*
sizeof
(
T
));
}
dropout_mask_out
.
Resize
({{
bsz
,
seq_len
,
dim_embed
}});
auto
*
dropout_mask_out_data
=
dev_ctx
.
Alloc
<
uint8_t
>
(
&
dropout_mask_out
,
dropout_mask_out
.
numel
()
*
sizeof
(
uint8_t
));
// 6. ffn matmul1
auto
ffn1_weights
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN1Weight"
);
auto
ffn1_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN1Bias"
);
auto
ffn1_weight_dim
=
ffn1_weights
[
0
]
->
dims
();
int
dim_ffn
=
ffn1_weight_dim
[
0
];
AttnMatmulINT8
<
T
>
ffn1_linear_compute
(
dev_ctx
,
bsz_seq
,
dim_ffn
,
dim_embed
,
false
);
Tensor
ffn1_out
;
ffn1_out
.
Resize
({{
bsz_seq
,
dim_ffn
}});
auto
*
ffn1_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
ffn1_out
,
ffn1_out
.
numel
()
*
sizeof
(
T
));
// 7. ffn act + bias
DropoutParam
ffn1_dropout_param
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
FusedDropoutHelper
<
T
,
uint8_t
,
int32_t
,
int8_t
>
fused_act_dropout_helper
(
dev_ctx
,
bsz_seq
,
dim_ffn
,
ffn1_dropout_param
);
FusedDropoutHelper
<
T
,
uint8_t
>
fused_act_dropout_helper_for_post_layernorm
(
dev_ctx
,
bsz_seq
,
dim_ffn
,
ffn1_dropout_param
);
Tensor
ffn1_dropout_out
,
ffn1_dropout_mask
;
ffn1_dropout_out
.
Resize
({{
bsz_seq
,
dim_ffn
}});
auto
*
ffn1_dropout_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
ffn1_dropout_out
,
ffn1_dropout_out
.
numel
()
*
sizeof
(
T
));
ffn1_dropout_mask
.
Resize
({{
bsz_seq
,
dim_ffn
}});
auto
*
ffn1_dropout_mask_data
=
dev_ctx
.
Alloc
<
uint8_t
>
(
&
ffn1_dropout_mask
,
ffn1_dropout_mask
.
numel
()
*
sizeof
(
uint8_t
));
// 8. ffn2 matmul
auto
ffn2_weights
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN2Weight"
);
auto
ffn2_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN2Bias"
);
AttnMatmulINT8
<
T
>
ffn2_linear_compute
(
dev_ctx
,
bsz_seq
,
dim_embed
,
dim_ffn
,
false
);
// 9. ffn2 residual bias
DropoutParam
ffn2_dropout_param
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
,
int32_t
,
int8_t
>
ffn2_fused_dropout_helper
(
dev_ctx
,
bsz_seq
,
dim_embed
,
ffn2_dropout_param
,
epsilon
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
,
int32_t
,
T
>
ffn2_fused_dropout_dequant_helper
(
dev_ctx
,
bsz_seq
,
dim_embed
,
ffn2_dropout_param
,
epsilon
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
ffn2_fused_dropout_helper_for_post_layernorm
(
dev_ctx
,
bsz_seq
,
dim_embed
,
ffn2_dropout_param
,
epsilon
);
// []. init workspace for cublasLt transform
Tensor
input_workspace
,
output_workspace
;
// for input and output transform data is CUBLASLT_ORDER_COL32 format,
int
m_max
=
bsz_seq
,
k_max
=
std
::
max
(
dim_embed
,
dim_ffn
),
n_max
=
std
::
max
({
output_size
,
dim_embed
,
dim_ffn
});
input_workspace
.
Resize
(
{{
32
*
((
m_max
+
32
-
1
)
/
32
),
(
k_max
+
31
)
/
32
*
32
}});
dev_ctx
.
Alloc
<
int8_t
>
(
&
input_workspace
,
input_workspace
.
numel
()
*
sizeof
(
int8_t
));
output_workspace
.
Resize
({{
n_max
*
4
,
(
m_max
+
31
)
/
32
*
32
*
4
}});
dev_ctx
.
Alloc
<
int32_t
>
(
&
output_workspace
,
output_workspace
.
numel
()
*
sizeof
(
int32_t
));
// calc
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
from_data
=
dev_ctx
.
Alloc
<
T
>
(
out
,
out
->
numel
()
*
sizeof
(
T
));
Tensor
*
from_tensor
=
out
;
Tensor
tmp_out
;
tmp_out
.
Resize
({{
bsz
,
seq_len
,
dim_embed
}});
auto
*
tmp_out_data
=
dev_ctx
.
Alloc
<
T
>
(
&
tmp_out
,
tmp_out
.
numel
()
*
sizeof
(
T
));
auto
*
x_data
=
input_x
->
data
<
T
>
();
Tensor
*
buf0
=
nullptr
;
Tensor
*
buf1
=
nullptr
;
// step0: x --> buf1
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int
layers
=
qkv_weights
.
size
();
if
(
pre_layer_norm
)
{
buf1
=
out
;
}
else
{
buf0
=
&
tmp_out
;
buf1
=
out
;
}
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
// step1. layer_norm
if
(
i
==
0
&&
pre_layer_norm
)
{
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
// TODO(wangxi): can remove mean var in inference
ln_compute
.
ComputeForward
(
x_data
,
ln_scale_data
,
ln_bias_data
,
input_workspace
.
data
<
int8_t
>
(),
ln_mean_data
,
ln_var_data
,
nullptr
,
0
,
qkv_in_scale
[
i
],
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step1"
;
#endif
// step2. qkv
const
Tensor
*
qkv_bias
=
qkv_biases
.
size
()
>
0
?
qkv_biases
[
i
]
:
nullptr
;
// NOTE: in decoder stage, bias is fused in fmha
const
Tensor
*
bias
=
time_step
?
nullptr
:
qkv_bias
;
if
(
!
pre_layer_norm
&&
i
==
0
)
{
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
input_x
,
&
input_workspace
,
bias
,
&
qkv_out
,
&
output_workspace
,
&
qkv_out
,
qkv_in_scale
[
i
],
qkv_out_scale
,
i
*
qkv_out_scale_n
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
if
(
!
pre_layer_norm
)
{
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
buf1
,
&
input_workspace
,
bias
,
&
qkv_out
,
&
output_workspace
,
&
qkv_out
,
qkv_in_scale
[
i
],
qkv_out_scale
,
i
*
qkv_out_scale_n
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
qkv_compute
.
ComputeForwardINT8ToT
(
qkv_weights
[
i
],
qkv_in_scale
[
i
],
&
input_workspace
,
bias
,
&
qkv_out
,
&
output_workspace
,
&
qkv_out
,
qkv_out_scale
,
i
*
qkv_out_scale_n
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step2"
;
#endif
// step3. fmha
const
Tensor
*
cache_kv
=
cache_kvs
.
size
()
>
0
?
cache_kvs
[
i
]
:
nullptr
;
Tensor
*
cache_kv_out
=
cache_kv
?
cache_kv_outs
[
i
]
:
nullptr
;
if
(
time_step
)
{
// generation decoder stage
// [2, batch_size, num_head, max_seq_len, head_size]
int
max_seq_len
=
cache_kv
->
dims
()[
3
];
fmha
<
T
>
(
dev_ctx
,
qkv_out
,
*
qkv_bias
,
*
src_mask
,
cache_kv_out
,
&
fmha_out
,
bsz
,
max_seq_len
,
num_head
,
dim_head
,
time_step
->
data
<
int
>
()[
0
],
1.
/
sqrt
(
dim_head
));
}
else
if
(
cache_kv_out
)
{
// generation context stage
// TODO(wangxi): can remove dropout in inference
fmha_compute
.
ComputeForward
(
qkv_out
,
nullptr
,
src_mask
,
&
transpose_out_2
,
nullptr
,
&
qk_out
,
nullptr
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
// [3, bsz, num_head, seq_len, head_dim]
T
*
qkv_data
=
transpose_out_2_data
;
int64_t
q_size
=
bsz
*
seq_len
*
num_head
*
dim_head
;
int64_t
k_size
=
q_size
;
const
T
*
q_ptr
=
qkv_data
;
const
T
*
k_ptr
=
q_ptr
+
q_size
;
const
T
*
v_ptr
=
k_ptr
+
k_size
;
// [2, bsz, num_head, max_seq_len, head_dim]
int
max_seq_len
=
cache_kv_out
->
dims
()[
3
];
T
*
cache_kv_data
=
cache_kv_out
->
data
<
T
>
();
int64_t
cache_k_size
=
bsz
*
num_head
*
max_seq_len
*
dim_head
;
T
*
cache_k_ptr
=
cache_kv_data
;
T
*
cache_v_ptr
=
cache_kv_data
+
cache_k_size
;
write_cache_kv
<
T
>
(
dev_ctx
,
cache_k_ptr
,
cache_v_ptr
,
k_ptr
,
v_ptr
,
bsz
,
num_head
,
seq_len
,
max_seq_len
,
dim_head
);
}
else
{
// not generation
// TODO(wangxi): can remove dropout in inference
fmha_compute
.
ComputeForward
(
qkv_out
,
cache_kv
,
src_mask
,
&
transpose_out_2
,
cache_kv_out
,
&
qk_out
,
nullptr
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step3"
;
#endif
if
(
pre_layer_norm
)
{
out_linear_compute
.
ComputeForwardTToINT8
(
out_linear_weights
[
i
],
out_linear_in_scale
[
i
],
&
fmha_out
,
&
input_workspace
,
nullptr
,
&
output_workspace
,
nullptr
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
AllReduce
<
int32_t
>
(
output_workspace
,
ring_id
,
bsz
*
seq_len
*
num_head
*
dim_head
,
dev_ctx
);
}
else
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
&
input_workspace
,
nullptr
,
buf0
,
&
output_workspace
,
nullptr
,
out_linear_in_scale
[
i
],
out_linear_out_scale
,
i
*
out_linear_out_scale_n
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
buf0
->
numel
(),
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step4"
;
#endif
// step5. ln(residual + dropout(input + bias))
if
(
pre_layer_norm
)
{
auto
*
ln_scale_data
=
ffn_ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ffn_ln_biases
[
i
]
->
data
<
U
>
();
auto
*
out_linear_bias_data
=
out_linear_biases
[
i
]
->
data
<
T
>
();
// inplace
// non-inplace: buf1 -> input_workspace
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
output_workspace
.
data
<
int32_t
>
(),
x_data
,
out_linear_bias_data
,
ln_scale_data
,
ln_bias_data
,
bias_dropout_residual_out_data
,
dropout_mask_out_data
,
input_workspace
.
data
<
int8_t
>
(),
ln_mean_data
,
ln_var_data
,
out_linear_in_scale
[
i
],
out_linear_out_scale
->
data
<
float
>
(),
i
*
out_linear_out_scale_n
,
ffn1_in_scale
[
i
],
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
auto
*
out_linear_bias_data
=
out_linear_biases
[
i
]
->
data
<
T
>
();
auto
*
residual_data
=
(
i
==
0
?
x_data
:
buf1
->
data
<
T
>
());
fused_dropout_layernorm_helper_for_post_layernorm
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
residual_data
,
out_linear_bias_data
,
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step5"
;
#endif
// step6. ffn matmul1
if
(
pre_layer_norm
)
{
ffn1_linear_compute
.
ComputeForwardINT8ToINT8
(
ffn1_weights
[
i
],
&
input_workspace
,
nullptr
,
&
output_workspace
,
nullptr
);
}
else
{
ffn1_linear_compute
.
ComputeForward
(
ffn1_weights
[
i
],
buf1
,
&
input_workspace
,
nullptr
,
&
ffn1_out
,
&
output_workspace
,
nullptr
,
ffn1_in_scale
[
i
],
ffn1_out_scale
,
i
*
ffn1_out_scale_n
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step6"
;
#endif
// step7. act bias
// TODO(wangxi): remove dropout mask in inference
if
(
pre_layer_norm
)
{
fused_act_dropout_helper
.
DropoutActBias
(
dev_ctx
,
output_workspace
.
data
<
int32_t
>
(),
ffn1_biases
[
i
]
->
data
<
T
>
(),
"gelu"
,
input_workspace
.
data
<
int8_t
>
(),
ffn1_dropout_mask_data
,
ffn1_in_scale
[
i
],
ffn1_out_scale
->
data
<
float
>
(),
i
*
ffn1_out_scale_n
,
ffn2_in_scale
[
i
],
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
fused_act_dropout_helper_for_post_layernorm
.
DropoutActBias
(
dev_ctx
,
ffn1_out_data
,
ffn1_biases
[
i
]
->
data
<
T
>
(),
"gelu"
,
ffn1_dropout_out_data
,
ffn1_dropout_mask_data
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step7"
;
#endif
// step8. ffn matmul2
if
(
pre_layer_norm
)
{
ffn2_linear_compute
.
ComputeForwardINT8ToINT8
(
ffn2_weights
[
i
],
&
input_workspace
,
nullptr
,
&
output_workspace
,
nullptr
);
}
else
{
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
&
input_workspace
,
nullptr
,
buf0
,
&
output_workspace
,
nullptr
,
ffn2_in_scale
[
i
],
ffn2_out_scale
,
i
*
ffn2_out_scale_n
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.0"
;
#endif
if
(
pre_layer_norm
)
{
AllReduce
<
int32_t
>
(
output_workspace
,
ring_id
,
bsz
*
seq_len
*
num_head
*
dim_head
,
dev_ctx
);
}
else
{
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
buf0
->
numel
(),
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.1"
;
#endif
// step9. residual bias
if
(
pre_layer_norm
)
{
// TODO(wangxi): remove dropout mask in inference
if
(
i
<
layers
-
1
)
{
auto
*
ln_scale_data
=
ln_scales
[
i
+
1
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
+
1
]
->
data
<
U
>
();
ffn2_fused_dropout_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
output_workspace
.
data
<
int32_t
>
(),
bias_dropout_residual_out_data
,
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf1
->
data
<
T
>
(),
dropout_mask_out_data
,
input_workspace
.
data
<
int8_t
>
(),
ln_mean_data
,
ln_var_data
,
ffn2_in_scale
[
i
],
ffn2_out_scale
->
data
<
float
>
(),
i
*
ffn2_out_scale_n
,
qkv_in_scale
[
i
+
1
],
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
ffn2_fused_dropout_dequant_helper
.
ResidualDropoutBias
(
dev_ctx
,
output_workspace
.
data
<
int32_t
>
(),
bias_dropout_residual_out_data
,
ffn2_biases
[
i
]
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
dropout_mask_out_data
,
ffn2_in_scale
[
i
],
ffn2_out_scale
->
data
<
float
>
(),
i
*
ffn2_out_scale_n
,
1.0
);
}
}
else
{
auto
*
ln_scale_data
=
ffn_ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ffn_ln_biases
[
i
]
->
data
<
U
>
();
ffn2_fused_dropout_helper_for_post_layernorm
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step9"
;
#endif
if
(
pre_layer_norm
)
{
x_data
=
buf1
->
data
<
T
>
();
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
fused_multi_transformer_int8
,
ops
::
FusedMultiTransformerINT8OpKernel
<
plat
::
float16
>
,
ops
::
FusedMultiTransformerINT8OpKernel
<
float
>
);
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This file has been adapted from FasterTransformer file:
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.
#include <cuda_fp16.h>
#include <float.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
// for debug
// #define _DEBUG_FUSED_MULTI_TRANSFORMER
template
<
typename
T
>
static
void
AllReduce
(
framework
::
Tensor
&
tensor
,
// NOLINT
const
int
ring_id
,
const
phi
::
GPUContext
&
ctx
)
{
if
(
ring_id
==
-
1
)
return
;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
map
=
paddle
::
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
if
(
map
->
has
(
ring_id
))
{
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
tensor
);
out_tensor
.
push_back
(
tensor
);
paddle
::
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
auto
task
=
pg
->
AllReduce
(
in_tensor
,
out_tensor
,
opts
);
task
->
Wait
();
}
else
{
auto
dtype
=
platform
::
ToNCCLDataType
(
framework
::
TransToProtoVarType
(
tensor
.
dtype
()));
int64_t
numel
=
tensor
.
numel
();
const
void
*
sendbuff
=
tensor
.
data
<
T
>
();
auto
place
=
ctx
.
GetPlace
();
void
*
recvbuff
=
ctx
.
Alloc
<
T
>
(
&
tensor
,
tensor
.
numel
()
*
sizeof
(
T
));
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
auto
stream
=
ctx
.
stream
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
dtype
,
ncclSum
,
comm
->
comm
(),
stream
));
}
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."
));
#endif
}
namespace
{
namespace
plat
=
paddle
::
platform
;
using
float16
=
plat
::
float16
;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
Masked_multihead_attention_params
{
// output buffer, [B, 1(seq_len), num_head * dim_head]
T
*
out
;
// qkv_out, [B, 1(seq_len), 3, num_head * dim_head]
const
T
*
qkv
;
// bias, [3, num_head, dim_head]
const
T
*
qkv_bias
;
// TODO(wangxi): optimize with input_lengths and max_input_len?
// [bsz, 1, 1, time_step(cache_seq_length)+1]
const
T
*
attn_mask
;
// [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head]
// k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first
// v [B, num_head, max_seq_len, dim_head]
T
*
cache_kv
;
int
batch_size
;
int
num_head
;
int
timestep
;
// cache_seq_length
int
max_seq_length
;
// 1.f / sqrt(Dh)
float
inv_sqrt_dh
;
};
struct
Float8_
{
float2
x
;
float2
y
;
float2
z
;
float2
w
;
};
// clang-format off
template
<
typename
T
,
int
Dh
>
struct
Qk_vec_
{};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float16
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
float16
,
256
>
{
using
Type
=
uint4
;
};
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_
<
float
,
1
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_
<
float16
,
4
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
K_vec_
<
float16
,
2
>
{
using
Type
=
uint2
;
};
template
<
>
struct
K_vec_
<
float16
,
1
>
{
using
Type
=
uint4
;
};
template
<
typename
T
,
int
V_VEC_SIZE
>
struct
V_vec_
{};
template
<
>
struct
V_vec_
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_
<
float
,
4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_
<
float16
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
V_vec_
<
float16
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
V_vec_
<
float16
,
8
>
{
using
Type
=
uint4
;
};
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
V_vec_acum_fp32_
{};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
// template <> struct V_vec_acum_fp32_<float2> { using Type = float2; };
template
<
>
struct
V_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
// template <> struct V_vec_acum_fp32_<uint32_t> { using Type = float2; };
// template <> struct V_vec_acum_fp32_<uint2 > { using Type = Float4_; };
template
<
>
struct
V_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
#endif
// clang-format on
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
inline
__device__
float2
half2_to_float2
(
uint32_t
v
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
v
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
inline
__device__
uint32_t
float2_to_half2
(
float2
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
tmp
.
u32
)
:
"f"
(
f
.
y
),
"f"
(
f
.
x
));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
.
x
));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
1
])
:
"f"
(
f
.
y
));
#endif
return
tmp
.
u32
;
}
inline
__device__
float
add
(
float
a
,
float
b
)
{
return
a
+
b
;
}
inline
__device__
float2
add
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
float4
add
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
uint16_t
add
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"add.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
inline
__device__
uint32_t
add
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
inline
__device__
uint2
add
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
uint4
add
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
float2
add
(
uint32_t
a
,
float2
fb
)
{
float2
fa
=
half2_to_float2
(
a
);
return
add
(
fa
,
fb
);
}
inline
__device__
Float8_
add
(
uint4
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
template
<
typename
Acc
,
typename
A
,
typename
B
>
inline
__device__
Acc
mul
(
A
a
,
B
b
);
template
<
>
inline
__device__
float
mul
<
float
,
float
>
(
float
a
,
float
b
)
{
return
a
*
b
;
}
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
return
c
;
}
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
c
.
z
=
a
.
z
*
b
.
z
;
c
.
w
=
a
.
w
*
b
.
w
;
return
c
;
}
template
<
>
inline
__device__
uint16_t
mul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"mul.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
float
b
)
{
float2
tmp
=
half2_to_float2
(
a
);
float2
tmp_res
;
tmp_res
.
x
=
tmp
.
x
*
b
;
tmp_res
.
y
=
tmp
.
y
*
b
;
uint32_t
res
=
float2_to_half2
(
tmp_res
);
return
res
;
}
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
float
b
)
{
uint2
res
;
res
.
x
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
x
,
b
);
res
.
y
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
y
,
b
);
return
res
;
}
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
float
b
)
{
uint4
res
;
res
.
x
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
x
,
b
);
res
.
y
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
y
,
b
);
res
.
z
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
z
,
b
);
res
.
w
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
w
,
b
);
return
res
;
}
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float
b
)
{
float2
res
;
res
.
x
=
a
.
x
*
b
;
res
.
y
=
a
.
y
*
b
;
return
res
;
}
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float
b
)
{
float4
res
;
res
.
x
=
a
.
x
*
b
;
res
.
y
=
a
.
y
*
b
;
res
.
z
=
a
.
z
*
b
;
res
.
w
=
a
.
w
*
b
;
return
res
;
}
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
inline
__device__
float
sum
(
uint16_t
v
)
{
return
half_to_float
(
v
);
}
inline
__device__
float
sum
(
uint32_t
v
)
{
float2
tmp
=
half2_to_float2
(
v
);
return
tmp
.
x
+
tmp
.
y
;
}
inline
__device__
float
sum
(
uint2
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
return
sum
(
c
);
}
inline
__device__
float
sum
(
uint4
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
c
=
add
(
c
,
v
.
z
);
c
=
add
(
c
,
v
.
w
);
return
sum
(
c
);
}
template
<
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
T
,
T
,
T
>
(
a
,
b
));
}
template
<
typename
A
,
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
A
,
T
,
T
>
(
a
,
b
));
}
inline
__device__
constexpr
uint32_t
shfl_mask
(
int
threads
)
{
return
threads
==
32
?
uint32_t
(
-
1
)
:
(
1u
<<
threads
)
-
1u
;
}
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
inline
__device__
float
fma
(
float
a
,
float
b
,
float
c
)
{
return
a
*
b
+
c
;
}
inline
__device__
float2
fma
(
float2
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
float4
fma
(
float4
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
uint32_t
fma
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
inline
__device__
uint2
fma
(
uint2
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
uint4
fma
(
uint4
a
,
uint4
b
,
uint4
c
)
{
uint4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float2
fma
(
float
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
float4
fma
(
float
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
Float8_
fma
(
float
a
,
Float8_
b
,
Float8_
c
)
{
Float8_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
uint32_t
h0_h0
(
uint16_t
a
)
{
uint32_t
b
;
asm
volatile
(
"mov.b32 %0, {%1, %1};"
:
"=r"
(
b
)
:
"h"
(
a
));
return
b
;
}
inline
__device__
uint32_t
fma
(
uint16_t
a
,
uint32_t
b
,
uint32_t
c
)
{
return
fma
(
h0_h0
(
a
),
b
,
c
);
}
inline
__device__
uint2
fma
(
uint16_t
a
,
uint2
b
,
uint2
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
uint4
fma
(
uint16_t
a
,
uint4
b
,
uint4
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float
cast_to_float
(
float
u
)
{
return
u
;
}
inline
__device__
float2
cast_to_float
(
float2
u
)
{
return
u
;
}
inline
__device__
float4
cast_to_float
(
float4
u
)
{
return
u
;
}
inline
__device__
Float8_
cast_to_float
(
uint4
u
)
{
Float8_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
tmp
.
z
=
half2_to_float2
(
u
.
z
);
tmp
.
w
=
half2_to_float2
(
u
.
w
);
return
tmp
;
}
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
],
float
inv_sqrt_dh
)
{
K_vec
inv_q
=
mul
<
K_vec
,
K_vec
,
float
>
(
q
[
0
],
inv_sqrt_dh
);
K_vec
qk_vec
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
inv_q
,
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
inv_q
=
mul
<
K_vec
,
K_vec
,
float
>
(
q
[
ii
],
inv_sqrt_dh
);
qk_vec
=
fma
(
inv_q
,
k
[
ii
],
qk_vec
);
}
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREADS_PER_KEY
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
],
float
inv_sqrt_dh
)
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
,
inv_sqrt_dh
);
}
};
template
<
int
WARPS_PER_BLOCK
,
int
WARP_SIZE
=
32
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
__syncthreads
();
if
(
lane
<
WARPS_PER_BLOCK
)
{
sum
=
red_smem
[
lane
];
}
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
inline
__device__
void
convert_from_float
(
float
&
dst
,
float
src
)
{
// NOLINT
dst
=
src
;
}
inline
__device__
void
convert_from_float
(
float4
&
dst
,
float4
src
)
{
// NOLINT
dst
=
src
;
}
inline
__device__
void
convert_from_float
(
plat
::
float16
&
dst
,
// NOLINT
float
src
)
{
dst
=
static_cast
<
plat
::
float16
>
(
src
);
}
inline
__device__
void
convert_from_float
(
uint4
&
dst
,
Float8_
src
)
{
// NOLINT
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
dst
.
z
=
float2_to_half2
(
src
.
z
);
dst
.
w
=
float2_to_half2
(
src
.
w
);
}
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
// NOLINT
template
<
typename
T
>
inline
__device__
void
zero
(
T
&
dst
)
{
// NOLINT
constexpr
int
WORDS
=
sizeof
(
T
)
/
4
;
union
{
T
raw
;
uint32_t
words
[
WORDS
];
}
tmp
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
WORDS
;
++
ii
)
{
tmp
.
words
[
ii
]
=
0u
;
}
dst
=
tmp
.
raw
;
}
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
int
THREADS_PER_KEY
,
int
THREADS_PER_VALUE
,
int
THREADS_PER_BLOCK
>
__global__
void
masked_multihead_attention_kernel
(
Masked_multihead_attention_params
<
T
>
params
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert
(
Dh_MAX
%
THREADS_PER_KEY
==
0
,
""
);
static_assert
(
Dh_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
extern
__shared__
char
smem_
[];
float
*
qk_smem
=
reinterpret_cast
<
float
*>
(
smem_
);
char
*
logits_smem_
=
smem_
;
// fp32 accum for logits
float
*
logits_smem
=
reinterpret_cast
<
float
*>
(
logits_smem_
);
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
x
;
const
int
bhi
=
bi
*
params
.
num_head
+
hi
;
const
int
tid
=
threadIdx
.
x
;
float
qk_max
=
-
FLT_MAX
;
float
qk
=
0
;
// qkv [B, S=1, 3, num_head, head_dim]
int
qkv_base_offset
=
bi
*
3
*
params
.
num_head
*
Dh
+
hi
*
Dh
;
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
constexpr
int
QK_ELTS_IN_16B
=
16
/
sizeof
(
T
);
constexpr
int
QK_VECS_IN_16B
=
16
/
sizeof
(
Qk_vec
);
const
T
*
q_base
=
params
.
qkv
;
const
T
*
k_base
=
params
.
qkv
+
params
.
num_head
*
Dh
;
const
T
*
q_bias_base
=
params
.
qkv_bias
;
const
T
*
k_bias_base
=
params
.
qkv_bias
+
params
.
num_head
*
Dh
;
if
(
tid
<
QK_VECS_PER_WARP
)
{
int
qk_offset
=
qkv_base_offset
+
tid
*
QK_VEC_SIZE
;
int
qk_bias_offset
=
hi
*
Dh
+
tid
*
QK_VEC_SIZE
;
Qk_vec
q
;
zero
(
q
);
q
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_base
[
qk_offset
])
:
q
;
Qk_vec
k
;
zero
(
k
);
k
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_base
[
qk_offset
])
:
k
;
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_bias_base
[
qk_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
k_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_bias_base
[
qk_bias_offset
])
:
k_bias
;
q
=
add
(
q
,
q_bias
);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// we may not require k_bias.
k
=
add
(
k
,
k_bias
);
*
reinterpret_cast
<
Qk_vec
*>
(
&
q_smem
[
tid
*
QK_VEC_SIZE
])
=
q
;
int
co
=
tid
/
QK_VECS_IN_16B
;
int
ci
=
(
tid
%
QK_VECS_IN_16B
)
*
QK_VEC_SIZE
;
int
offset
=
bhi
*
params
.
max_seq_length
*
Dh
+
co
*
params
.
max_seq_length
*
QK_ELTS_IN_16B
+
params
.
timestep
*
QK_ELTS_IN_16B
+
ci
;
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
cache_kv
[
offset
])
=
k
;
}
qk
=
dot
<
Qk_vec
,
Qk_vec
>
(
q
,
k
);
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
#pragma unroll
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
if
(
tid
==
0
)
{
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk
*=
params
.
inv_sqrt_dh
;
qk_max
=
qk
;
qk_smem
[
params
.
timestep
]
=
qk
;
}
__syncthreads
();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"=======q_out=======
\n
"
);
for
(
int
i
=
0
;
i
<
Dh
;
++
i
)
printf
(
"%f "
,
static_cast
<
float
>
(
q_smem
[
i
]));
printf
(
"
\n
"
);
}
__syncthreads
();
#endif
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
static_assert
(
Dh_MAX
%
K_VEC_SIZE
==
0
,
""
);
constexpr
int
K_ELTS_PER_THREAD
=
Dh_MAX
/
THREADS_PER_KEY
;
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
int
ko
=
tid
/
THREADS_PER_KEY
;
int
ki
=
(
tid
%
THREADS_PER_KEY
)
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
,
""
);
K_vec
q
[
K_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
K_VECS_PER_THREAD
;
++
i
)
{
q
[
i
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
q_smem
[
ki
+
i
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
constexpr
int
K_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_KEY
;
constexpr
int
K_PER_WARP
=
WARP_SIZE
/
THREADS_PER_KEY
;
T
*
k_cache
=
&
params
.
cache_kv
[
bhi
*
params
.
max_seq_length
*
Dh
+
ki
];
int
ti_end
=
div_up
(
params
.
timestep
,
K_PER_WARP
)
*
K_PER_WARP
;
for
(
int
ti
=
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
max_seq_length
+
ti
;
if
(
ti
<
params
.
timestep
)
{
k
[
ii
]
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
max_seq_length
)
?
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
:
k_vec_zero
;
}
}
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q
,
k
,
params
.
inv_sqrt_dh
);
// bool is_mask = false;
if
(
ti
<
params
.
timestep
&&
tid
%
THREADS_PER_KEY
==
0
)
{
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T
mask
=
params
.
attn_mask
[
bi
*
(
params
.
timestep
+
1
)
+
ti
];
qk
+=
static_cast
<
float
>
(
mask
);
qk_max
=
fmaxf
(
qk_max
,
qk
);
qk_smem
[
ti
]
=
qk
;
}
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREADS_PER_KEY
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
const
int
warp
=
tid
/
WARP_SIZE
;
const
int
lane
=
tid
%
WARP_SIZE
;
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
qk_max
;
}
__syncthreads
();
qk_max
=
lane
<
WARPS_PER_BLOCK
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
qk_max
=
__shfl_sync
(
uint32_t
(
-
1
),
qk_max
,
0
);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"=======qk_out=======
\n
"
);
for
(
int
i
=
0
;
i
<=
params
.
timestep
;
++
i
)
printf
(
"%f "
,
qk_smem
[
i
]);
printf
(
"qk_max=%f
\n
"
,
qk_max
);
}
__syncthreads
();
#endif
float
sum
=
0.
f
;
for
(
int
ti
=
tid
;
ti
<=
params
.
timestep
;
ti
+=
THREADS_PER_BLOCK
)
{
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float
logit
=
__expf
(
qk_smem
[
ti
]
-
qk_max
);
sum
+=
logit
;
qk_smem
[
ti
]
=
logit
;
}
sum
=
block_sum
<
WARPS_PER_BLOCK
>
(
&
red_smem
[
WARPS_PER_BLOCK
],
sum
);
// FIXME(wangxi): need add 1.e-6f?
float
inv_sum
=
__fdividef
(
1.
f
,
sum
+
1.e-6
f
);
for
(
int
ti
=
tid
;
ti
<=
params
.
timestep
;
ti
+=
THREADS_PER_BLOCK
)
{
convert_from_float
(
logits_smem
[
ti
],
qk_smem
[
ti
]
*
inv_sum
);
}
__syncthreads
();
constexpr
int
V_VEC_SIZE
=
Dh_MAX
/
THREADS_PER_VALUE
;
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
int
vo
=
tid
/
THREADS_PER_VALUE
;
int
vi
=
(
tid
%
THREADS_PER_VALUE
)
*
V_VEC_SIZE
;
T
*
v_cache
=
&
params
.
cache_kv
[
params
.
batch_size
*
params
.
num_head
*
params
.
max_seq_length
*
Dh
+
bhi
*
params
.
max_seq_length
*
Dh
+
vi
];
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using
V_vec_acum
=
typename
V_vec_acum_fp32_
<
V_vec
>::
Type
;
#else
using
V_vec_acum
=
V_vec
;
#endif
V_vec_acum
out
;
zero
(
out
);
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
vo
;
ti
<
params
.
timestep
;
ti
+=
V_PER_ITER
)
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
ti
*
Dh
]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float
logit
=
logits_smem
[
ti
];
out
=
fma
(
logit
,
cast_to_float
(
v
),
out
);
#else
T
logit
=
logits_smem
[
ti
];
// Update the partial sums.
out
=
fma
(
logit
,
v
,
out
);
#endif
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"======logits_out=====
\n
"
);
for
(
int
i
=
0
;
i
<=
params
.
timestep
;
++
i
)
printf
(
"%f "
,
logits_smem
[
i
]);
printf
(
"
\n
"
);
}
__syncthreads
();
#endif
V_vec
v_bias
;
zero
(
v_bias
);
if
(
vo
==
(
params
.
timestep
%
V_PER_ITER
)
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv
[
2
*
params
.
num_head
*
Dh
+
qkv_base_offset
+
vi
]);
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv_bias
[
2
*
params
.
num_head
*
Dh
+
hi
*
Dh
+
vi
]);
v
=
add
(
v
,
v_bias
);
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
params
.
timestep
*
Dh
])
=
v
;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out
=
fma
(
logits_smem
[
params
.
timestep
],
cast_to_float
(
v
),
out
);
#else
out
=
fma
(
logits_smem
[
params
.
timestep
],
v
,
out
);
#endif
}
__syncthreads
();
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
#pragma unroll
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
int
midpoint
=
active_groups
/
2
;
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
}
__syncthreads
();
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
])
=
out
;
#endif
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
__syncthreads
();
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"======fmha_out=====
\n
"
);
for
(
int
i
=
0
;
i
<
Dh
;
++
i
)
printf
(
"%f "
,
static_cast
<
float
>
(
params
.
out
[
i
]));
printf
(
"
\n
"
);
}
#endif
#else
assert
(
false
);
#endif
}
template
<
typename
T
>
inline
size_t
smem_size_in_bytes
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
dim_head
,
int
threads_per_value
,
int
threads_per_block
)
{
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
logits_sz
=
0
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
logits_sz
=
div_up
(
params
.
max_seq_length
,
4
)
*
4
*
sizeof
(
T
);
}
#endif
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
size_t
red_sz
=
rows_per_red
*
dim_head
*
sizeof
(
T
)
/
2
;
return
max
(
softmax_sz
,
red_sz
);
}
#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
>
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh_MAX
*
sizeof
(
T
)
/
16
;
if
(
params
.
timestep
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
stream
);
}
else
if
(
params
.
timestep
<
2048
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
stream
);
}
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
stream
);
}
}
template
<
typename
T
>
void
fmha
(
const
phi
::
GPUContext
&
dev_ctx
,
const
Tensor
&
qkv_tensor
,
const
Tensor
&
qkv_bias_tensor
,
const
Tensor
&
src_mask_tensor
,
Tensor
*
cache_kv_tensor
,
Tensor
*
out_tensor
,
int
batch_size
,
int
max_seq_length
,
int
num_head
,
int
dim_head
,
int
timestep
,
float
inv_sqrt_dh
)
{
Masked_multihead_attention_params
<
T
>
params
;
params
.
out
=
out_tensor
->
data
<
T
>
();
params
.
qkv
=
qkv_tensor
.
data
<
T
>
();
params
.
qkv_bias
=
qkv_bias_tensor
.
data
<
T
>
();
params
.
attn_mask
=
src_mask_tensor
.
data
<
T
>
();
params
.
cache_kv
=
cache_kv_tensor
->
data
<
T
>
();
params
.
batch_size
=
batch_size
;
params
.
num_head
=
num_head
;
params
.
timestep
=
timestep
;
params
.
max_seq_length
=
max_seq_length
;
params
.
inv_sqrt_dh
=
inv_sqrt_dh
;
switch
(
dim_head
)
{
case
10
:
fmha_launch_kernel
<
T
,
10
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
26
:
fmha_launch_kernel
<
T
,
26
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
32
:
fmha_launch_kernel
<
T
,
32
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
64
:
fmha_launch_kernel
<
T
,
64
,
64
>
(
params
,
dev_ctx
.
stream
());
break
;
case
96
:
fmha_launch_kernel
<
T
,
96
,
128
>
(
params
,
dev_ctx
.
stream
());
break
;
case
128
:
fmha_launch_kernel
<
T
,
128
,
128
>
(
params
,
dev_ctx
.
stream
());
break
;
case
192
:
fmha_launch_kernel
<
T
,
192
,
256
>
(
params
,
dev_ctx
.
stream
());
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Dim_head = %d is unsupport!"
,
dim_head
));
}
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr
int
VEC_16B
=
16
;
template
<
typename
T
>
__global__
void
write_cache_k_kernel
(
T
*
cache_k
,
const
T
*
k
,
const
int
num_head
,
const
int
dim_head
,
const
int
seq_len
,
const
int
max_seq_len
)
{
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
z
;
constexpr
int
X_ELEMS
=
VEC_16B
/
sizeof
(
T
);
// [bsz, num_head, seq_len, dim_head/x, x]
auto
k_src
=
reinterpret_cast
<
const
uint4
*>
(
k
+
bi
*
num_head
*
seq_len
*
dim_head
+
hi
*
seq_len
*
dim_head
);
// [bsz, num_head, dim_head/x, max_seq_len, x]
auto
k_dst
=
reinterpret_cast
<
uint4
*>
(
cache_k
+
bi
*
num_head
*
max_seq_len
*
dim_head
+
hi
*
max_seq_len
*
dim_head
);
const
int
out_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// vec size
int
dim_head_div_x
=
dim_head
/
X_ELEMS
;
// FIXME(wangxi): num_head is not need?
// if (out_idx >= num_head * dim_head_div_x * max_seq_len) return;
if
(
out_idx
>=
dim_head_div_x
*
max_seq_len
)
return
;
int
idx
=
out_idx
;
const
int
k_seq_len_id
=
idx
%
max_seq_len
;
// idx = (idx - k_seq_len_id) / max_seq_len;
idx
=
idx
/
max_seq_len
;
const
int
k_vec_id
=
idx
%
dim_head_div_x
;
if
(
k_seq_len_id
<
seq_len
)
{
k_dst
[
out_idx
]
=
k_src
[
k_seq_len_id
*
dim_head_div_x
+
k_vec_id
];
}
}
template
<
typename
T
>
__global__
void
write_cache_v_kernel
(
T
*
cache_v
,
const
T
*
v
,
const
int
num_head
,
const
int
dim_head
,
const
int
seq_len
,
const
int
max_seq_len
)
{
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
z
;
// [bsz, num_head, seq_len, dim_head/x, x]
auto
v_src
=
reinterpret_cast
<
const
uint4
*>
(
v
+
bi
*
num_head
*
seq_len
*
dim_head
+
hi
*
seq_len
*
dim_head
);
// [bsz, num_head, max_seq_len, dim_head/x, x]
auto
v_dst
=
reinterpret_cast
<
uint4
*>
(
cache_v
+
bi
*
num_head
*
max_seq_len
*
dim_head
+
hi
*
max_seq_len
*
dim_head
);
const
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
constexpr
int
X_ELEMS
=
VEC_16B
/
sizeof
(
T
);
const
int
dim_head_div_x
=
dim_head
/
X_ELEMS
;
if
(
idx
>=
dim_head_div_x
*
seq_len
)
return
;
v_dst
[
idx
]
=
v_src
[
idx
];
}
template
<
typename
T
>
void
write_cache_kv
(
const
phi
::
GPUContext
&
dev_ctx
,
T
*
cache_k
,
T
*
cache_v
,
const
T
*
k
,
const
T
*
v
,
const
int
bsz
,
const
int
num_head
,
const
int
seq_len
,
const
int
max_seq_len
,
const
int
dim_head
)
{
constexpr
int
block_sz
=
128
;
constexpr
int
x
=
VEC_16B
/
sizeof
(
T
);
assert
(
dim_head
%
x
==
0
);
PADDLE_ENFORCE_EQ
(
dim_head
%
x
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"dim_head=%d must be divisible by vec_size=%d"
,
dim_head
,
x
));
int
max_size
=
max_seq_len
*
dim_head
/
x
;
int
size
=
seq_len
*
dim_head
/
x
;
dim3
grid
(
div_up
(
max_size
,
block_sz
),
bsz
,
num_head
);
dim3
grid_v
(
div_up
(
size
,
block_sz
),
bsz
,
num_head
);
// transpose [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, dim_head/x, max_seq_len, x]
write_cache_k_kernel
<<<
grid
,
block_sz
,
0
,
dev_ctx
.
stream
()
>>>
(
cache_k
,
k
,
num_head
,
dim_head
,
seq_len
,
max_seq_len
);
// copy [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, max_seq_len, dim_head/x, x]
write_cache_v_kernel
<<<
grid_v
,
block_sz
,
0
,
dev_ctx
.
stream
()
>>>
(
cache_v
,
v
,
num_head
,
dim_head
,
seq_len
,
max_seq_len
);
}
}
// namespace
template
<
typename
T
>
class
FusedMultiTransformerOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -1480,11 +338,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
if
(
pre_layer_norm
)
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf1
,
nullptr
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
buf1
->
numel
(),
dev_ctx
);
}
else
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf0
,
nullptr
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
buf0
->
numel
(),
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step4"
;
...
...
@@ -1563,9 +421,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
#endif
if
(
pre_layer_norm
)
{
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
buf1
->
numel
(),
dev_ctx
);
}
else
{
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
buf0
->
numel
(),
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.1"
;
...
...
paddle/fluid/operators/fused/fused_multi_transformer_op.h
0 → 100644
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This file has been adapted from FasterTransformer file:
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.
#include <cuda_fp16.h>
#include <float.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
// for debug
// #define _DEBUG_FUSED_MULTI_TRANSFORMER
template
<
typename
T
>
static
void
AllReduce
(
framework
::
Tensor
&
tensor
,
// NOLINT
const
int
ring_id
,
const
int
count
,
const
phi
::
GPUContext
&
ctx
)
{
if
(
ring_id
==
-
1
)
return
;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
map
=
paddle
::
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
if
(
map
->
has
(
ring_id
))
{
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
tensor
);
out_tensor
.
push_back
(
tensor
);
paddle
::
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
auto
task
=
pg
->
AllReduce
(
in_tensor
,
out_tensor
,
opts
);
task
->
Wait
();
}
else
{
auto
dtype
=
platform
::
ToNCCLDataType
(
framework
::
TransToProtoVarType
(
tensor
.
dtype
()));
int64_t
numel
=
tensor
.
numel
();
const
void
*
sendbuff
=
tensor
.
data
<
T
>
();
auto
place
=
ctx
.
GetPlace
();
void
*
recvbuff
=
tensor
.
mutable_data
<
T
>
(
place
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
auto
stream
=
ctx
.
stream
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
count
,
dtype
,
ncclSum
,
comm
->
comm
(),
stream
));
}
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."
));
#endif
}
namespace
{
// NOLINT
namespace
plat
=
paddle
::
platform
;
using
float16
=
plat
::
float16
;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
Masked_multihead_attention_params
{
// output buffer, [B, 1(seq_len), num_head * dim_head]
T
*
out
;
// qkv_out, [B, 1(seq_len), 3, num_head * dim_head]
const
T
*
qkv
;
// bias, [3, num_head, dim_head]
const
T
*
qkv_bias
;
// TODO(wangxi): optimize with input_lengths and max_input_len?
// [bsz, 1, 1, time_step(cache_seq_length)+1]
const
T
*
attn_mask
;
// [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head]
// k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first
// v [B, num_head, max_seq_len, dim_head]
T
*
cache_kv
;
int
batch_size
;
int
num_head
;
int
timestep
;
// cache_seq_length
int
max_seq_length
;
// 1.f / sqrt(Dh)
float
inv_sqrt_dh
;
};
struct
Float8_
{
float2
x
;
float2
y
;
float2
z
;
float2
w
;
};
// clang-format off
template
<
typename
T
,
int
Dh
>
struct
Qk_vec_
{};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float16
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
float16
,
256
>
{
using
Type
=
uint4
;
};
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_
<
float
,
1
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_
<
float16
,
4
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
K_vec_
<
float16
,
2
>
{
using
Type
=
uint2
;
};
template
<
>
struct
K_vec_
<
float16
,
1
>
{
using
Type
=
uint4
;
};
template
<
typename
T
,
int
V_VEC_SIZE
>
struct
V_vec_
{};
template
<
>
struct
V_vec_
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_
<
float
,
4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_
<
float16
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
V_vec_
<
float16
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
V_vec_
<
float16
,
8
>
{
using
Type
=
uint4
;
};
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
V_vec_acum_fp32_
{};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
// template <> struct V_vec_acum_fp32_<float2> { using Type = float2; };
template
<
>
struct
V_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
// template <> struct V_vec_acum_fp32_<uint32_t> { using Type = float2; };
// template <> struct V_vec_acum_fp32_<uint2 > { using Type = Float4_; };
template
<
>
struct
V_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
#endif
// clang-format on
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
inline
__device__
float2
half2_to_float2
(
uint32_t
v
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
v
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
inline
__device__
uint32_t
float2_to_half2
(
float2
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
tmp
.
u32
)
:
"f"
(
f
.
y
),
"f"
(
f
.
x
));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
.
x
));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
1
])
:
"f"
(
f
.
y
));
#endif
return
tmp
.
u32
;
}
inline
__device__
float
add
(
float
a
,
float
b
)
{
return
a
+
b
;
}
inline
__device__
float2
add
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
float4
add
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
uint16_t
add
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"add.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
inline
__device__
uint32_t
add
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
inline
__device__
uint2
add
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
uint4
add
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
float2
add
(
uint32_t
a
,
float2
fb
)
{
float2
fa
=
half2_to_float2
(
a
);
return
add
(
fa
,
fb
);
}
inline
__device__
Float8_
add
(
uint4
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
template
<
typename
Acc
,
typename
A
,
typename
B
>
inline
__device__
Acc
mul
(
A
a
,
B
b
);
template
<
>
inline
__device__
float
mul
<
float
,
float
>
(
float
a
,
float
b
)
{
return
a
*
b
;
}
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
return
c
;
}
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
c
.
z
=
a
.
z
*
b
.
z
;
c
.
w
=
a
.
w
*
b
.
w
;
return
c
;
}
template
<
>
inline
__device__
uint16_t
mul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"mul.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
float
b
)
{
float2
tmp
=
half2_to_float2
(
a
);
float2
tmp_res
;
tmp_res
.
x
=
tmp
.
x
*
b
;
tmp_res
.
y
=
tmp
.
y
*
b
;
uint32_t
res
=
float2_to_half2
(
tmp_res
);
return
res
;
}
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
float
b
)
{
uint2
res
;
res
.
x
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
x
,
b
);
res
.
y
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
y
,
b
);
return
res
;
}
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
float
b
)
{
uint4
res
;
res
.
x
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
x
,
b
);
res
.
y
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
y
,
b
);
res
.
z
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
z
,
b
);
res
.
w
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
w
,
b
);
return
res
;
}
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float
b
)
{
float2
res
;
res
.
x
=
a
.
x
*
b
;
res
.
y
=
a
.
y
*
b
;
return
res
;
}
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float
b
)
{
float4
res
;
res
.
x
=
a
.
x
*
b
;
res
.
y
=
a
.
y
*
b
;
res
.
z
=
a
.
z
*
b
;
res
.
w
=
a
.
w
*
b
;
return
res
;
}
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
inline
__device__
float
sum
(
uint16_t
v
)
{
return
half_to_float
(
v
);
}
inline
__device__
float
sum
(
uint32_t
v
)
{
float2
tmp
=
half2_to_float2
(
v
);
return
tmp
.
x
+
tmp
.
y
;
}
inline
__device__
float
sum
(
uint2
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
return
sum
(
c
);
}
inline
__device__
float
sum
(
uint4
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
c
=
add
(
c
,
v
.
z
);
c
=
add
(
c
,
v
.
w
);
return
sum
(
c
);
}
template
<
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
T
,
T
,
T
>
(
a
,
b
));
}
template
<
typename
A
,
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
A
,
T
,
T
>
(
a
,
b
));
}
inline
__device__
constexpr
uint32_t
shfl_mask
(
int
threads
)
{
return
threads
==
32
?
uint32_t
(
-
1
)
:
(
1u
<<
threads
)
-
1u
;
}
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
inline
__device__
float
fma
(
float
a
,
float
b
,
float
c
)
{
return
a
*
b
+
c
;
}
inline
__device__
float2
fma
(
float2
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
float4
fma
(
float4
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
uint32_t
fma
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
inline
__device__
uint2
fma
(
uint2
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
uint4
fma
(
uint4
a
,
uint4
b
,
uint4
c
)
{
uint4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float2
fma
(
float
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
float4
fma
(
float
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
Float8_
fma
(
float
a
,
Float8_
b
,
Float8_
c
)
{
Float8_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
uint32_t
h0_h0
(
uint16_t
a
)
{
uint32_t
b
;
asm
volatile
(
"mov.b32 %0, {%1, %1};"
:
"=r"
(
b
)
:
"h"
(
a
));
return
b
;
}
inline
__device__
uint32_t
fma
(
uint16_t
a
,
uint32_t
b
,
uint32_t
c
)
{
return
fma
(
h0_h0
(
a
),
b
,
c
);
}
inline
__device__
uint2
fma
(
uint16_t
a
,
uint2
b
,
uint2
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
uint4
fma
(
uint16_t
a
,
uint4
b
,
uint4
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float
cast_to_float
(
float
u
)
{
return
u
;
}
inline
__device__
float2
cast_to_float
(
float2
u
)
{
return
u
;
}
inline
__device__
float4
cast_to_float
(
float4
u
)
{
return
u
;
}
inline
__device__
Float8_
cast_to_float
(
uint4
u
)
{
Float8_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
tmp
.
z
=
half2_to_float2
(
u
.
z
);
tmp
.
w
=
half2_to_float2
(
u
.
w
);
return
tmp
;
}
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
],
float
inv_sqrt_dh
)
{
K_vec
inv_q
=
mul
<
K_vec
,
K_vec
,
float
>
(
q
[
0
],
inv_sqrt_dh
);
K_vec
qk_vec
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
inv_q
,
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
inv_q
=
mul
<
K_vec
,
K_vec
,
float
>
(
q
[
ii
],
inv_sqrt_dh
);
qk_vec
=
fma
(
inv_q
,
k
[
ii
],
qk_vec
);
}
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREADS_PER_KEY
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
],
float
inv_sqrt_dh
)
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
,
inv_sqrt_dh
);
}
};
template
<
int
WARPS_PER_BLOCK
,
int
WARP_SIZE
=
32
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
__syncthreads
();
if
(
lane
<
WARPS_PER_BLOCK
)
{
sum
=
red_smem
[
lane
];
}
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
inline
__device__
void
convert_from_float
(
float
&
dst
,
float
src
)
{
// NOLINT
dst
=
src
;
}
inline
__device__
void
convert_from_float
(
float4
&
dst
,
float4
src
)
{
// NOLINT
dst
=
src
;
}
inline
__device__
void
convert_from_float
(
plat
::
float16
&
dst
,
// NOLINT
float
src
)
{
dst
=
static_cast
<
plat
::
float16
>
(
src
);
}
inline
__device__
void
convert_from_float
(
uint4
&
dst
,
Float8_
src
)
{
// NOLINT
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
dst
.
z
=
float2_to_half2
(
src
.
z
);
dst
.
w
=
float2_to_half2
(
src
.
w
);
}
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
// NOLINT
template
<
typename
T
>
inline
__device__
void
zero
(
T
&
dst
)
{
// NOLINT
constexpr
int
WORDS
=
sizeof
(
T
)
/
4
;
union
{
T
raw
;
uint32_t
words
[
WORDS
];
}
tmp
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
WORDS
;
++
ii
)
{
tmp
.
words
[
ii
]
=
0u
;
}
dst
=
tmp
.
raw
;
}
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
int
THREADS_PER_KEY
,
int
THREADS_PER_VALUE
,
int
THREADS_PER_BLOCK
>
__global__
void
masked_multihead_attention_kernel
(
Masked_multihead_attention_params
<
T
>
params
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert
(
Dh_MAX
%
THREADS_PER_KEY
==
0
,
""
);
static_assert
(
Dh_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
extern
__shared__
char
smem_
[];
float
*
qk_smem
=
reinterpret_cast
<
float
*>
(
smem_
);
char
*
logits_smem_
=
smem_
;
// fp32 accum for logits
float
*
logits_smem
=
reinterpret_cast
<
float
*>
(
logits_smem_
);
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
x
;
const
int
bhi
=
bi
*
params
.
num_head
+
hi
;
const
int
tid
=
threadIdx
.
x
;
float
qk_max
=
-
FLT_MAX
;
float
qk
=
0
;
// qkv [B, S=1, 3, num_head, head_dim]
int
qkv_base_offset
=
bi
*
3
*
params
.
num_head
*
Dh
+
hi
*
Dh
;
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
constexpr
int
QK_ELTS_IN_16B
=
16
/
sizeof
(
T
);
constexpr
int
QK_VECS_IN_16B
=
16
/
sizeof
(
Qk_vec
);
const
T
*
q_base
=
params
.
qkv
;
const
T
*
k_base
=
params
.
qkv
+
params
.
num_head
*
Dh
;
const
T
*
q_bias_base
=
params
.
qkv_bias
;
const
T
*
k_bias_base
=
params
.
qkv_bias
+
params
.
num_head
*
Dh
;
if
(
tid
<
QK_VECS_PER_WARP
)
{
int
qk_offset
=
qkv_base_offset
+
tid
*
QK_VEC_SIZE
;
int
qk_bias_offset
=
hi
*
Dh
+
tid
*
QK_VEC_SIZE
;
Qk_vec
q
;
zero
(
q
);
q
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_base
[
qk_offset
])
:
q
;
Qk_vec
k
;
zero
(
k
);
k
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_base
[
qk_offset
])
:
k
;
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_bias_base
[
qk_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
k_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_bias_base
[
qk_bias_offset
])
:
k_bias
;
q
=
add
(
q
,
q_bias
);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// we may not require k_bias.
k
=
add
(
k
,
k_bias
);
*
reinterpret_cast
<
Qk_vec
*>
(
&
q_smem
[
tid
*
QK_VEC_SIZE
])
=
q
;
int
co
=
tid
/
QK_VECS_IN_16B
;
int
ci
=
(
tid
%
QK_VECS_IN_16B
)
*
QK_VEC_SIZE
;
int
offset
=
bhi
*
params
.
max_seq_length
*
Dh
+
co
*
params
.
max_seq_length
*
QK_ELTS_IN_16B
+
params
.
timestep
*
QK_ELTS_IN_16B
+
ci
;
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
cache_kv
[
offset
])
=
k
;
}
qk
=
dot
<
Qk_vec
,
Qk_vec
>
(
q
,
k
);
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
#pragma unroll
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
if
(
tid
==
0
)
{
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk
*=
params
.
inv_sqrt_dh
;
qk_max
=
qk
;
qk_smem
[
params
.
timestep
]
=
qk
;
}
__syncthreads
();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"=======q_out=======
\n
"
);
for
(
int
i
=
0
;
i
<
Dh
;
++
i
)
printf
(
"%f "
,
static_cast
<
float
>
(
q_smem
[
i
]));
printf
(
"
\n
"
);
}
__syncthreads
();
#endif
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
static_assert
(
Dh_MAX
%
K_VEC_SIZE
==
0
,
""
);
constexpr
int
K_ELTS_PER_THREAD
=
Dh_MAX
/
THREADS_PER_KEY
;
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
int
ko
=
tid
/
THREADS_PER_KEY
;
int
ki
=
(
tid
%
THREADS_PER_KEY
)
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
,
""
);
K_vec
q
[
K_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
K_VECS_PER_THREAD
;
++
i
)
{
q
[
i
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
q_smem
[
ki
+
i
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
constexpr
int
K_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_KEY
;
constexpr
int
K_PER_WARP
=
WARP_SIZE
/
THREADS_PER_KEY
;
T
*
k_cache
=
&
params
.
cache_kv
[
bhi
*
params
.
max_seq_length
*
Dh
+
ki
];
int
ti_end
=
div_up
(
params
.
timestep
,
K_PER_WARP
)
*
K_PER_WARP
;
for
(
int
ti
=
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
max_seq_length
+
ti
;
if
(
ti
<
params
.
timestep
)
{
k
[
ii
]
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
max_seq_length
)
?
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
:
k_vec_zero
;
}
}
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q
,
k
,
params
.
inv_sqrt_dh
);
// bool is_mask = false;
if
(
ti
<
params
.
timestep
&&
tid
%
THREADS_PER_KEY
==
0
)
{
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T
mask
=
params
.
attn_mask
[
bi
*
(
params
.
timestep
+
1
)
+
ti
];
qk
+=
static_cast
<
float
>
(
mask
);
qk_max
=
fmaxf
(
qk_max
,
qk
);
qk_smem
[
ti
]
=
qk
;
}
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREADS_PER_KEY
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
const
int
warp
=
tid
/
WARP_SIZE
;
const
int
lane
=
tid
%
WARP_SIZE
;
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
qk_max
;
}
__syncthreads
();
qk_max
=
lane
<
WARPS_PER_BLOCK
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
qk_max
=
__shfl_sync
(
uint32_t
(
-
1
),
qk_max
,
0
);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"=======qk_out=======
\n
"
);
for
(
int
i
=
0
;
i
<=
params
.
timestep
;
++
i
)
printf
(
"%f "
,
qk_smem
[
i
]);
printf
(
"qk_max=%f
\n
"
,
qk_max
);
}
__syncthreads
();
#endif
float
sum
=
0.
f
;
for
(
int
ti
=
tid
;
ti
<=
params
.
timestep
;
ti
+=
THREADS_PER_BLOCK
)
{
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float
logit
=
__expf
(
qk_smem
[
ti
]
-
qk_max
);
sum
+=
logit
;
qk_smem
[
ti
]
=
logit
;
}
sum
=
block_sum
<
WARPS_PER_BLOCK
>
(
&
red_smem
[
WARPS_PER_BLOCK
],
sum
);
// FIXME(wangxi): need add 1.e-6f?
float
inv_sum
=
__fdividef
(
1.
f
,
sum
+
1.e-6
f
);
for
(
int
ti
=
tid
;
ti
<=
params
.
timestep
;
ti
+=
THREADS_PER_BLOCK
)
{
convert_from_float
(
logits_smem
[
ti
],
qk_smem
[
ti
]
*
inv_sum
);
}
__syncthreads
();
constexpr
int
V_VEC_SIZE
=
Dh_MAX
/
THREADS_PER_VALUE
;
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
int
vo
=
tid
/
THREADS_PER_VALUE
;
int
vi
=
(
tid
%
THREADS_PER_VALUE
)
*
V_VEC_SIZE
;
T
*
v_cache
=
&
params
.
cache_kv
[
params
.
batch_size
*
params
.
num_head
*
params
.
max_seq_length
*
Dh
+
bhi
*
params
.
max_seq_length
*
Dh
+
vi
];
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using
V_vec_acum
=
typename
V_vec_acum_fp32_
<
V_vec
>::
Type
;
#else
using
V_vec_acum
=
V_vec
;
#endif
V_vec_acum
out
;
zero
(
out
);
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
vo
;
ti
<
params
.
timestep
;
ti
+=
V_PER_ITER
)
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
ti
*
Dh
]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float
logit
=
logits_smem
[
ti
];
out
=
fma
(
logit
,
cast_to_float
(
v
),
out
);
#else
T
logit
=
logits_smem
[
ti
];
// Update the partial sums.
out
=
fma
(
logit
,
v
,
out
);
#endif
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"======logits_out=====
\n
"
);
for
(
int
i
=
0
;
i
<=
params
.
timestep
;
++
i
)
printf
(
"%f "
,
logits_smem
[
i
]);
printf
(
"
\n
"
);
}
__syncthreads
();
#endif
V_vec
v_bias
;
zero
(
v_bias
);
if
(
vo
==
(
params
.
timestep
%
V_PER_ITER
)
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv
[
2
*
params
.
num_head
*
Dh
+
qkv_base_offset
+
vi
]);
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv_bias
[
2
*
params
.
num_head
*
Dh
+
hi
*
Dh
+
vi
]);
v
=
add
(
v
,
v_bias
);
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
params
.
timestep
*
Dh
])
=
v
;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out
=
fma
(
logits_smem
[
params
.
timestep
],
cast_to_float
(
v
),
out
);
#else
out
=
fma
(
logits_smem
[
params
.
timestep
],
v
,
out
);
#endif
}
__syncthreads
();
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
#pragma unroll
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
int
midpoint
=
active_groups
/
2
;
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
}
__syncthreads
();
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
])
=
out
;
#endif
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
__syncthreads
();
if
(
bi
==
0
&&
hi
==
0
&&
tid
==
0
)
{
printf
(
"======fmha_out=====
\n
"
);
for
(
int
i
=
0
;
i
<
Dh
;
++
i
)
printf
(
"%f "
,
static_cast
<
float
>
(
params
.
out
[
i
]));
printf
(
"
\n
"
);
}
#endif
#else
assert
(
false
);
#endif
}
template
<
typename
T
>
inline
size_t
smem_size_in_bytes
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
dim_head
,
int
threads_per_value
,
int
threads_per_block
)
{
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
logits_sz
=
0
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS // NOLINT
if
(
sizeof
(
T
)
!=
4
)
{
logits_sz
=
div_up
(
params
.
max_seq_length
,
4
)
*
4
*
sizeof
(
T
);
}
#endif // NOLINT
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
size_t
red_sz
=
rows_per_red
*
dim_head
*
sizeof
(
T
)
/
2
;
return
max
(
softmax_sz
,
red_sz
);
}
#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
>
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh_MAX
*
sizeof
(
T
)
/
16
;
if
(
params
.
timestep
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
stream
);
}
else
if
(
params
.
timestep
<
2048
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
stream
);
}
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
stream
);
}
}
template
<
typename
T
>
void
fmha
(
const
phi
::
GPUContext
&
dev_ctx
,
const
Tensor
&
qkv_tensor
,
const
Tensor
&
qkv_bias_tensor
,
const
Tensor
&
src_mask_tensor
,
Tensor
*
cache_kv_tensor
,
Tensor
*
out_tensor
,
int
batch_size
,
int
max_seq_length
,
int
num_head
,
int
dim_head
,
int
timestep
,
float
inv_sqrt_dh
)
{
Masked_multihead_attention_params
<
T
>
params
;
params
.
out
=
out_tensor
->
data
<
T
>
();
params
.
qkv
=
qkv_tensor
.
data
<
T
>
();
params
.
qkv_bias
=
qkv_bias_tensor
.
data
<
T
>
();
params
.
attn_mask
=
src_mask_tensor
.
data
<
T
>
();
params
.
cache_kv
=
cache_kv_tensor
->
data
<
T
>
();
params
.
batch_size
=
batch_size
;
params
.
num_head
=
num_head
;
params
.
timestep
=
timestep
;
params
.
max_seq_length
=
max_seq_length
;
params
.
inv_sqrt_dh
=
inv_sqrt_dh
;
switch
(
dim_head
)
{
case
10
:
fmha_launch_kernel
<
T
,
10
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
26
:
fmha_launch_kernel
<
T
,
26
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
32
:
fmha_launch_kernel
<
T
,
32
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
64
:
fmha_launch_kernel
<
T
,
64
,
64
>
(
params
,
dev_ctx
.
stream
());
break
;
case
96
:
fmha_launch_kernel
<
T
,
96
,
128
>
(
params
,
dev_ctx
.
stream
());
break
;
case
128
:
fmha_launch_kernel
<
T
,
128
,
128
>
(
params
,
dev_ctx
.
stream
());
break
;
case
192
:
fmha_launch_kernel
<
T
,
192
,
256
>
(
params
,
dev_ctx
.
stream
());
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Dim_head = %d is unsupport!"
,
dim_head
));
}
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr
int
VEC_16B
=
16
;
template
<
typename
T
>
__global__
void
write_cache_k_kernel
(
T
*
cache_k
,
const
T
*
k
,
const
int
num_head
,
const
int
dim_head
,
const
int
seq_len
,
const
int
max_seq_len
)
{
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
z
;
constexpr
int
X_ELEMS
=
VEC_16B
/
sizeof
(
T
);
// [bsz, num_head, seq_len, dim_head/x, x]
auto
k_src
=
reinterpret_cast
<
const
uint4
*>
(
k
+
bi
*
num_head
*
seq_len
*
dim_head
+
hi
*
seq_len
*
dim_head
);
// [bsz, num_head, dim_head/x, max_seq_len, x]
auto
k_dst
=
reinterpret_cast
<
uint4
*>
(
cache_k
+
bi
*
num_head
*
max_seq_len
*
dim_head
+
hi
*
max_seq_len
*
dim_head
);
const
int
out_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// vec size
int
dim_head_div_x
=
dim_head
/
X_ELEMS
;
// FIXME(wangxi): num_head is not need?
// if (out_idx >= num_head * dim_head_div_x * max_seq_len) return;
if
(
out_idx
>=
dim_head_div_x
*
max_seq_len
)
return
;
int
idx
=
out_idx
;
const
int
k_seq_len_id
=
idx
%
max_seq_len
;
// idx = (idx - k_seq_len_id) / max_seq_len;
idx
=
idx
/
max_seq_len
;
const
int
k_vec_id
=
idx
%
dim_head_div_x
;
if
(
k_seq_len_id
<
seq_len
)
{
k_dst
[
out_idx
]
=
k_src
[
k_seq_len_id
*
dim_head_div_x
+
k_vec_id
];
}
}
template
<
typename
T
>
__global__
void
write_cache_v_kernel
(
T
*
cache_v
,
const
T
*
v
,
const
int
num_head
,
const
int
dim_head
,
const
int
seq_len
,
const
int
max_seq_len
)
{
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
z
;
// [bsz, num_head, seq_len, dim_head/x, x]
auto
v_src
=
reinterpret_cast
<
const
uint4
*>
(
v
+
bi
*
num_head
*
seq_len
*
dim_head
+
hi
*
seq_len
*
dim_head
);
// [bsz, num_head, max_seq_len, dim_head/x, x]
auto
v_dst
=
reinterpret_cast
<
uint4
*>
(
cache_v
+
bi
*
num_head
*
max_seq_len
*
dim_head
+
hi
*
max_seq_len
*
dim_head
);
const
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
constexpr
int
X_ELEMS
=
VEC_16B
/
sizeof
(
T
);
const
int
dim_head_div_x
=
dim_head
/
X_ELEMS
;
if
(
idx
>=
dim_head_div_x
*
seq_len
)
return
;
v_dst
[
idx
]
=
v_src
[
idx
];
}
template
<
typename
T
>
void
write_cache_kv
(
const
phi
::
GPUContext
&
dev_ctx
,
T
*
cache_k
,
T
*
cache_v
,
const
T
*
k
,
const
T
*
v
,
const
int
bsz
,
const
int
num_head
,
const
int
seq_len
,
const
int
max_seq_len
,
const
int
dim_head
)
{
constexpr
int
block_sz
=
128
;
constexpr
int
x
=
VEC_16B
/
sizeof
(
T
);
assert
(
dim_head
%
x
==
0
);
PADDLE_ENFORCE_EQ
(
dim_head
%
x
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"dim_head=%d must be divisible by vec_size=%d"
,
dim_head
,
x
));
int
max_size
=
max_seq_len
*
dim_head
/
x
;
int
size
=
seq_len
*
dim_head
/
x
;
dim3
grid
(
div_up
(
max_size
,
block_sz
),
bsz
,
num_head
);
dim3
grid_v
(
div_up
(
size
,
block_sz
),
bsz
,
num_head
);
// transpose [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, dim_head/x, max_seq_len, x]
write_cache_k_kernel
<<<
grid
,
block_sz
,
0
,
dev_ctx
.
stream
()
>>>
(
cache_k
,
k
,
num_head
,
dim_head
,
seq_len
,
max_seq_len
);
// copy [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, max_seq_len, dim_head/x, x]
write_cache_v_kernel
<<<
grid_v
,
block_sz
,
0
,
dev_ctx
.
stream
()
>>>
(
cache_v
,
v
,
num_head
,
dim_head
,
seq_len
,
max_seq_len
);
}
}
// namespace
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
浏览文件 @
3d7e2118
...
...
@@ -28,7 +28,9 @@ template <typename T,
int
VecSize
,
bool
ComputeLayerNorm
,
bool
Activation
,
typename
Functor
>
typename
Functor
,
typename
InType
=
T
,
typename
OutType
=
T
>
__forceinline__
__device__
void
FusedResidualDropoutBiasOneThread
(
const
int
row_id
,
const
int
col_id
,
...
...
@@ -36,30 +38,45 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
curandStatePhilox4_32_10_t
*
state
,
const
float
dropout_prob
,
const
T
factor
,
const
T
*
__restrict__
src
,
const
InType
*
__restrict__
src
,
const
T
*
__restrict__
residual
,
const
T
*
__restrict__
bias
,
T
*
dst
,
OutType
*
dst
,
MaskType
*
mask
,
const
bool
is_test
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
mean_val
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
var_val
,
Functor
act_func
)
{
Functor
act_func
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
LoadInType
=
phi
::
AlignedVector
<
InType
,
VecSize
>
;
using
LoadFloat
=
phi
::
AlignedVector
<
float
,
VecSize
>
;
using
StoreT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreOutType
=
phi
::
AlignedVector
<
OutType
,
VecSize
>
;
using
MaskStoreT
=
phi
::
AlignedVector
<
MaskType
,
VecSize
>
;
using
U
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
Load
T
src_vec
;
Load
InType
src_vec
;
LoadT
residual_vec
;
LoadT
bias_vec
;
LoadFloat
quant_out_scale_vec
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
bias_vec
[
ii
]
=
static_cast
<
T
>
(
0
);
residual_vec
[
ii
]
=
static_cast
<
T
>
(
0
);
}
// vectorize load data from global
phi
::
Load
<
T
,
VecSize
>
(
&
src
[
row_id
*
cols
+
col_id
],
&
src_vec
);
phi
::
Load
<
InType
,
VecSize
>
(
&
src
[
row_id
*
cols
+
col_id
],
&
src_vec
);
phi
::
Load
<
float
,
VecSize
>
(
&
dequant_out_scale_data
[
quant_out_scale_offset
+
col_id
],
&
quant_out_scale_vec
);
if
(
residual
)
{
phi
::
Load
<
T
,
VecSize
>
(
&
residual
[
row_id
*
cols
+
col_id
],
&
residual_vec
);
}
...
...
@@ -84,10 +101,18 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
}
StoreT
dest_vec
;
StoreOutType
dest_vec_out_type
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
T
tmp
=
src_vec
[
ii
]
+
bias_vec
[
ii
];
T
tmp
;
if
(
std
::
is_same
<
InType
,
int32_t
>::
value
)
{
T
tmp0
=
static_cast
<
T
>
(
static_cast
<
float
>
(
src_vec
[
ii
])
*
quant_last_in_scale
/
quant_out_scale_vec
[
ii
]);
tmp
=
tmp0
+
bias_vec
[
ii
];
}
else
{
tmp
=
static_cast
<
T
>
(
src_vec
[
ii
])
+
bias_vec
[
ii
];
}
if
(
Activation
)
{
tmp
=
act_func
(
tmp
);
}
...
...
@@ -98,10 +123,23 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
*
mean_val
+=
tmp
;
*
var_val
+=
(
tmp
*
tmp
);
}
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
dest_vec_out_type
[
ii
]
=
quant_helper
(
dest_vec
[
ii
],
quant_next_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
}
// store result to global
phi
::
Store
<
T
,
VecSize
>
(
dest_vec
,
&
dst
[
row_id
*
cols
+
col_id
]);
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
phi
::
Store
<
OutType
,
VecSize
>
(
dest_vec_out_type
,
&
dst
[
row_id
*
cols
+
col_id
]);
}
else
{
phi
::
Store
<
T
,
VecSize
>
(
dest_vec
,
reinterpret_cast
<
T
*>
(
&
dst
[
row_id
*
cols
+
col_id
]));
}
if
(
!
is_test
)
{
phi
::
Store
<
MaskType
,
VecSize
>
(
mask_vec
,
&
mask
[
row_id
*
cols
+
col_id
]);
}
...
...
@@ -114,19 +152,28 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
* is_test: only used in inference
* mask: can be null if is_test=true
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
FusedResidualDropoutBias
(
const
size_t
rows
,
const
size_t
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
residual
,
const
T
*
__restrict__
bias
,
MaskType
*
mask
,
T
*
dst
,
uint64_t
increment
,
const
bool
is_test
)
{
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
InType
=
T
,
typename
OutType
=
T
>
__global__
void
FusedResidualDropoutBias
(
const
size_t
rows
,
const
size_t
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
InType
*
__restrict__
src
,
const
T
*
__restrict__
residual
,
const
T
*
__restrict__
bias
,
MaskType
*
mask
,
OutType
*
dst
,
uint64_t
increment
,
const
bool
is_test
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
)
{
int
col_id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row_id
=
blockIdx
.
y
;
int
idx
=
row_id
*
cols
+
col_id
;
...
...
@@ -142,22 +189,27 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
VecSize
,
false
,
false
,
phi
::
funcs
::
ReluFunctor
<
T
>>
(
r
,
i
,
cols
,
&
state
,
dropout_prob
,
factor
,
src
,
residual
,
bias
,
dst
,
mask
,
is_test
,
nullptr
,
nullptr
,
relu
);
phi
::
funcs
::
ReluFunctor
<
T
>
,
InType
,
OutType
>
(
r
,
i
,
cols
,
&
state
,
dropout_prob
,
factor
,
src
,
residual
,
bias
,
dst
,
mask
,
is_test
,
nullptr
,
nullptr
,
relu
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
);
}
}
}
...
...
@@ -165,7 +217,10 @@ __global__ void FusedResidualDropoutBias(const size_t rows,
/**
* @brief dst = residual + dropout(src + bias);
*/
template
<
typename
T
,
typename
MaskType
>
template
<
typename
T
,
typename
MaskType
,
typename
InType
=
T
,
typename
OutType
=
T
>
void
LaunchResidualDropoutBias
(
const
uint32_t
rows
,
const
uint32_t
cols
,
const
int
increment
,
...
...
@@ -173,14 +228,19 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const
float
dropout_prob
,
const
bool
is_test
,
bool
is_upscale_in_train
,
const
T
*
src
,
const
InType
*
src
,
const
T
*
residual
,
const
T
*
bias
,
MaskType
*
mask_data
,
T
*
dst
,
const
phi
::
GPUContext
&
ctx
)
{
OutType
*
dst
,
const
phi
::
GPUContext
&
ctx
,
const
float
quant_last_in_scale
=
1.0
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_next_in_scale
=
1.0
)
{
// dropout_prob == 1.0f
if
(
std
::
abs
(
dropout_prob
-
1.0
f
)
<
1e-5
)
{
// NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
if
(
residual
==
dst
)
return
;
if
(
residual
)
{
memory
::
Copy
(
ctx
.
GetPlace
(),
...
...
@@ -202,7 +262,7 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
auto
config
=
Get1DBlocksAnd2DGrids
(
ctx
,
rows
,
cols
,
real_vec_size
);
if
(
cols
%
VecSize
==
0
)
{
FusedResidualDropoutBias
<
T
,
uint8_t
,
VecSize
>
FusedResidualDropoutBias
<
T
,
uint8_t
,
VecSize
,
InType
,
OutType
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
...
...
@@ -215,9 +275,13 @@ void LaunchResidualDropoutBias(const uint32_t rows,
mask_data
,
dst
,
increment
,
is_test
);
is_test
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
);
}
else
{
FusedResidualDropoutBias
<
T
,
uint8_t
,
1
>
FusedResidualDropoutBias
<
T
,
uint8_t
,
1
,
InType
,
OutType
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
...
...
@@ -230,7 +294,11 @@ void LaunchResidualDropoutBias(const uint32_t rows,
mask_data
,
dst
,
increment
,
is_test
);
is_test
,
quant_last_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
,
quant_next_in_scale
);
}
}
...
...
paddle/fluid/operators/fused/quant_dequant_kernel.h
0 → 100644
浏览文件 @
3d7e2118
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__forceinline__
__device__
int8_t
quant_helper
(
const
T
input
,
const
float
scale
,
const
int
round_type
,
const
float
max_bound
,
const
float
min_bound
)
{
float
quant_value
=
max_bound
*
inverse
(
scale
)
*
static_cast
<
float
>
(
input
);
if
(
round_type
==
0
)
{
quant_value
=
static_cast
<
float
>
(
roundWithTiesToEven
(
quant_value
));
}
else
{
quant_value
=
static_cast
<
float
>
(
round
(
quant_value
));
}
quant_value
=
quant_value
>
max_bound
?
max_bound
:
quant_value
;
quant_value
=
quant_value
<
min_bound
?
min_bound
:
quant_value
;
return
static_cast
<
int8_t
>
(
quant_value
);
}
template
<
typename
T
>
__global__
void
quantize_kernel
(
const
T
*
input
,
char4
*
output
,
const
float
scale
,
const
int
m
,
const
int
n
,
const
int
round_type
,
const
float
max_bound
,
const
float
min_bound
)
{
int
n_id
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
<<
2
;
int
m_id
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
bool
check
=
((
m_id
<
m
)
&&
(
n_id
<
n
));
if
(
check
)
{
char4
tmp
;
tmp
.
x
=
quant_helper
(
input
[
m_id
*
n
+
n_id
],
scale
,
round_type
,
max_bound
,
min_bound
);
tmp
.
y
=
quant_helper
(
input
[
m_id
*
n
+
n_id
+
1
],
scale
,
round_type
,
max_bound
,
min_bound
);
tmp
.
z
=
quant_helper
(
input
[
m_id
*
n
+
n_id
+
2
],
scale
,
round_type
,
max_bound
,
min_bound
);
tmp
.
w
=
quant_helper
(
input
[
m_id
*
n
+
n_id
+
3
],
scale
,
round_type
,
max_bound
,
min_bound
);
output
[(
m_id
*
n
+
n_id
)
>>
2
]
=
tmp
;
}
}
template
<
typename
T
>
void
quantize_kernel_launcher
(
const
T
*
input
,
int8_t
*
output
,
const
float
scale
,
const
int
m
,
const
int
n
,
const
int
round_type
,
const
float
max_bound
,
const
float
min_bound
,
gpuStream_t
stream
)
{
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3
grid
((
n
+
31
)
/
32
,
(
m
+
31
)
/
32
);
dim3
block
(
32
,
32
);
quantize_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
input
,
(
char4
*
)
output
,
// NOLINT
scale
,
m
,
n
,
round_type
,
max_bound
,
min_bound
);
}
// dequantize using weight scales and input scales
template
<
typename
T
>
__global__
void
dequantize_kernel
(
T
*
output
,
const
int32_t
*
input
,
const
int
m
,
// hidden
const
int
n
,
// batch size
const
float
quant_in_scale
,
const
float
*
dequant_out_scale_data
,
const
int
quant_out_scale_offset
)
{
int
m_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// hidden
int
n_id
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
// batch size
bool
check
=
((
m_id
<
m
)
&&
(
n_id
<
n
));
if
(
check
)
{
float
out_scale
=
dequant_out_scale_data
[
quant_out_scale_offset
+
m_id
];
output
[
n_id
*
m
+
m_id
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
input
[
n_id
*
m
+
m_id
])
*
quant_in_scale
/
out_scale
);
}
}
template
<
typename
T
>
void
dequantize_kernel_launcher
(
const
int32_t
*
input
,
T
*
output
,
const
int
batch_size
,
// m
const
int
hidden_units
,
// n
gpuStream_t
stream
,
const
float
quant_in_scale
,
const
float
*
dequant_out_scale_data
,
const
int
quant_out_scale_offset
)
{
dim3
grid
((
hidden_units
+
31
)
/
32
,
(
batch_size
+
31
)
/
32
);
dim3
block
(
32
,
32
);
dequantize_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
output
,
input
,
hidden_units
,
batch_size
,
quant_in_scale
,
dequant_out_scale_data
,
quant_out_scale_offset
);
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/layer_norm_kernel.cu.h
浏览文件 @
3d7e2118
...
...
@@ -24,6 +24,7 @@ namespace cub = hipcub;
#include <iostream>
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
...
...
@@ -338,16 +339,24 @@ using LayerNormScaleBiasT =
template
<
typename
T
,
typename
U
,
int
BlockDim
,
bool
ScaleBiasWithSameTypeX
=
false
>
bool
ScaleBiasWithSameTypeX
=
false
,
typename
InType
=
T
,
typename
OutType
=
T
>
__global__
void
LayerNormForward
(
const
T
*
x
,
const
InType
*
x
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
scale
,
const
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
*
bias
,
T
*
y
,
OutType
*
y
,
U
*
mean
,
U
*
var
,
float
epsilon
,
int64_t
feature_size
)
{
int64_t
feature_size
,
const
float
*
dequant_out_scale_data
=
nullptr
,
const
int
quant_out_scale_offset
=
0
,
const
float
quant_in_scale
=
1.0
,
const
int
quant_round_type
=
1
,
const
float
quant_max_bound
=
127.0
,
const
float
quant_min_bound
=
-
127.0
)
{
__shared__
U
mean_share
;
__shared__
U
var_share
;
__shared__
U
shared_mean
[
32
];
// threadIdx.x / warpSize <= kMaxBlockDim /
...
...
@@ -387,28 +396,72 @@ __global__ void LayerNormForward(
if
(
bias
!=
nullptr
)
{
for
(
int64_t
i
=
beg_idx
,
j
=
threadIdx
.
x
;
i
<
end_idx
;
i
+=
BlockDim
,
j
+=
BlockDim
)
{
y
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
scale
[
j
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
+
static_cast
<
U
>
(
bias
[
j
]));
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
y
[
i
]
=
quant_helper
(
static_cast
<
T
>
(
static_cast
<
U
>
(
scale
[
j
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
+
static_cast
<
U
>
(
bias
[
j
])),
quant_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
y
[
i
]
=
static_cast
<
OutType
>
(
static_cast
<
U
>
(
scale
[
j
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
+
static_cast
<
U
>
(
bias
[
j
]));
}
}
}
else
{
for
(
int64_t
i
=
beg_idx
,
j
=
threadIdx
.
x
;
i
<
end_idx
;
i
+=
BlockDim
,
j
+=
BlockDim
)
{
y
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
scale
[
j
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
);
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
y
[
i
]
=
quant_helper
(
static_cast
<
T
>
(
static_cast
<
U
>
(
scale
[
j
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
),
quant_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
y
[
i
]
=
static_cast
<
OutType
>
(
static_cast
<
U
>
(
scale
[
j
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
);
}
}
}
}
else
{
// scale == nullptr
if
(
bias
!=
nullptr
)
{
for
(
int64_t
i
=
beg_idx
,
j
=
threadIdx
.
x
;
i
<
end_idx
;
i
+=
BlockDim
,
j
+=
BlockDim
)
{
y
[
i
]
=
static_cast
<
T
>
((
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
+
static_cast
<
U
>
(
bias
[
j
]));
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
y
[
i
]
=
quant_helper
(
static_cast
<
T
>
((
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
+
static_cast
<
U
>
(
bias
[
j
])),
quant_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
y
[
i
]
=
static_cast
<
OutType
>
((
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
+
static_cast
<
U
>
(
bias
[
j
]));
}
}
}
else
{
for
(
int64_t
i
=
beg_idx
,
j
=
threadIdx
.
x
;
i
<
end_idx
;
i
+=
BlockDim
,
j
+=
BlockDim
)
{
y
[
i
]
=
static_cast
<
T
>
((
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
);
if
(
std
::
is_same
<
OutType
,
int8_t
>::
value
)
{
y
[
i
]
=
quant_helper
(
static_cast
<
T
>
((
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
),
quant_in_scale
,
quant_round_type
,
quant_max_bound
,
quant_min_bound
);
}
else
{
y
[
i
]
=
static_cast
<
OutType
>
((
static_cast
<
U
>
(
x
[
i
])
-
mean_val
)
*
invvar
);
}
}
}
}
...
...
paddle/fluid/platform/dynload/cublasLt.h
浏览文件 @
3d7e2118
...
...
@@ -40,26 +40,28 @@ namespace dynload {
// APIs available after CUDA 10.1
// #if CUDA_VERSION >= 10100
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute); \
__macro(cublasLtMatmulAlgoInit); \
__macro(cublasLtMatmulAlgoConfigSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP
)
// #endif
...
...
paddle/fluid/pybind/op_function_generator.h
浏览文件 @
3d7e2118
...
...
@@ -71,6 +71,12 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"FFN1Bias"
,
"FFN2Weight"
,
"FFN2Bias"
}},
{
"fused_multi_transformer_int8"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"CacheKV"
,
"TimeStep"
,
"SrcMask"
,
"OutLinearW"
,
"OutLinearBias"
,
"FFNLnScale"
,
"FFNLnBias"
,
"FFN1Weight"
,
"FFN1Bias"
,
"FFN2Weight"
,
"FFN2Bias"
,
"QKVOutScale"
,
"OutLinearOutScale"
,
"FFN1OutScale"
,
"FFN2OutScale"
}},
{
"fused_bias_dropout_residual_layer_norm"
,
{
"X"
,
"Residual"
,
"Bias"
,
"LnScale"
,
"LnBias"
}},
{
"instance_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
...
...
@@ -329,6 +335,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta2PowOut"
,
"MasterParamOut"
}},
{
"fused_multi_transformer"
,
{
"CacheKVOut"
,
"Out"
}},
{
"fused_multi_transformer_int8"
,
{
"CacheKVOut"
,
"Out"
}},
{
"resnet_basic_block"
,
{
"Y"
,
"Conv1"
,
"SavedMean1"
,
"SavedInvstd1"
,
"Mean1Out"
,
"Var1Out"
,
"Conv2"
,
"SavedMean2"
,
"SavedInvstd2"
,
"Mean2Out"
,
...
...
@@ -433,6 +440,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{
"split"
,
{
"Out"
}},
{
"concat"
,
{
"Out"
}},
{
"fused_multi_transformer"
,
{
"CacheKVOut"
}},
{
"fused_multi_transformer_int8"
,
{
"CacheKVOut"
}},
{
"group_norm"
,
{
"Mean"
,
"Variance"
}},
{
"resnet_basic_block"
,
{
"Mean1Out"
,
"Var1Out"
,
"Mean2Out"
,
"Var2Out"
,
"Mean3Out"
,
"Var3Out"
}},
...
...
paddle/phi/backends/dynload/cublasLt.h
浏览文件 @
3d7e2118
...
...
@@ -54,26 +54,28 @@ extern void *cublasLt_dso_handle;
// APIs available after CUDA 10.1
// #if CUDA_VERSION >= 10100
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute); \
__macro(cublasLtMatmulAlgoInit); \
__macro(cublasLtMatmulAlgoConfigSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP
)
// #endif
...
...
paddle/phi/backends/dynload/dynamic_loader.cc
浏览文件 @
3d7e2118
...
...
@@ -326,7 +326,7 @@ void* GetCublasDsoHandle() {
void
*
GetCublasLtDsoHandle
()
{
// APIs available after CUDA 10.1
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10
10
0
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10
01
0
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcublasLt.so"
);
#else
std
::
string
warning_msg
(
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
3d7e2118
...
...
@@ -72,6 +72,7 @@ if(NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS test_fused_attention_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_attention_op_api
)
list
(
REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
...
...
@@ -141,6 +142,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_complex_matmul
)
list
(
REMOVE_ITEM TEST_OPS test_ops_nms
)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias
)
list
(
REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_checkpoint_saver
)
...
...
@@ -1202,6 +1204,10 @@ endif()
if
(
WITH_GPU OR WITH_ROCM
)
set_tests_properties
(
test_rank_attention_op PROPERTIES TIMEOUT 120
)
endif
()
if
(
WITH_GPU AND NOT WIN32
)
set_tests_properties
(
test_fused_multi_transformer_int8_op PROPERTIES TIMEOUT
60
)
endif
()
set_tests_properties
(
test_inplace_addto_strategy PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_eigvals_op PROPERTIES TIMEOUT 400
)
set_tests_properties
(
...
...
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py
0 → 100644
浏览文件 @
3d7e2118
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.fluid.core
as
core
import
paddle.nn.functional
as
F
import
paddle.incubate.nn.functional
as
incubate_f
from
paddle.nn.layer.norm
import
LayerNorm
from
paddle.nn.layer.common
import
Linear
,
Dropout
from
paddle.nn.layer.transformer
import
_convert_attention_mask
from
paddle
import
tensor
from
paddle.fluid
import
layers
import
unittest
from
op_test
import
OpTest
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.nn.initializer
import
Constant
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.fluid.framework
import
_non_static_mode
,
default_main_program
from
paddle
import
_legacy_C_ops
default_main_program
().
random_seed
=
42
np
.
random
.
seed
(
0
)
def
fused_multi_transformer_int8
(
x
,
ln_scales
,
ln_biases
,
qkv_weights
,
qkv_biases
,
linear_weights
,
linear_biases
,
ffn_ln_scales
,
ffn_ln_biases
,
ffn1_weights
,
ffn1_biases
,
ffn2_weights
,
ffn2_biases
,
pre_layer_norm
=
True
,
epsilon
=
1e-05
,
cache_kvs
=
None
,
time_step
=
None
,
attn_mask
=
None
,
dropout_rate
=
0.0
,
activation
=
"gelu"
,
training
=
False
,
mode
=
'upscale_in_train'
,
trans_qkvw
=
True
,
ring_id
=-
1
,
name
=
None
,
qkv_out_scales
=
None
,
out_linear_out_scales
=
None
,
ffn1_out_scales
=
None
,
ffn2_out_scales
=
None
,
num_head
=
0
,
dim_head
=
0
,
dim_ffn
=
0
,
qkv_in_scale
=
[],
out_linear_in_scale
=
[],
ffn1_in_scale
=
[],
ffn2_in_scale
=
[],
):
mode
=
'downgrade_in_infer'
if
mode
==
'downscale_in_infer'
else
mode
#semantic transfer
cache_kv_out
,
final_out
=
_legacy_C_ops
.
fused_multi_transformer_int8
(
x
,
ln_scales
,
ln_biases
,
qkv_weights
,
qkv_biases
,
cache_kvs
,
time_step
,
attn_mask
,
linear_weights
,
linear_biases
,
ffn_ln_scales
,
ffn_ln_biases
,
ffn1_weights
,
ffn1_biases
,
ffn2_weights
,
ffn2_biases
,
qkv_out_scales
,
out_linear_out_scales
,
ffn1_out_scales
,
ffn2_out_scales
,
cache_kvs
,
'num_head'
,
num_head
,
'dim_head'
,
dim_head
,
'dim_ffn'
,
dim_ffn
,
'qkv_in_scale'
,
qkv_in_scale
,
'out_linear_in_scale'
,
out_linear_in_scale
,
'ffn1_in_scale'
,
ffn1_in_scale
,
'ffn2_in_scale'
,
ffn2_in_scale
,
'pre_layer_norm'
,
pre_layer_norm
,
'epsilon'
,
epsilon
,
'dropout_rate'
,
dropout_rate
,
'is_test'
,
not
training
,
'dropout_implementation'
,
mode
,
'act_method'
,
activation
,
'trans_qkvw'
,
trans_qkvw
,
'ring_id'
,
ring_id
)
if
cache_kvs
is
not
None
:
return
final_out
,
cache_kv_out
return
final_out
class
TestFusedMultiTransformerInt8Op
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
config
()
self
.
generate_input_data
()
self
.
rtol
=
1e-5
# FIXME(wangxi): Because there is a problem with the test precision
# on A100, atol is temporarily set to 1e-2, and it will be
# changed back after the precision problem is solved.
self
.
atol
=
1e-2
# make sure local development precision
if
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
self
.
atol
=
1e-4
if
self
.
x_type
is
np
.
float16
:
self
.
atol
=
1e-1
paddle
.
set_default_dtype
(
self
.
x_type
)
self
.
__class__
.
op_type
=
"fused_multi_transformer_int8"
# use autograd to check grad in this unittest.
self
.
__class__
.
no_need_check_grad
=
True
paddle
.
set_default_dtype
(
np
.
float32
)
self
.
norm
=
LayerNorm
(
self
.
embed_dim
,
weight_attr
=
False
,
bias_attr
=
False
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
,
weight_attr
=
False
,
bias_attr
=
False
)
paddle
.
set_default_dtype
(
self
.
x_type
)
self
.
dropout
=
Dropout
(
self
.
dropout_prob
,
mode
=
"upscale_in_train"
)
self
.
activation
=
getattr
(
F
,
self
.
act_method
)
def
config
(
self
):
# for debug
self
.
debug
=
False
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
#self.attn_mask_type = np.bool
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
True
# has_cache_kv, gen_cache_kv, stage
# False, False, not generation
# True, True, generation context stage
# True, False, generation decoder stage
self
.
has_cache_kv
=
False
self
.
gen_cache_kv
=
False
self
.
training
=
False
self
.
layers
=
3
self
.
batch_size
=
1
self
.
query_length
=
1
self
.
cache_length
=
1
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
act_method
=
'gelu'
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
generate_input_data
(
self
):
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
).
astype
(
self
.
x_type
)
q_weight
=
np
.
random
.
randint
(
-
64
,
64
,
[
self
.
embed_dim
,
self
.
embed_dim
],
np
.
int32
).
astype
(
'float64'
)
k_weight
=
np
.
random
.
randint
(
-
64
,
64
,
[
self
.
kdim
,
self
.
embed_dim
],
np
.
int32
).
astype
(
'float64'
)
v_weight
=
np
.
random
.
randint
(
-
64
,
64
,
[
self
.
vdim
,
self
.
embed_dim
],
np
.
int32
).
astype
(
'float64'
)
self
.
q_weight_tensor
=
paddle
.
to_tensor
(
q_weight
)
self
.
k_weight_tensor
=
paddle
.
to_tensor
(
k_weight
)
self
.
v_weight_tensor
=
paddle
.
to_tensor
(
v_weight
)
out_weight
=
np
.
random
.
randint
(
-
64
,
64
,
[
self
.
embed_dim
,
self
.
embed_dim
],
np
.
int32
).
astype
(
'float64'
)
ffn1_weight
=
np
.
random
.
randint
(
-
64
,
64
,
[
self
.
embed_dim
,
4
*
self
.
embed_dim
],
np
.
int32
).
astype
(
'float64'
)
ffn2_weight
=
np
.
random
.
randint
(
-
64
,
64
,
[
4
*
self
.
embed_dim
,
self
.
embed_dim
],
np
.
int32
).
astype
(
'float64'
)
self
.
out_weight_tensor
=
paddle
.
to_tensor
(
out_weight
)
self
.
ffn1_weight_tensor
=
paddle
.
to_tensor
(
ffn1_weight
)
self
.
ffn2_weight_tensor
=
paddle
.
to_tensor
(
ffn2_weight
)
q_proj_bias
=
np
.
random
.
rand
(
self
.
embed_dim
).
astype
(
self
.
x_type
)
k_proj_bias
=
np
.
random
.
rand
(
self
.
embed_dim
).
astype
(
self
.
x_type
)
v_proj_bias
=
np
.
random
.
rand
(
self
.
embed_dim
).
astype
(
self
.
x_type
)
self
.
q_proj_bias_tensor
=
paddle
.
to_tensor
(
q_proj_bias
)
self
.
k_proj_bias_tensor
=
paddle
.
to_tensor
(
k_proj_bias
)
self
.
v_proj_bias_tensor
=
paddle
.
to_tensor
(
v_proj_bias
)
out_linear_proj_bias
=
np
.
random
.
rand
(
self
.
embed_dim
).
astype
(
self
.
x_type
)
ffn1_proj_bias
=
np
.
random
.
rand
(
4
*
self
.
embed_dim
).
astype
(
self
.
x_type
)
ffn2_proj_bias
=
np
.
random
.
rand
(
self
.
embed_dim
).
astype
(
self
.
x_type
)
self
.
out_linear_proj_bias_tensor
=
paddle
.
to_tensor
(
out_linear_proj_bias
)
self
.
ffn1_proj_bias_tensor
=
paddle
.
to_tensor
(
ffn1_proj_bias
)
self
.
ffn2_proj_bias_tensor
=
paddle
.
to_tensor
(
ffn2_proj_bias
)
out_seq_len
=
self
.
key_length
self
.
qkv_in_scales
=
[]
self
.
qkv_out_scales
=
[]
self
.
out_linear_in_scales
=
[]
self
.
out_linear_out_scales
=
[]
self
.
ffn1_in_scales
=
[]
self
.
ffn1_out_scales
=
[]
self
.
ffn2_in_scales
=
[]
self
.
ffn2_out_scales
=
[]
if
self
.
has_cache_kv
:
self
.
cache_kv
=
np
.
random
.
rand
(
2
,
self
.
batch_size
,
self
.
num_heads
,
self
.
cache_length
,
self
.
head_dim
).
astype
(
self
.
x_type
)
if
self
.
gen_cache_kv
:
self
.
cache_kv
[:]
=
0
else
:
out_seq_len
+=
self
.
cache_length
else
:
self
.
cache_kv
=
None
if
self
.
has_attn_mask
:
# [B, n_head, seq_len, out_seq_len]
self
.
attn_mask
=
np
.
ones
(
(
self
.
batch_size
,
1
,
self
.
query_length
,
out_seq_len
),
dtype
=
self
.
attn_mask_type
)
if
self
.
attn_mask_type
==
np
.
int64
:
self
.
attn_mask
=
np
.
tril
(
self
.
attn_mask
)
elif
self
.
attn_mask_type
==
np
.
float64
:
if
self
.
has_cache_kv
and
not
self
.
gen_cache_kv
:
# NOTE: decoder stage, -1(out_seq_len) should no mask
self
.
attn_mask
[:,
:,
:,
-
2
]
=
0.0
self
.
attn_mask
=
(
self
.
attn_mask
-
1.0
)
*
1e4
else
:
self
.
attn_mask
=
(
np
.
tril
(
self
.
attn_mask
)
-
1.0
)
*
1e4
elif
self
.
attn_mask_type
==
np
.
bool_
:
if
self
.
has_cache_kv
and
not
self
.
gen_cache_kv
:
self
.
attn_mask
[:,
:,
:,
-
2
]
=
0
else
:
self
.
attn_mask
=
np
.
tril
(
self
.
attn_mask
)
else
:
raise
ValueError
(
"'attn_mask_type' should be 'int64' or 'float64'."
)
else
:
self
.
attn_mask
=
None
def
fake_quant
(
self
,
input
,
scale
):
quant_value
=
127.0
*
(
1.0
/
scale
)
*
paddle
.
cast
(
input
,
'float32'
)
quant_value
=
paddle
.
round
(
quant_value
)
# No need to clip here because scale is the max value
return
paddle
.
cast
(
quant_value
,
'float64'
)
def
GetBaselineOut
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
tensor_query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
cache_kvs
=
[]
cache_kv
=
None
if
self
.
has_cache_kv
:
cache_kv
=
paddle
.
to_tensor
(
self
.
cache_kv
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
attn_mask
=
None
for
i
in
range
(
self
.
layers
):
residual
=
tensor_query
ln1_out
=
tensor_query
if
self
.
pre_layer_norm
:
ln1_out
=
self
.
norm
(
tensor_query
)
max_v
=
paddle
.
max
(
paddle
.
abs
(
paddle
.
cast
(
ln1_out
,
'float32'
)))[
0
]
# self.qkv_in_scales.append(127.0 / max_v)
self
.
qkv_in_scales
.
append
(
max_v
)
self
.
qkv_out_scales
.
append
(
127.0
*
127.0
)
# print('qkv_in_scales ', i, self.qkv_in_scales[i])
# print('qkv_out_scales ', i, self.qkv_out_scales[i])
# quant ln1_out
ln1_out
=
self
.
fake_quant
(
ln1_out
,
self
.
qkv_in_scales
[
i
])
q
=
paddle
.
nn
.
functional
.
linear
(
ln1_out
,
self
.
q_weight_tensor
)
# de quant
q
=
paddle
.
cast
(
paddle
.
cast
(
q
,
'float32'
)
*
self
.
qkv_in_scales
[
i
]
/
self
.
qkv_out_scales
[
i
],
self
.
x_type
)
q
=
q
+
self
.
q_proj_bias_tensor
q
=
tensor
.
reshape
(
x
=
q
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
q_out
=
tensor
.
transpose
(
x
=
q
,
perm
=
[
0
,
2
,
1
,
3
])
k
=
paddle
.
nn
.
functional
.
linear
(
ln1_out
,
self
.
k_weight_tensor
)
k
=
paddle
.
cast
(
paddle
.
cast
(
k
,
'float32'
)
*
self
.
qkv_in_scales
[
i
]
/
self
.
qkv_out_scales
[
i
],
self
.
x_type
)
k
=
k
+
self
.
k_proj_bias_tensor
v
=
paddle
.
nn
.
functional
.
linear
(
ln1_out
,
self
.
v_weight_tensor
)
v
=
paddle
.
cast
(
paddle
.
cast
(
v
,
'float32'
)
*
self
.
qkv_in_scales
[
i
]
/
self
.
qkv_out_scales
[
i
],
self
.
x_type
)
v
=
v
+
self
.
v_proj_bias_tensor
k
=
tensor
.
reshape
(
x
=
k
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
k_out
=
tensor
.
transpose
(
x
=
k
,
perm
=
[
0
,
2
,
1
,
3
])
v
=
tensor
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
v_out
=
tensor
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
if
self
.
has_cache_kv
:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k
,
cache_v
=
paddle
.
split
(
cache_kv
,
2
)
cache_k
=
paddle
.
squeeze
(
cache_k
,
axis
=
0
)
cache_v
=
paddle
.
squeeze
(
cache_v
,
axis
=
0
)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
if
self
.
debug
:
print
(
'q out is'
)
print
(
q_out
[
0
,
0
,
:,
:])
print
(
'cache k out seq=128'
)
print
(
k_out
[
0
,
0
,
:,
:])
if
self
.
gen_cache_kv
:
cache_kvs
.
append
((
k_out
,
v_out
))
else
:
k_out
=
paddle
.
concat
([
cache_k
,
k_out
],
axis
=-
2
)
v_out
=
paddle
.
concat
([
cache_v
,
v_out
],
axis
=-
2
)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out
=
layers
.
matmul
(
x
=
q_out
,
y
=
k_out
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
self
.
debug
:
print
(
'qk out is'
)
print
(
qk_out
[
0
][
0
][
0
])
if
attn_mask
is
not
None
:
attn_mask
=
_convert_attention_mask
(
attn_mask
,
qk_out
.
dtype
)
attn_mask_out
=
qk_out
+
attn_mask
if
self
.
debug
:
print
(
'attn mask out is'
)
print
(
attn_mask_out
[
0
][
0
][
0
])
softmax_out
=
F
.
softmax
(
attn_mask_out
)
else
:
softmax_out
=
F
.
softmax
(
qk_out
)
if
self
.
debug
:
print
(
'softmax out is'
)
print
(
softmax_out
[
0
][
0
][
0
])
if
self
.
dropout_prob
:
dropout_out
=
F
.
dropout
(
softmax_out
,
self
.
dropout_prob
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out
=
tensor
.
matmul
(
dropout_out
,
v_out
)
else
:
qktv_out
=
tensor
.
matmul
(
softmax_out
,
v_out
)
fmha_out
=
tensor
.
transpose
(
qktv_out
,
perm
=
[
0
,
2
,
1
,
3
])
if
self
.
debug
:
print
(
'fmha out is'
)
print
(
fmha_out
[
0
][
0
][
0
])
out_linear_in
=
tensor
.
reshape
(
x
=
fmha_out
,
shape
=
[
0
,
0
,
fmha_out
.
shape
[
2
]
*
fmha_out
.
shape
[
3
]])
max_v
=
paddle
.
max
(
paddle
.
abs
(
paddle
.
cast
(
out_linear_in
,
'float32'
)))[
0
]
# self.out_linear_in_scales.append(127.0 / max_v)
self
.
out_linear_in_scales
.
append
(
max_v
)
self
.
out_linear_out_scales
.
append
((
127.0
*
127.0
))
out_linear_in
=
self
.
fake_quant
(
out_linear_in
,
self
.
out_linear_in_scales
[
i
])
out
=
paddle
.
nn
.
functional
.
linear
(
out_linear_in
,
self
.
out_weight_tensor
)
out
=
paddle
.
cast
(
paddle
.
cast
(
out
,
'float32'
)
*
self
.
out_linear_in_scales
[
i
]
/
self
.
out_linear_out_scales
[
i
],
self
.
x_type
)
out
=
out
+
self
.
out_linear_proj_bias_tensor
residual_out
=
residual
+
self
.
dropout
(
out
)
if
not
self
.
pre_layer_norm
:
attn_out
=
self
.
norm
(
residual_out
)
else
:
attn_out
=
residual_out
ffn_ln_out
=
attn_out
if
self
.
pre_layer_norm
:
ffn_ln_out
=
self
.
ffn_norm
(
attn_out
)
max_v
=
paddle
.
max
(
paddle
.
abs
(
paddle
.
cast
(
ffn_ln_out
,
'float32'
)))[
0
]
self
.
ffn1_in_scales
.
append
(
max_v
)
self
.
ffn1_out_scales
.
append
((
127.0
*
127.0
))
ffn_ln_out
=
self
.
fake_quant
(
ffn_ln_out
,
self
.
ffn1_in_scales
[
i
])
ffn1_out
=
paddle
.
nn
.
functional
.
linear
(
ffn_ln_out
,
self
.
ffn1_weight_tensor
)
ffn1_out
=
paddle
.
cast
(
paddle
.
cast
(
ffn1_out
,
'float32'
)
*
self
.
ffn1_in_scales
[
i
]
/
self
.
ffn1_out_scales
[
i
],
self
.
x_type
)
ffn1_out
=
ffn1_out
+
self
.
ffn1_proj_bias_tensor
ffn1_out
=
self
.
dropout
(
self
.
activation
(
ffn1_out
))
max_v
=
paddle
.
max
(
paddle
.
abs
(
paddle
.
cast
(
ffn1_out
,
'float32'
)))[
0
]
# self.ffn2_in_scales.append(127.0 / max_v)
self
.
ffn2_in_scales
.
append
(
max_v
)
self
.
ffn2_out_scales
.
append
((
127.0
*
127.0
))
# print('ffn2_in_scales ', i, self.ffn2_in_scales[i])
ffn1_out
=
self
.
fake_quant
(
ffn1_out
,
self
.
ffn2_in_scales
[
i
])
ffn2_out
=
paddle
.
nn
.
functional
.
linear
(
ffn1_out
,
self
.
ffn2_weight_tensor
)
ffn2_out
=
paddle
.
cast
(
paddle
.
cast
(
ffn2_out
,
'float32'
)
*
self
.
ffn2_in_scales
[
i
]
/
self
.
ffn2_out_scales
[
i
],
self
.
x_type
)
ffn2_out
=
ffn2_out
+
self
.
ffn2_proj_bias_tensor
residual_out
=
attn_out
+
self
.
dropout
(
ffn2_out
)
# print("residual ", attn_out)
# print("residual_out ", residual_out)
final_out
=
residual_out
if
not
self
.
pre_layer_norm
:
final_out
=
self
.
ffn_norm
(
residual_out
)
tensor_query
=
final_out
if
self
.
has_cache_kv
and
self
.
gen_cache_kv
:
return
final_out
,
cache_kvs
return
final_out
def
GetFusedMultiTransformerOut
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
ln_scale
=
paddle
.
ones
([
self
.
embed_dim
],
'float32'
)
ln_bias
=
paddle
.
zeros
([
self
.
embed_dim
],
'float32'
)
ffn_ln_scale
=
ln_scale
ffn_ln_bias
=
ln_bias
q_proj_weight
=
self
.
q_weight_tensor
.
numpy
().
transpose
((
1
,
0
))
k_proj_weight
=
self
.
k_weight_tensor
.
numpy
().
transpose
((
1
,
0
))
v_proj_weight
=
self
.
v_weight_tensor
.
numpy
().
transpose
((
1
,
0
))
qkv_weight
=
np
.
concatenate
(
(
q_proj_weight
,
k_proj_weight
,
v_proj_weight
))
qkv_weight
=
qkv_weight
.
reshape
(
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
))
qkv_weight_tensor
=
paddle
.
to_tensor
(
qkv_weight
)
qkv_weight_tensor
=
paddle
.
cast
(
qkv_weight_tensor
,
'int8'
)
out_weight_tensor
=
paddle
.
cast
(
paddle
.
to_tensor
(
self
.
out_weight_tensor
.
numpy
().
transpose
((
1
,
0
))),
'int8'
)
ffn1_weight_tensor
=
paddle
.
cast
(
paddle
.
to_tensor
(
self
.
ffn1_weight_tensor
.
numpy
().
transpose
((
1
,
0
))),
'int8'
)
ffn2_weight_tensor
=
paddle
.
cast
(
paddle
.
to_tensor
(
self
.
ffn2_weight_tensor
.
numpy
().
transpose
((
1
,
0
))),
'int8'
)
qkv_bias
=
np
.
concatenate
(
(
self
.
q_proj_bias_tensor
.
numpy
(),
self
.
k_proj_bias_tensor
.
numpy
(),
self
.
v_proj_bias_tensor
.
numpy
()))
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
)
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
True
)
cache_kvs
,
cache_kv
=
None
,
None
time_step
=
None
if
self
.
has_cache_kv
:
cache_kvs
=
[]
max_seq_length
=
(
self
.
cache_length
+
128
)
//
128
*
128
cache_kv
=
np
.
zeros
([
2
,
self
.
batch_size
,
self
.
num_heads
,
max_seq_length
,
self
.
head_dim
],
dtype
=
self
.
x_type
)
elems
=
4
if
self
.
x_type
is
np
.
float16
:
elems
=
8
assert
self
.
head_dim
%
elems
==
0
v_elems
=
self
.
head_dim
//
elems
# [B, num_head, 128, head_dim]
# cache_k_tmp = self.cache_kv[0, :]
# [B, num_head, 128, head_dim / 4, 4]
cache_k_tmp
=
self
.
cache_kv
[
0
].
reshape
([
self
.
batch_size
,
self
.
num_heads
,
self
.
cache_length
,
v_elems
,
elems
])
# [B, num_head, head_dim / 4, 128, 4]
cache_k_tmp
=
cache_k_tmp
.
transpose
([
0
,
1
,
3
,
2
,
4
])
cache_kv
[
0
,
:].
reshape
([
self
.
batch_size
,
self
.
num_heads
,
v_elems
,
max_seq_length
,
elems
])[:,
:,
:,
:
self
.
cache_length
,
:]
=
cache_k_tmp
cache_kv
[
1
,
:,
:,
:
self
.
cache_length
,
:]
=
self
.
cache_kv
[
1
]
if
self
.
gen_cache_kv
:
assert
self
.
query_length
==
self
.
cache_length
cache_kv
[:]
=
0
else
:
time_step
=
paddle
.
to_tensor
([
self
.
cache_length
],
dtype
=
'int32'
,
place
=
paddle
.
CPUPlace
())
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
True
)
else
:
attn_mask
=
None
epsilon
=
1e-05
ln2_epsilon
=
1e-05
if
attn_mask
is
not
None
and
self
.
attn_mask_type
!=
np
.
bool_
:
attn_mask
=
_convert_attention_mask
(
attn_mask
,
x
.
dtype
)
qkv_weights
,
qkv_biases
=
[],
[]
out_weights
,
out_biases
=
[],
[]
ln_scales
,
ln_biases
=
[],
[]
ffn1_weights
,
ffn1_biases
=
[],
[]
ffn2_weights
,
ffn2_biases
=
[],
[]
ffn_ln_scales
,
ffn_ln_biases
=
[],
[]
qkv_in_scale
=
[]
out_linear_in_scale
=
[]
ffn1_in_scale
=
[]
ffn2_in_scale
=
[]
qkv_out_scales_tensor
=
paddle
.
ones
([
self
.
layers
,
3
*
self
.
embed_dim
],
'float32'
)
out_linear_out_scales_tensor
=
paddle
.
ones
(
[
self
.
layers
,
self
.
embed_dim
],
'float32'
)
ffn1_out_scales_tensor
=
paddle
.
ones
([
self
.
layers
,
4
*
self
.
embed_dim
],
'float32'
)
ffn2_out_scales_tensor
=
paddle
.
ones
([
self
.
layers
,
self
.
embed_dim
],
'float32'
)
for
i
in
range
(
self
.
layers
):
qkv_weights
.
append
(
qkv_weight_tensor
)
qkv_biases
.
append
(
qkv_bias_tensor
)
out_weights
.
append
(
out_weight_tensor
)
out_biases
.
append
(
self
.
out_linear_proj_bias_tensor
)
ln_scales
.
append
(
ln_scale
)
ln_biases
.
append
(
ln_bias
)
ffn1_weights
.
append
(
ffn1_weight_tensor
)
ffn1_biases
.
append
(
self
.
ffn1_proj_bias_tensor
)
ffn2_weights
.
append
(
ffn2_weight_tensor
)
ffn2_biases
.
append
(
self
.
ffn2_proj_bias_tensor
)
ffn_ln_scales
.
append
(
ffn_ln_scale
)
ffn_ln_biases
.
append
(
ffn_ln_bias
)
qkv_in_scale
.
append
(
self
.
qkv_in_scales
[
i
])
out_linear_in_scale
.
append
(
self
.
out_linear_in_scales
[
i
])
ffn1_in_scale
.
append
(
self
.
ffn1_in_scales
[
i
])
ffn2_in_scale
.
append
(
self
.
ffn2_in_scales
[
i
])
qkv_out_scales_tensor
[
i
,
:]
*=
self
.
qkv_out_scales
[
i
]
out_linear_out_scales_tensor
[
i
,
:]
*=
self
.
out_linear_out_scales
[
i
]
ffn1_out_scales_tensor
[
i
,
:]
*=
self
.
ffn1_out_scales
[
i
]
ffn2_out_scales_tensor
[
i
,
:]
*=
self
.
ffn2_out_scales
[
i
]
if
self
.
has_cache_kv
:
cache_kvs
.
append
(
paddle
.
to_tensor
(
cache_kv
,
stop_gradient
=
True
))
final_out
=
fused_multi_transformer_int8
(
x
,
ln_scales
,
ln_biases
,
qkv_weights
,
qkv_biases
,
out_weights
,
out_biases
,
ffn_ln_scales
,
ffn_ln_biases
,
ffn1_weights
,
ffn1_biases
,
ffn2_weights
,
ffn2_biases
,
pre_layer_norm
=
self
.
pre_layer_norm
,
epsilon
=
epsilon
,
cache_kvs
=
cache_kvs
,
time_step
=
time_step
,
attn_mask
=
attn_mask
,
dropout_rate
=
self
.
dropout_prob
,
training
=
self
.
training
,
mode
=
'upscale_in_train'
,
trans_qkvw
=
True
,
ring_id
=-
1
,
name
=
None
,
qkv_out_scales
=
qkv_out_scales_tensor
,
out_linear_out_scales
=
out_linear_out_scales_tensor
,
ffn1_out_scales
=
ffn1_out_scales_tensor
,
ffn2_out_scales
=
ffn2_out_scales_tensor
,
num_head
=
self
.
num_heads
,
dim_head
=
self
.
head_dim
,
dim_ffn
=
4
*
self
.
embed_dim
,
qkv_in_scale
=
qkv_in_scale
,
out_linear_in_scale
=
out_linear_in_scale
,
ffn1_in_scale
=
ffn1_in_scale
,
ffn2_in_scale
=
ffn2_in_scale
)
if
self
.
has_cache_kv
:
return
final_out
[
0
],
final_out
[
1
]
return
final_out
def
test_fused_multi_transformer_op
(
self
):
final_out_ref
=
self
.
GetBaselineOut
()
final_out
=
self
.
GetFusedMultiTransformerOut
()
if
self
.
has_cache_kv
:
final_out
,
cache_kv_out
=
final_out
s
=
cache_kv_out
[
0
].
shape
bsz
=
s
[
1
]
num_head
=
s
[
2
]
max_seq_len
=
s
[
3
]
head_dim
=
s
[
4
]
elems
=
8
if
self
.
x_type
is
np
.
float16
else
4
v_elems
=
head_dim
//
elems
if
self
.
debug
:
print
(
"cache_k out timestep=128"
)
print
(
cache_kv_out
[
0
].
reshape
(
[
2
,
bsz
,
num_head
,
v_elems
,
max_seq_len
,
elems
])[
0
,
0
,
0
,
:,
self
.
cache_length
,
:])
print
(
"cache_v out timestep=128"
)
print
(
cache_kv_out
[
0
][
1
,
0
,
0
,
self
.
cache_length
,
:])
if
self
.
gen_cache_kv
:
final_out_ref
,
cache_kvs
=
final_out_ref
for
i
in
range
(
self
.
layers
):
cache_k_ref
=
cache_kvs
[
i
][
0
]
cache_v_ref
=
cache_kvs
[
i
][
1
]
cache_k
=
cache_kv_out
[
i
][
0
,
:]
cache_k
=
cache_k
.
reshape
(
[
bsz
,
num_head
,
v_elems
,
max_seq_len
,
elems
])
cache_k
=
cache_k
[:,
:,
:,
:
self
.
cache_length
,
:]
cache_k
=
cache_k
.
transpose
([
0
,
1
,
3
,
2
,
4
])
cache_k
=
cache_k
.
reshape
(
[
bsz
,
num_head
,
self
.
cache_length
,
head_dim
])
cache_v
=
cache_kv_out
[
i
][
1
,
:,
:,
:
self
.
cache_length
,
:]
np
.
testing
.
assert_allclose
(
cache_k_ref
,
cache_k
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
np
.
testing
.
assert_allclose
(
cache_v_ref
,
cache_v
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
if
i
==
0
:
break
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
class
TestFusedMultiTransformerInt8OpFp16
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
class
TestFusedMultiTransformerInt8OpCacheKV
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
super
().
generate_input_data
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
layers
=
3
# odd layers
class
TestFusedMultiTransformerInt8OpCacheKVFp16
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
x_type
=
np
.
float16
class
TestFusedMultiTransformerInt8OpGenCacheKV
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
class
TestFusedMultiTransformerInt8OpGenCacheKVFp16
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
class
TestFusedMultiTransformerInt8OpPostLayerNormFp16
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerInt8OpCacheKVPostLayerNorm
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerInt8OpCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
x_type
=
np
.
float16
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNorm
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerInt8Op
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录