Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c5f4a9cc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c5f4a9cc
编写于
8月 09, 2022
作者:
C
carryyu
提交者:
GitHub
8月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add post layer norm (#44931)
上级
9336dd3e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
397 addition
and
192 deletion
+397
-192
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+230
-84
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
.../fluid/tests/unittests/test_fused_multi_transformer_op.py
+167
-108
未找到文件。
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
c5f4a9cc
...
@@ -530,7 +530,10 @@ inline __device__ void zero(T &dst) { // NOLINT
...
@@ -530,7 +530,10 @@ inline __device__ void zero(T &dst) { // NOLINT
dst
=
tmp
.
raw
;
dst
=
tmp
.
raw
;
}
}
template
<
typename
T
,
int
Dh
,
int
THREADS_PER_KEY
,
int
THREADS_PER_VALUE
,
template
<
typename
T
,
int
Dh
,
int
THREADS_PER_KEY
,
int
THREADS_PER_VALUE
,
int
THREADS_PER_BLOCK
>
int
THREADS_PER_BLOCK
>
__global__
void
masked_multihead_attention_kernel
(
__global__
void
masked_multihead_attention_kernel
(
Masked_multihead_attention_params
<
T
>
params
)
{
Masked_multihead_attention_params
<
T
>
params
)
{
...
@@ -830,8 +833,10 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -830,8 +833,10 @@ __global__ void masked_multihead_attention_kernel(
template
<
typename
T
>
template
<
typename
T
>
inline
size_t
smem_size_in_bytes
(
inline
size_t
smem_size_in_bytes
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
dim_head
,
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
threads_per_value
,
int
threads_per_block
)
{
int
dim_head
,
int
threads_per_value
,
int
threads_per_block
)
{
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
logits_sz
=
0
;
size_t
logits_sz
=
0
;
...
@@ -848,14 +853,17 @@ inline size_t smem_size_in_bytes(
...
@@ -848,14 +853,17 @@ inline size_t smem_size_in_bytes(
return
max
(
softmax_sz
,
red_sz
);
return
max
(
softmax_sz
,
red_sz
);
}
}
#define MMHA_LAUNCH_KERNEL(
T, Dh, THDS_PER_KEY, THDS_PER_VALUE,
\
#define MMHA_LAUNCH_KERNEL(
\
THDS_PER_BLOCK, stream)
\
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream)
\
size_t smem_sz = \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel< \
masked_multihead_attention_kernel<T, \
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
Dh, \
THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template
<
typename
T
,
int
Dh
>
template
<
typename
T
,
int
Dh
>
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
...
@@ -871,10 +879,17 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> ¶ms,
...
@@ -871,10 +879,17 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> ¶ms,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
fmha
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
Tensor
&
qkv_tensor
,
void
fmha
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
Tensor
&
qkv_bias_tensor
,
const
Tensor
&
src_mask_tensor
,
const
Tensor
&
qkv_tensor
,
Tensor
*
cache_kv_tensor
,
Tensor
*
out_tensor
,
int
batch_size
,
const
Tensor
&
qkv_bias_tensor
,
int
max_seq_length
,
int
num_head
,
int
dim_head
,
int
timestep
,
const
Tensor
&
src_mask_tensor
,
Tensor
*
cache_kv_tensor
,
Tensor
*
out_tensor
,
int
batch_size
,
int
max_seq_length
,
int
num_head
,
int
dim_head
,
int
timestep
,
float
inv_sqrt_dh
)
{
float
inv_sqrt_dh
)
{
Masked_multihead_attention_params
<
T
>
params
;
Masked_multihead_attention_params
<
T
>
params
;
params
.
out
=
out_tensor
->
data
<
T
>
();
params
.
out
=
out_tensor
->
data
<
T
>
();
...
@@ -911,8 +926,11 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
...
@@ -911,8 +926,11 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
constexpr
int
VEC_16B
=
16
;
constexpr
int
VEC_16B
=
16
;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
write_cache_k_kernel
(
T
*
cache_k
,
const
T
*
k
,
const
int
num_head
,
__global__
void
write_cache_k_kernel
(
T
*
cache_k
,
const
int
dim_head
,
const
int
seq_len
,
const
T
*
k
,
const
int
num_head
,
const
int
dim_head
,
const
int
seq_len
,
const
int
max_seq_len
)
{
const
int
max_seq_len
)
{
const
int
bi
=
blockIdx
.
y
;
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
z
;
const
int
hi
=
blockIdx
.
z
;
...
@@ -946,8 +964,11 @@ __global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head,
...
@@ -946,8 +964,11 @@ __global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
write_cache_v_kernel
(
T
*
cache_v
,
const
T
*
v
,
const
int
num_head
,
__global__
void
write_cache_v_kernel
(
T
*
cache_v
,
const
int
dim_head
,
const
int
seq_len
,
const
T
*
v
,
const
int
num_head
,
const
int
dim_head
,
const
int
seq_len
,
const
int
max_seq_len
)
{
const
int
max_seq_len
)
{
const
int
bi
=
blockIdx
.
y
;
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
z
;
const
int
hi
=
blockIdx
.
z
;
...
@@ -970,16 +991,23 @@ __global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head,
...
@@ -970,16 +991,23 @@ __global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
write_cache_kv
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
T
*
cache_k
,
void
write_cache_kv
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
T
*
cache_v
,
const
T
*
k
,
const
T
*
v
,
const
int
bsz
,
T
*
cache_k
,
const
int
num_head
,
const
int
seq_len
,
T
*
cache_v
,
const
int
max_seq_len
,
const
int
dim_head
)
{
const
T
*
k
,
const
T
*
v
,
const
int
bsz
,
const
int
num_head
,
const
int
seq_len
,
const
int
max_seq_len
,
const
int
dim_head
)
{
constexpr
int
block_sz
=
128
;
constexpr
int
block_sz
=
128
;
constexpr
int
x
=
VEC_16B
/
sizeof
(
T
);
constexpr
int
x
=
VEC_16B
/
sizeof
(
T
);
assert
(
dim_head
%
x
==
0
);
assert
(
dim_head
%
x
==
0
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dim_head
%
x
,
0
,
dim_head
%
x
,
0
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"dim_head=%d must be divisible by vec_size=%d"
,
dim_head
,
x
));
"dim_head=%d must be divisible by vec_size=%d"
,
dim_head
,
x
));
...
@@ -1043,15 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1043,15 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool
compute_bias
=
qkv_biases
.
size
()
>
0
&&
time_step
==
nullptr
;
bool
compute_bias
=
qkv_biases
.
size
()
>
0
&&
time_step
==
nullptr
;
// (transA, transB, compute_bias) = (false, true, false)
// (transA, transB, compute_bias) = (false, true, false)
auto
qkv_compute
=
AttnMatMul
<
T
>
(
dev_ctx
,
false
,
true
,
bsz_seq
,
output_size
,
auto
qkv_compute
=
AttnMatMul
<
T
>
(
input_size
,
compute_bias
);
dev_ctx
,
false
,
true
,
bsz_seq
,
output_size
,
input_size
,
compute_bias
);
Tensor
qkv_out
;
Tensor
qkv_out
;
auto
*
qkv_out_data
=
auto
*
qkv_out_data
=
qkv_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
3
,
num_head
,
dim_head
},
place
);
qkv_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
3
,
num_head
,
dim_head
},
place
);
// 3. fmha
// 3. fmha
AttnDropoutParam
attn_param
(
true
,
"upscale_in_train"
,
0.0
,
true
,
true
,
0
,
AttnDropoutParam
attn_param
(
nullptr
);
true
,
"upscale_in_train"
,
0.0
,
true
,
true
,
0
,
nullptr
);
auto
fmha_compute
=
auto
fmha_compute
=
FMHARef
<
T
>
(
dev_ctx
,
bsz
,
seq_len
,
num_head
,
dim_head
,
attn_param
);
FMHARef
<
T
>
(
dev_ctx
,
bsz
,
seq_len
,
num_head
,
dim_head
,
attn_param
);
auto
*
src_mask
=
ctx
.
Input
<
Tensor
>
(
"SrcMask"
);
auto
*
src_mask
=
ctx
.
Input
<
Tensor
>
(
"SrcMask"
);
...
@@ -1061,17 +1089,20 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1061,17 +1089,20 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
out_seq_len
=
seq_len
;
auto
out_seq_len
=
seq_len
;
if
(
time_step
)
{
if
(
time_step
)
{
PADDLE_ENFORCE_EQ
(
time_step
->
place
(),
platform
::
CPUPlace
(),
PADDLE_ENFORCE_EQ
(
time_step
->
place
(),
platform
::
CPUPlace
(),
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The place of input(TimeStep) must be CPUPlace."
));
"The place of input(TimeStep) must be CPUPlace."
));
// cache_seq_len
// cache_seq_len
int
time_step_value
=
time_step
->
data
<
int
>
()[
0
];
int
time_step_value
=
time_step
->
data
<
int
>
()[
0
];
PADDLE_ENFORCE_GT
(
time_step_value
,
0
,
PADDLE_ENFORCE_GT
(
time_step_value
,
0
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The value of time_step must > 0, but now is %d"
,
"The value of time_step must > 0, but now is %d"
,
time_step_value
));
time_step_value
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
seq_len
,
1
,
seq_len
,
1
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"In decode stage, the seq_len of input must be 1, but now is %d"
,
"In decode stage, the seq_len of input must be 1, but now is %d"
,
seq_len
));
seq_len
));
...
@@ -1107,8 +1138,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1107,8 +1138,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
out_linear_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"OutLinearBias"
);
auto
out_linear_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"OutLinearBias"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// (transA, transB, compute_bias) = (false, false, false)
// (transA, transB, compute_bias) = (false, false, false)
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
dev_ctx
,
false
,
false
,
bsz_seq
,
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
dim_embed
,
hidden_size
,
false
);
dev_ctx
,
false
,
false
,
bsz_seq
,
dim_embed
,
hidden_size
,
false
);
// 5. ln(residual + bias)
// 5. ln(residual + bias)
DropoutParam
dropout_param2
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
DropoutParam
dropout_param2
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
...
@@ -1117,9 +1148,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1117,9 +1148,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
ffn_ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnScale"
);
auto
ffn_ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnScale"
);
auto
ffn_ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnBias"
);
auto
ffn_ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnBias"
);
Tensor
bias_dropout_residual_out
,
dropout_mask_out
;
Tensor
bias_dropout_residual_out
,
dropout_mask_out
;
auto
*
bias_dropout_residual_out_data
=
T
*
bias_dropout_residual_out_data
=
nullptr
;
if
(
pre_layer_norm
)
{
bias_dropout_residual_out_data
=
bias_dropout_residual_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
dim_embed
},
bias_dropout_residual_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
dim_embed
},
place
);
place
);
}
auto
*
dropout_mask_out_data
=
dropout_mask_out
.
mutable_data
<
uint8_t
>
(
auto
*
dropout_mask_out_data
=
dropout_mask_out
.
mutable_data
<
uint8_t
>
(
{
bsz
,
seq_len
,
dim_embed
},
place
);
{
bsz
,
seq_len
,
dim_embed
},
place
);
...
@@ -1129,8 +1163,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1129,8 +1163,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
ffn1_weight_dim
=
ffn1_weights
[
0
]
->
dims
();
auto
ffn1_weight_dim
=
ffn1_weights
[
0
]
->
dims
();
int
dim_ffn
=
ffn1_weight_dim
[
1
];
int
dim_ffn
=
ffn1_weight_dim
[
1
];
auto
ffn1_linear_compute
=
AttnMatMul
<
T
>
(
dev_ctx
,
false
,
false
,
bsz_seq
,
auto
ffn1_linear_compute
=
AttnMatMul
<
T
>
(
dim_ffn
,
dim_embed
,
false
);
dev_ctx
,
false
,
false
,
bsz_seq
,
dim_ffn
,
dim_embed
,
false
);
Tensor
ffn1_out
;
Tensor
ffn1_out
;
auto
*
ffn1_out_data
=
ffn1_out
.
mutable_data
<
T
>
({
bsz_seq
,
dim_ffn
},
place
);
auto
*
ffn1_out_data
=
ffn1_out
.
mutable_data
<
T
>
({
bsz_seq
,
dim_ffn
},
place
);
...
@@ -1147,8 +1181,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1147,8 +1181,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// 8. ffn2 matmul
// 8. ffn2 matmul
auto
ffn2_weights
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN2Weight"
);
auto
ffn2_weights
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN2Weight"
);
auto
ffn2_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN2Bias"
);
auto
ffn2_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFN2Bias"
);
auto
ffn2_linear_compute
=
AttnMatMul
<
T
>
(
dev_ctx
,
false
,
false
,
bsz_seq
,
auto
ffn2_linear_compute
=
AttnMatMul
<
T
>
(
dim_embed
,
dim_ffn
,
false
);
dev_ctx
,
false
,
false
,
bsz_seq
,
dim_embed
,
dim_ffn
,
false
);
// 9. ffn2 residual bias
// 9. ffn2 residual bias
DropoutParam
ffn2_dropout_param
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
DropoutParam
ffn2_dropout_param
(
true
,
0
,
true
,
true
,
0.0
,
nullptr
,
0
);
...
@@ -1171,6 +1205,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1171,6 +1205,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step1: buf1 --> buf0
// step2: buf0 --> buf1
// step2: buf0 --> buf1
int
layers
=
qkv_weights
.
size
();
int
layers
=
qkv_weights
.
size
();
if
(
pre_layer_norm
)
{
if
(
layers
&
1
)
{
if
(
layers
&
1
)
{
// odd, set buf1 as out
// odd, set buf1 as out
buf0
=
&
tmp_out
;
buf0
=
&
tmp_out
;
...
@@ -1180,6 +1215,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1180,6 +1215,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
buf0
=
out
;
buf0
=
out
;
buf1
=
&
tmp_out
;
buf1
=
&
tmp_out
;
}
}
}
else
{
buf0
=
&
tmp_out
;
buf1
=
out
;
}
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
// step1. layer_norm
// step1. layer_norm
...
@@ -1187,11 +1226,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1187,11 +1226,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
// TODO(wangxi): can remove mean var in inference
// TODO(wangxi): can remove mean var in inference
ln_compute
.
ComputeForward
(
x_data
,
ln_scale_data
,
ln_bias_data
,
ln_compute
.
ComputeForward
(
x_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
ln_scale_data
,
}
else
if
(
!
pre_layer_norm
)
{
ln_bias_data
,
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
buf1
->
data
<
T
>
(),
"Unimplemented post_layer_norm for now."
));
ln_mean_data
,
ln_var_data
);
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step1"
;
VLOG
(
0
)
<<
"step1"
;
...
@@ -1201,8 +1241,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1201,8 +1241,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
const
Tensor
*
qkv_bias
=
qkv_biases
.
size
()
>
0
?
qkv_biases
[
i
]
:
nullptr
;
const
Tensor
*
qkv_bias
=
qkv_biases
.
size
()
>
0
?
qkv_biases
[
i
]
:
nullptr
;
// NOTE: in decoder stage, bias is fused in fmha
// NOTE: in decoder stage, bias is fused in fmha
const
Tensor
*
bias
=
time_step
?
nullptr
:
qkv_bias
;
const
Tensor
*
bias
=
time_step
?
nullptr
:
qkv_bias
;
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
buf1
,
bias
,
&
qkv_out
,
if
(
!
pre_layer_norm
&&
i
==
0
)
{
&
qkv_out
);
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
input_x
,
bias
,
&
qkv_out
,
&
qkv_out
);
}
else
{
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
buf1
,
bias
,
&
qkv_out
,
&
qkv_out
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step2"
;
VLOG
(
0
)
<<
"step2"
;
#endif
#endif
...
@@ -1214,15 +1259,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1214,15 +1259,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
if
(
time_step
)
{
// generation decoder stage
if
(
time_step
)
{
// generation decoder stage
// [2, batch_size, num_head, max_seq_len, head_size]
// [2, batch_size, num_head, max_seq_len, head_size]
int
max_seq_len
=
cache_kv
->
dims
()[
3
];
int
max_seq_len
=
cache_kv
->
dims
()[
3
];
fmha
<
T
>
(
dev_ctx
,
qkv_out
,
*
qkv_bias
,
*
src_mask
,
cache_kv_out
,
&
fmha_out
,
fmha
<
T
>
(
dev_ctx
,
bsz
,
max_seq_len
,
num_head
,
dim_head
,
time_step
->
data
<
int
>
()[
0
],
qkv_out
,
*
qkv_bias
,
*
src_mask
,
cache_kv_out
,
&
fmha_out
,
bsz
,
max_seq_len
,
num_head
,
dim_head
,
time_step
->
data
<
int
>
()[
0
],
1.
/
sqrt
(
dim_head
));
1.
/
sqrt
(
dim_head
));
}
else
if
(
cache_kv_out
)
{
// generation context stage
}
else
if
(
cache_kv_out
)
{
// generation context stage
// TODO(wangxi): can remove dropout in inference
// TODO(wangxi): can remove dropout in inference
fmha_compute
.
ComputeForward
(
fmha_compute
.
ComputeForward
(
qkv_out
,
qkv_out
,
nullptr
,
src_mask
,
&
transpose_out_2
,
nullptr
,
&
qk_out
,
nullptr
,
&
src_mask_out
,
&
softmax_out
,
&
attn_dropout_mask_out
,
src_mask
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
&
transpose_out_2
,
nullptr
,
&
qk_out
,
&
src_mask_out
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
// [3, bsz, num_head, seq_len, head_dim]
// [3, bsz, num_head, seq_len, head_dim]
T
*
qkv_data
=
transpose_out_2_data
;
T
*
qkv_data
=
transpose_out_2_data
;
int64_t
q_size
=
bsz
*
seq_len
*
num_head
*
dim_head
;
int64_t
q_size
=
bsz
*
seq_len
*
num_head
*
dim_head
;
...
@@ -1239,23 +1301,45 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1239,23 +1301,45 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
T
*
cache_k_ptr
=
cache_kv_data
;
T
*
cache_k_ptr
=
cache_kv_data
;
T
*
cache_v_ptr
=
cache_kv_data
+
cache_k_size
;
T
*
cache_v_ptr
=
cache_kv_data
+
cache_k_size
;
write_cache_kv
<
T
>
(
dev_ctx
,
cache_k_ptr
,
cache_v_ptr
,
k_ptr
,
v_ptr
,
bsz
,
write_cache_kv
<
T
>
(
dev_ctx
,
num_head
,
seq_len
,
max_seq_len
,
dim_head
);
cache_k_ptr
,
cache_v_ptr
,
k_ptr
,
v_ptr
,
bsz
,
num_head
,
seq_len
,
max_seq_len
,
dim_head
);
}
else
{
// not generation
}
else
{
// not generation
// TODO(wangxi): can remove dropout in inference
// TODO(wangxi): can remove dropout in inference
fmha_compute
.
ComputeForward
(
fmha_compute
.
ComputeForward
(
qkv_out
,
qkv_out
,
cache_kv
,
src_mask
,
&
transpose_out_2
,
cache_kv_out
,
cache_kv
,
&
qk_out
,
&
src_mask_out
,
&
softmax_out
,
&
attn_dropout_mask_out
,
src_mask
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
&
transpose_out_2
,
cache_kv_out
,
&
qk_out
,
&
src_mask_out
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step3"
;
VLOG
(
0
)
<<
"step3"
;
#endif
#endif
// step4. out_linear
// step4. out_linear
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
if
(
pre_layer_norm
)
{
nullptr
,
buf1
,
nullptr
);
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf1
,
nullptr
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
}
else
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf0
,
nullptr
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step4"
;
VLOG
(
0
)
<<
"step4"
;
#endif
#endif
...
@@ -1268,39 +1352,75 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1268,39 +1352,75 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// inplace
// inplace
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf1
->
data
<
T
>
(),
x_data
,
out_linear_bias_data
,
dev_ctx
,
ln_scale_data
,
ln_bias_data
,
bias_dropout_residual_out_data
,
buf1
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
x_data
,
out_linear_bias_data
,
ln_scale_data
,
ln_bias_data
,
bias_dropout_residual_out_data
,
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
else
{
}
else
{
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
auto
*
out_linear_bias_data
=
out_linear_biases
[
i
]
->
data
<
T
>
();
auto
*
residual_data
=
(
i
==
0
?
x_data
:
buf1
->
data
<
T
>
());
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
residual_data
,
out_linear_bias_data
,
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step5"
;
VLOG
(
0
)
<<
"step5"
;
#endif
#endif
// step6. ffn matmul1
// step6. ffn matmul1
ffn1_linear_compute
.
ComputeForward
(
ffn1_weights
[
i
],
buf1
,
nullptr
,
ffn1_linear_compute
.
ComputeForward
(
&
ffn1_out
,
nullptr
);
ffn1_weights
[
i
],
buf1
,
nullptr
,
&
ffn1_out
,
nullptr
);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step6"
;
VLOG
(
0
)
<<
"step6"
;
#endif
#endif
// step7. act bias
// step7. act bias
// TODO(wangxi): remove dropout mask in inference
// TODO(wangxi): remove dropout mask in inference
fused_act_dropout_helper
.
DropoutActBias
(
fused_act_dropout_helper
.
DropoutActBias
(
dev_ctx
,
dev_ctx
,
ffn1_out_data
,
ffn1_biases
[
i
]
->
data
<
T
>
(),
"gelu"
,
ffn1_out_data
,
ffn1_dropout_out_data
,
ffn1_dropout_mask_data
);
ffn1_biases
[
i
]
->
data
<
T
>
(),
"gelu"
,
ffn1_dropout_out_data
,
ffn1_dropout_mask_data
);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step7"
;
VLOG
(
0
)
<<
"step7"
;
#endif
#endif
// step8. ffn matmul2
// step8. ffn matmul2
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
if
(
pre_layer_norm
)
{
nullptr
,
buf1
,
nullptr
);
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf1
,
nullptr
);
}
else
{
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf0
,
nullptr
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.0"
;
VLOG
(
0
)
<<
"step8.0"
;
#endif
#endif
if
(
pre_layer_norm
)
{
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
}
else
{
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.1"
;
VLOG
(
0
)
<<
"step8.1"
;
#endif
#endif
...
@@ -1312,25 +1432,51 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1312,25 +1432,51 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
*
ln_scale_data
=
ln_scales
[
i
+
1
]
->
data
<
U
>
();
auto
*
ln_scale_data
=
ln_scales
[
i
+
1
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
+
1
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
+
1
]
->
data
<
U
>
();
ffn2_fused_dropout_helper
.
LayernormResidualDropoutBias
(
ffn2_fused_dropout_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf1
->
data
<
T
>
(),
bias_dropout_residual_out_data
,
dev_ctx
,
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf1
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
dropout_mask_out_data
,
buf0
->
data
<
T
>
(),
bias_dropout_residual_out_data
,
ln_mean_data
,
ln_var_data
);
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf1
->
data
<
T
>
(),
dropout_mask_out_data
,
buf0
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
else
{
}
else
{
ffn2_fused_dropout_helper
.
ResidualDropoutBias
(
ffn2_fused_dropout_helper
.
ResidualDropoutBias
(
dev_ctx
,
buf1
->
data
<
T
>
(),
bias_dropout_residual_out_data
,
dev_ctx
,
ffn2_biases
[
i
]
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
bias_dropout_residual_out_data
,
ffn2_biases
[
i
]
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
dropout_mask_out_data
);
dropout_mask_out_data
);
}
}
}
else
{
}
else
{
auto
*
ln_scale_data
=
ffn_ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ffn_ln_biases
[
i
]
->
data
<
U
>
();
ffn2_fused_dropout_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step9"
;
VLOG
(
0
)
<<
"step9"
;
#endif
#endif
if
(
pre_layer_norm
)
{
x_data
=
buf1
->
data
<
T
>
();
x_data
=
buf1
->
data
<
T
>
();
std
::
swap
(
buf0
,
buf1
);
std
::
swap
(
buf0
,
buf1
);
}
}
}
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
浏览文件 @
c5f4a9cc
...
@@ -39,6 +39,7 @@ default_main_program().random_seed = 42
...
@@ -39,6 +39,7 @@ default_main_program().random_seed = 42
class
TestFusedMultiTransformerOp
(
OpTest
):
class
TestFusedMultiTransformerOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
config
()
self
.
config
()
self
.
generate_input_data
()
self
.
generate_input_data
()
...
@@ -61,36 +62,30 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -61,36 +62,30 @@ class TestFusedMultiTransformerOp(OpTest):
bias_attr
=
paddle
.
fluid
.
ParamAttr
(
bias_attr
=
paddle
.
fluid
.
ParamAttr
(
initializer
=
paddle
.
fluid
.
initializer
.
Constant
(
value
=
0.0005
))
initializer
=
paddle
.
fluid
.
initializer
.
Constant
(
value
=
0.0005
))
self
.
q_proj
=
Linear
(
self
.
q_proj
=
Linear
(
self
.
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
self
.
weight_attr
,
bias_attr
=
bias_attr
)
bias_attr
=
bias_attr
)
#bias_attr=self.bias_attr)
#bias_attr=self.bias_attr)
self
.
k_proj
=
Linear
(
self
.
k_proj
=
Linear
(
self
.
kdim
,
self
.
kdim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
)
bias_attr
=
self
.
bias_attr
)
self
.
v_proj
=
Linear
(
self
.
v_proj
=
Linear
(
self
.
vdim
,
self
.
vdim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
)
bias_attr
=
self
.
bias_attr
)
self
.
out_proj
=
Linear
(
self
.
out_proj
=
Linear
(
self
.
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
)
bias_attr
=
self
.
bias_attr
)
self
.
ffn1_proj
=
Linear
(
self
.
ffn1_proj
=
Linear
(
self
.
embed_dim
,
self
.
embed_dim
,
4
*
self
.
embed_dim
,
4
*
self
.
embed_dim
,
self
.
weight_attr
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
)
bias_attr
=
self
.
bias_attr
)
self
.
ffn2_proj
=
Linear
(
self
.
ffn2_proj
=
Linear
(
4
*
self
.
embed_dim
,
4
*
self
.
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
)
bias_attr
=
self
.
bias_attr
)
...
@@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest):
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
# --> [B, n_head, seq_len, out_seq_len]
qk_out
=
layers
.
matmul
(
qk_out
=
layers
.
matmul
(
x
=
q_out
,
x
=
q_out
,
y
=
k_out
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
y
=
k_out
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
self
.
debug
:
if
self
.
debug
:
print
(
'qk out is'
)
print
(
'qk out is'
)
...
@@ -249,8 +246,7 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -249,8 +246,7 @@ class TestFusedMultiTransformerOp(OpTest):
print
(
'softmax out is'
)
print
(
'softmax out is'
)
print
(
softmax_out
[
0
][
0
][
0
])
print
(
softmax_out
[
0
][
0
][
0
])
if
self
.
dropout_prob
:
if
self
.
dropout_prob
:
dropout_out
=
F
.
dropout
(
dropout_out
=
F
.
dropout
(
softmax_out
,
softmax_out
,
self
.
dropout_prob
,
self
.
dropout_prob
,
training
=
self
.
training
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
mode
=
"upscale_in_train"
)
...
@@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest):
print
(
'fmha out is'
)
print
(
'fmha out is'
)
print
(
fmha_out
[
0
][
0
][
0
])
print
(
fmha_out
[
0
][
0
][
0
])
out_linear_in
=
tensor
.
reshape
(
out_linear_in
=
tensor
.
reshape
(
x
=
fmha_out
,
x
=
fmha_out
,
shape
=
[
0
,
0
,
fmha_out
.
shape
[
2
]
*
fmha_out
.
shape
[
3
]])
shape
=
[
0
,
0
,
fmha_out
.
shape
[
2
]
*
fmha_out
.
shape
[
3
]])
out
=
self
.
out_proj
(
out_linear_in
)
out
=
self
.
out_proj
(
out_linear_in
)
residual_out
=
residual
+
self
.
dropout
(
out
)
residual_out
=
residual
+
self
.
dropout
(
out
)
...
@@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest):
def
GetFusedMultiTransformerOut
(
self
):
def
GetFusedMultiTransformerOut
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
q_proj_weight
=
paddle
.
to_tensor
(
q_proj_weight
=
paddle
.
to_tensor
(
self
.
q_proj
.
weight
,
self
.
q_proj
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
k_proj_weight
=
paddle
.
to_tensor
(
k_proj_weight
=
paddle
.
to_tensor
(
self
.
k_proj
.
weight
,
self
.
k_proj
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
v_proj_weight
=
paddle
.
to_tensor
(
v_proj_weight
=
paddle
.
to_tensor
(
self
.
v_proj
.
weight
,
self
.
v_proj
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
out_linear_weight
=
paddle
.
to_tensor
(
out_linear_weight
=
paddle
.
to_tensor
(
self
.
out_proj
.
weight
,
self
.
out_proj
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
ffn1_weight
=
paddle
.
to_tensor
(
ffn1_weight
=
paddle
.
to_tensor
(
self
.
ffn1_proj
.
weight
,
self
.
ffn1_proj
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
ffn2_weight
=
paddle
.
to_tensor
(
ffn2_weight
=
paddle
.
to_tensor
(
self
.
ffn2_proj
.
weight
,
self
.
ffn2_proj
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
if
self
.
bias_attr
is
False
:
if
self
.
bias_attr
is
False
:
qkv_bias_tensor
=
None
qkv_bias_tensor
=
None
out_linear_bias
=
None
out_linear_bias
=
None
else
:
else
:
q_proj_bias
=
paddle
.
to_tensor
(
q_proj_bias
=
paddle
.
to_tensor
(
self
.
q_proj
.
bias
,
self
.
q_proj
.
bias
,
stop_gradient
=
False
)
stop_gradient
=
False
)
k_proj_bias
=
paddle
.
to_tensor
(
k_proj_bias
=
paddle
.
to_tensor
(
self
.
k_proj
.
bias
,
self
.
k_proj
.
bias
,
stop_gradient
=
False
)
stop_gradient
=
False
)
v_proj_bias
=
paddle
.
to_tensor
(
v_proj_bias
=
paddle
.
to_tensor
(
self
.
v_proj
.
bias
,
self
.
v_proj
.
bias
,
stop_gradient
=
False
)
stop_gradient
=
False
)
qkv_bias
=
np
.
concatenate
(
qkv_bias
=
np
.
concatenate
(
(
q_proj_bias
.
numpy
(),
k_proj_bias
.
numpy
(),
v_proj_bias
.
numpy
()))
(
q_proj_bias
.
numpy
(),
k_proj_bias
.
numpy
(),
v_proj_bias
.
numpy
()))
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
,
stop_gradient
=
False
)
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
,
stop_gradient
=
False
)
out_linear_bias
=
paddle
.
to_tensor
(
out_linear_bias
=
paddle
.
to_tensor
(
self
.
out_proj
.
bias
,
self
.
out_proj
.
bias
,
stop_gradient
=
False
)
stop_gradient
=
False
)
ffn1_bias
=
paddle
.
to_tensor
(
ffn1_bias
=
paddle
.
to_tensor
(
self
.
ffn1_proj
.
bias
,
self
.
ffn1_proj
.
bias
,
stop_gradient
=
False
)
stop_gradient
=
False
)
ffn2_bias
=
paddle
.
to_tensor
(
ffn2_bias
=
paddle
.
to_tensor
(
self
.
ffn2_proj
.
bias
,
self
.
ffn2_proj
.
bias
,
stop_gradient
=
False
)
stop_gradient
=
False
)
ln_scale
=
paddle
.
to_tensor
(
self
.
norm
.
weight
,
stop_gradient
=
False
)
ln_scale
=
paddle
.
to_tensor
(
self
.
norm
.
weight
,
stop_gradient
=
False
)
ln_bias
=
paddle
.
to_tensor
(
self
.
norm
.
bias
,
stop_gradient
=
False
)
ln_bias
=
paddle
.
to_tensor
(
self
.
norm
.
bias
,
stop_gradient
=
False
)
ffn_ln_scale
=
paddle
.
to_tensor
(
ffn_ln_scale
=
paddle
.
to_tensor
(
self
.
ffn_norm
.
weight
,
self
.
ffn_norm
.
weight
,
stop_gradient
=
False
)
stop_gradient
=
False
)
ffn_ln_bias
=
paddle
.
to_tensor
(
self
.
ffn_norm
.
bias
,
stop_gradient
=
False
)
ffn_ln_bias
=
paddle
.
to_tensor
(
self
.
ffn_norm
.
bias
,
stop_gradient
=
False
)
q_proj_weight
=
q_proj_weight
.
numpy
().
transpose
((
1
,
0
))
q_proj_weight
=
q_proj_weight
.
numpy
().
transpose
((
1
,
0
))
...
@@ -351,8 +346,7 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -351,8 +346,7 @@ class TestFusedMultiTransformerOp(OpTest):
cache_kvs
=
[]
cache_kvs
=
[]
max_seq_length
=
(
self
.
cache_length
+
128
)
//
128
*
128
max_seq_length
=
(
self
.
cache_length
+
128
)
//
128
*
128
cache_kv
=
np
.
zeros
(
cache_kv
=
np
.
zeros
([
[
2
,
self
.
batch_size
,
self
.
num_heads
,
max_seq_length
,
2
,
self
.
batch_size
,
self
.
num_heads
,
max_seq_length
,
self
.
head_dim
self
.
head_dim
],
],
...
@@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest):
assert
self
.
query_length
==
self
.
cache_length
assert
self
.
query_length
==
self
.
cache_length
cache_kv
[:]
=
0
cache_kv
[:]
=
0
else
:
else
:
time_step
=
paddle
.
to_tensor
(
time_step
=
paddle
.
to_tensor
([
self
.
cache_length
],
[
self
.
cache_length
],
dtype
=
'int32'
,
place
=
paddle
.
CPUPlace
())
dtype
=
'int32'
,
place
=
paddle
.
CPUPlace
())
if
self
.
has_attn_mask
:
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
else
:
...
@@ -417,12 +412,10 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -417,12 +412,10 @@ class TestFusedMultiTransformerOp(OpTest):
ffn_ln_scales
.
append
(
ffn_ln_scale
)
ffn_ln_scales
.
append
(
ffn_ln_scale
)
ffn_ln_biases
.
append
(
ffn_ln_bias
)
ffn_ln_biases
.
append
(
ffn_ln_bias
)
if
self
.
has_cache_kv
:
if
self
.
has_cache_kv
:
cache_kvs
.
append
(
cache_kvs
.
append
(
paddle
.
to_tensor
(
cache_kv
,
paddle
.
to_tensor
(
stop_gradient
=
False
))
cache_kv
,
stop_gradient
=
False
))
final_out
=
fused_multi_transformer
(
final_out
=
fused_multi_transformer
(
x
,
x
,
ln_scales
,
ln_scales
,
ln_biases
,
ln_biases
,
qkv_weights
,
qkv_weights
,
...
@@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest):
if
self
.
debug
:
if
self
.
debug
:
print
(
"cache_k out timestep=128"
)
print
(
"cache_k out timestep=128"
)
print
(
cache_kv_out
[
0
].
reshape
(
[
print
(
cache_kv_out
[
0
].
reshape
(
2
,
bsz
,
num_head
,
v_elems
,
max_seq_len
,
elems
[
2
,
bsz
,
num_head
,
v_elems
,
max_seq_len
,
])[
0
,
0
,
0
,
:,
self
.
cache_length
,
:])
elems
])[
0
,
0
,
0
,
:,
self
.
cache_length
,
:])
print
(
"cache_v out timestep=128"
)
print
(
"cache_v out timestep=128"
)
print
(
cache_kv_out
[
0
][
1
,
0
,
0
,
self
.
cache_length
,
:])
print
(
cache_kv_out
[
0
][
1
,
0
,
0
,
self
.
cache_length
,
:])
...
@@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest):
...
@@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest):
cache_v
=
cache_kv_out
[
i
][
1
,
:,
:,
:
self
.
cache_length
,
:]
cache_v
=
cache_kv_out
[
i
][
1
,
:,
:,
:
self
.
cache_length
,
:]
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
cache_k_ref
,
cache_k_ref
,
cache_k
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
cache_k
,
np
.
testing
.
assert_allclose
(
rtol
=
self
.
rtol
,
cache_v_ref
,
cache_v
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
atol
=
self
.
atol
)
np
.
testing
.
assert_allclose
(
cache_v_ref
,
cache_v
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
if
i
==
0
:
if
i
==
0
:
break
break
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out_ref
,
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
final_out
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
class
TestFusedMultiTransformerOpFp16
(
TestFusedMultiTransformerOp
):
class
TestFusedMultiTransformerOpFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
def
config
(
self
):
super
().
config
()
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
x_type
=
np
.
float16
...
@@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
...
@@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
class
TestFusedMultiTransformerOpCacheKV
(
TestFusedMultiTransformerOp
):
class
TestFusedMultiTransformerOpCacheKV
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
def
config
(
self
):
super
().
config
()
super
().
config
()
self
.
has_cache_kv
=
True
self
.
has_cache_kv
=
True
...
@@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
...
@@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
class
TestFusedMultiTransformerOpCacheKVFp16
(
TestFusedMultiTransformerOp
):
class
TestFusedMultiTransformerOpCacheKVFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
def
config
(
self
):
super
().
config
()
super
().
config
()
self
.
has_cache_kv
=
True
self
.
has_cache_kv
=
True
...
@@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
...
@@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
class
TestFusedMultiTransformerOpGenCacheKV
(
TestFusedMultiTransformerOp
):
class
TestFusedMultiTransformerOpGenCacheKV
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
def
config
(
self
):
super
().
config
()
super
().
config
()
self
.
has_cache_kv
=
True
self
.
has_cache_kv
=
True
...
@@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
...
@@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
class
TestFusedMultiTransformerOpGenCacheKVFp16
(
TestFusedMultiTransformerOp
):
class
TestFusedMultiTransformerOpGenCacheKVFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
class
TestFusedMultiTransformerOpPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpCacheKVPostLayerNorm
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
x_type
=
np
.
float16
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpGenCacheKVPostLayerNorm
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
def
config
(
self
):
super
().
config
()
super
().
config
()
self
.
has_cache_kv
=
True
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
x_type
=
np
.
float16
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录