Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1882c496
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1882c496
编写于
3月 11, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] Support tensor parallel and cache structure for fused attention op. (#40101)
上级
e24ca55e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
608 addition
and
147 deletion
+608
-147
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+30
-9
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+69
-12
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+59
-10
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+12
-7
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py
.../tests/unittests/static_model_parallel_fused_attention.py
+297
-0
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+66
-93
python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py
...s/unittests/test_static_model_parallel_fused_attention.py
+45
-0
python/paddle/incubate/nn/functional/fused_transformer.py
python/paddle/incubate/nn/functional/fused_transformer.py
+28
-16
未找到文件。
paddle/fluid/operators/fused/fmha_ref.h
浏览文件 @
1882c496
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace
paddle
{
...
...
@@ -69,20 +70,21 @@ class FMHARef {
~
FMHARef
()
{}
void
ComputeForward
(
const
Tensor
&
qkv_input_tensor
,
const
Tensor
*
cache_kv_tensor
,
const
Tensor
*
src_mask_tensor
,
Tensor
*
transpose_2_out_tensor
,
Tensor
*
qk_out_tensor
,
Tensor
*
transpose_2_out_tensor
,
Tensor
*
cache_kv_out_tensor
,
Tensor
*
qk_out_tensor
,
Tensor
*
src_mask_out_tensor
,
Tensor
*
softmax_out_tensor
,
Tensor
*
dropout_mask_out_tensor
,
Tensor
*
dropout_out_tensor
,
Tensor
*
qktv_out_tensor
,
Tensor
*
fmha_out_tensor
)
{
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0,
1, 3
, 4],
// transpose with perm [2, 0,
3, 1
, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
int
ndims
=
5
;
std
::
vector
<
int
>
perm_1
=
{
2
,
0
,
3
,
1
,
4
};
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
qkv_input_tensor
,
perm_1
,
transpose_2_out_tensor
);
T
*
qkv_data
=
transpose_2_out_tensor
->
data
<
T
>
();
T
*
qk_out_data
=
qk_out_tensor
->
data
<
T
>
();
T
*
qktv_out_data
=
qktv_out_tensor
->
data
<
T
>
();
...
...
@@ -90,11 +92,30 @@ class FMHARef {
T
*
dropout_out_data
=
dropout_out_tensor
->
data
<
T
>
();
T
*
fmha_out_data
=
fmha_out_tensor
->
data
<
T
>
();
int
q_size
=
batch_size_
*
seq_len_
*
num_head_
*
head_dim_
;
int
k_size
=
q_size
;
auto
out_seq_len
=
seq_len_
;
if
(
cache_kv_tensor
)
{
// kv [2, bs, num_head, seq_len, head_dim]
auto
kv_tensor
=
transpose_2_out_tensor
->
Slice
(
1
,
3
);
phi
::
funcs
::
ConcatFunctor
<
phi
::
GPUContext
,
T
>
concat
;
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat
(
dev_ctx_
,
{
*
cache_kv_tensor
,
kv_tensor
},
3
,
cache_kv_out_tensor
);
out_seq_len
=
cache_kv_out_tensor
->
dims
()[
3
];
}
int64_t
q_size
=
batch_size_
*
seq_len_
*
num_head_
*
head_dim_
;
T
*
q_ptr
=
qkv_data
;
T
*
k_ptr
=
q_ptr
+
q_size
;
T
*
v_ptr
=
k_ptr
+
k_size
;
T
*
k_ptr
=
nullptr
;
T
*
v_ptr
=
nullptr
;
if
(
cache_kv_tensor
)
{
int64_t
k_size
=
cache_kv_out_tensor
->
numel
()
/
2
;
k_ptr
=
cache_kv_out_tensor
->
data
<
T
>
();
v_ptr
=
k_ptr
+
k_size
;
}
else
{
int64_t
k_size
=
q_size
;
k_ptr
=
q_ptr
+
q_size
;
v_ptr
=
k_ptr
+
k_size
;
}
// q*k^t, batched_gemm
CBLAS_TRANSPOSE
transA
=
CblasNoTrans
;
...
...
@@ -102,7 +123,7 @@ class FMHARef {
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
int
gemm_batch_size
=
batch_size_
*
num_head_
;
int
gemm_m
=
seq_len_
;
int
gemm_n
=
seq_len_
;
int
gemm_n
=
out_seq_len
;
int
gemm_k
=
head_dim_
;
T
alpha
=
static_cast
<
T
>
(
1.0
/
sqrt
(
head_dim_
));
T
beta
=
static_cast
<
T
>
(
0.0
);
...
...
@@ -133,7 +154,7 @@ class FMHARef {
transB
=
CblasNoTrans
;
gemm_m
=
seq_len_
;
gemm_n
=
head_dim_
;
gemm_k
=
seq_len_
;
gemm_k
=
out_seq_len
;
alpha
=
static_cast
<
T
>
(
1.0
);
stride_a
=
gemm_m
*
gemm_k
;
stride_b
=
gemm_k
*
gemm_n
;
...
...
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
1882c496
...
...
@@ -61,6 +61,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKTVOut"
),
"Output"
,
"QKTVOut"
,
"FusedAttentionOp"
);
if
(
ctx
->
HasInput
(
"CacheKV"
))
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"CacheKVOut"
),
"Output"
,
"CacheKVOut"
,
"FusedAttentionOp"
);
}
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SrcMaskOut"
),
"Output"
,
"SrcMaskOut"
,
"FusedAttentionOp"
);
...
...
@@ -105,12 +109,14 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
if
(
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
)
==
-
1
)
{
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
}
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
...
...
@@ -132,20 +138,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [3, batch_size, num_head, seq_len, head_size]
ctx
->
SetOutputDim
(
"TransposeOut2"
,
{
y_dim
[
0
],
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
// [batch, num_head, seq_len, seq_len]
ctx
->
SetOutputDim
(
"QKOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
// cache_seq_len + seq_len if cache else seq_len
auto
out_seq_len
=
x_dim
[
1
];
if
(
ctx
->
HasInput
(
"CacheKV"
))
{
// [2, batch_size, num_head, cache_seq_len, head_size]
auto
c_dim
=
ctx
->
GetInputDim
(
"CacheKV"
);
PADDLE_ENFORCE_EQ
(
c_dim
.
size
(),
5
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The CacheKV must be 5 dims, but got %d"
,
c_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
c_dim
[
0
],
2
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The first dim of CacheKV must be 2, but got %d"
,
c_dim
[
0
]));
// 2
PADDLE_ENFORCE_EQ
(
c_dim
[
1
],
x_dim
[
0
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d"
,
x_dim
[
0
],
c_dim
[
1
]));
// batch_size
PADDLE_ENFORCE_EQ
(
c_dim
[
2
],
y_dim
[
1
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d"
,
y_dim
[
1
],
c_dim
[
2
]));
// num_head
PADDLE_ENFORCE_GE
(
c_dim
[
3
],
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The forth dim of CacheKV must be greater than 0, but got %d"
,
c_dim
[
3
]));
// cache_seq_len
PADDLE_ENFORCE_EQ
(
c_dim
[
4
],
y_dim
[
2
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d"
,
y_dim
[
2
],
c_dim
[
4
]));
// head_size
out_seq_len
+=
c_dim
[
3
];
// [3, batch_size, num_head, cache_seq_len + seq_len, head_size]
ctx
->
SetOutputDim
(
"CacheKVOut"
,
{
c_dim
[
0
],
c_dim
[
1
],
c_dim
[
2
],
out_seq_len
,
c_dim
[
4
]});
}
// [batch, num_head, seq_len, out_seq_len]
ctx
->
SetOutputDim
(
"QKOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
ctx
->
SetOutputDim
(
"SrcMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"SrcMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
}
// the same as QKOut's shape.
ctx
->
SetOutputDim
(
"AttnDropoutOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]
});
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"attn_dropout_is_test"
)
==
false
)
{
ctx
->
SetOutputDim
(
"AttnDropoutMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]
});
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
}
ctx
->
SetOutputDim
(
"SoftmaxOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"SoftmaxOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
// [batch_size, num_heads, seq_len, head_dim]
ctx
->
SetOutputDim
(
"QKTVOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
// [batch_size, seq_len, number of heads*head size]
...
...
@@ -182,6 +232,8 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.
AsDispensable
();
AddInput
(
"QKVW"
,
"The qkv weight tensor."
);
AddInput
(
"QKVBias"
,
"The qkv bias tensor."
).
AsDispensable
();
AddInput
(
"CacheKV"
,
"(optional) The cached KV for generation inference."
)
.
AsDispensable
();
AddInput
(
"SrcMask"
,
"(optional) The attention mask tensor in fmha."
)
.
AsDispensable
();
AddInput
(
"OutLinearW"
,
"The out_linear weight tensor."
);
...
...
@@ -217,6 +269,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"BiasDropoutResidualOut"
,
"Result of residual + dropout(src + bias)."
)
.
AsIntermediate
();
AddOutput
(
"CacheKVOut"
,
"The udpated cache KV."
);
AddOutput
(
"Y"
,
"Result after attention."
);
AddAttr
<
bool
>
(
"pre_layer_norm"
,
...
...
@@ -324,6 +377,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s]."
,
ln_epsilon
));
});
AddAttr
<
int
>
(
"ring_id"
,
"ring id for tensor model parallel. distributed training and inference"
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
Add fused attention op whose logic is as follows:
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
1882c496
...
...
@@ -27,11 +27,39 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
static
void
AllReduce
(
framework
::
Tensor
&
tensor
,
// NOLINT
const
int
ring_id
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
if
(
ring_id
==
-
1
)
return
;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
dtype
=
platform
::
ToNCCLDataType
(
framework
::
TransToProtoVarType
(
tensor
.
dtype
()));
int64_t
numel
=
tensor
.
numel
();
const
void
*
sendbuff
=
tensor
.
data
<
T
>
();
auto
place
=
ctx
.
GetPlace
();
void
*
recvbuff
=
tensor
.
mutable_data
<
T
>
(
place
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
auto
stream
=
ctx
.
stream
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
dtype
,
ncclSum
,
comm
->
comm
(),
stream
));
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."
));
#endif
}
template
<
typename
T
>
class
FusedAttentionOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -56,6 +84,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
src_mask
=
ctx
.
Input
<
Tensor
>
(
"SrcMask"
);
auto
*
transpose_out_2
=
ctx
.
Output
<
Tensor
>
(
"TransposeOut2"
);
auto
*
cache_kv
=
ctx
.
Input
<
Tensor
>
(
"CacheKV"
);
auto
*
cache_kv_out
=
ctx
.
Output
<
Tensor
>
(
"CacheKVOut"
);
auto
*
qk_out
=
ctx
.
Output
<
Tensor
>
(
"QKOut"
);
auto
*
qktv_out
=
ctx
.
Output
<
Tensor
>
(
"QKTVOut"
);
auto
*
softmax_out
=
ctx
.
Output
<
Tensor
>
(
"SoftmaxOut"
);
...
...
@@ -86,6 +116,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
seed_1
=
ctx
.
HasInput
(
"Seed1"
)
?
ctx
.
Input
<
Tensor
>
(
"Seed1"
)
:
nullptr
;
bool
is_fix_seed_1
=
ctx
.
Attr
<
bool
>
(
"attn_dropout_fix_seed"
);
int
seed_val_1
=
ctx
.
Attr
<
int
>
(
"attn_dropout_seed"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// final output.
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
...
...
@@ -105,6 +136,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// get data ptr for FMHA.
auto
*
transpose_out_2_data
=
transpose_out_2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
cache_kv_out_data
=
(
cache_kv_out
==
nullptr
)
?
nullptr
:
cache_kv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qk_out_data
=
qk_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qktv_out_data
=
qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
src_mask_out_data
=
...
...
@@ -161,9 +196,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
output_size
=
hidden_size
;
// (transA, transB, compute_bias) = (false, false, false)
// NOTE(Yuang Liu): For general input size == output size, change the
// position won't have effects. For mp, the output size is mp_head * dkey
// which is actually the input size. While the input size is hidden size,
// which is actually the output size. So for out linear, switch the
// input size and output size.
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
false
,
false
,
bsz_seq
,
output_size
,
in
put_size
,
false
);
input_size
,
out
put_size
,
false
);
DropoutParam
dropout_param2
(
ctx
,
0
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper
(
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
...
...
@@ -186,15 +226,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_bias_out
);
}
if
(
qkv_bias
==
nullptr
)
{
fmha_ref_compute
.
ComputeForward
(
*
qkv_out
,
src_mask
,
transpose_out_2
,
qk_out
,
src_mask_out
,
softmax
_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
fmha_ref_compute
.
ComputeForward
(
*
qkv_out
,
cache_kv
,
src_mask
,
transpose_out_2
,
cache_kv_out
,
qk
_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
}
else
{
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
src_mask
,
transpose_out_2
,
qk_out
,
src_mask_out
,
softmax
_out
,
attn_dropout_mask_out
,
attn_dropout
_out
,
qktv_out
,
fmha_out
);
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
cache_kv
,
src_mask
,
transpose_out_2
,
cache_kv
_out
,
qk_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask
_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
...
...
@@ -202,6 +242,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute
.
ComputeForward
(
out_linear_weight
,
fmha_out
,
nullptr
,
out_linear_out
,
nullptr
);
// tensor model parallel
AllReduce
<
T
>
(
*
out_linear_out
,
ring_id
,
ctx
.
cuda_device_context
());
if
(
pre_layer_norm
)
{
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper
.
ResidualDropoutBias
(
...
...
@@ -244,6 +287,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
seed_1
=
ctx
.
HasInput
(
"Seed1"
)
?
ctx
.
Input
<
Tensor
>
(
"Seed1"
)
:
nullptr
;
bool
is_fix_seed_1
=
ctx
.
Attr
<
bool
>
(
"attn_dropout_fix_seed"
);
int
seed_val_1
=
ctx
.
Attr
<
int
>
(
"attn_dropout_seed"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// get inputs.
auto
*
d_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
...
...
@@ -399,9 +443,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
transA
=
false
;
transB
=
false
;
bool
compute_bias
=
false
;
// (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
output_size
,
in
put_size
,
compute_bias
);
input_size
,
out
put_size
,
compute_bias
);
DropoutParam
dropout_param2
(
ctx
,
0
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper
(
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
...
...
@@ -475,6 +520,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
qkv_compute
.
ComputeBackward
(
ln_out
,
qkv_weight
,
d_qkv_out
,
d_ln_out
,
d_qkv_weight
,
d_qkv_bias
);
}
// tensor model parallel
AllReduce
<
T
>
(
*
d_ln_out
,
ring_id
,
ctx
.
cuda_device_context
());
layer_norm_compute
.
ComputeBackward
(
x_data
,
d_ln_out_data
,
ln_scale_data
,
ln_mean_data
,
ln_var_data
,
d_x_data
,
d_ln_scale_data
,
d_ln_bias_data
);
...
...
@@ -486,6 +533,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
qkv_compute
.
ComputeBackward
(
input_x
,
qkv_weight
,
d_qkv_out
,
d_x
,
d_qkv_weight
,
d_qkv_bias
);
}
// tensor model parallel
AllReduce
<
T
>
(
*
d_x
,
ring_id
,
ctx
.
cuda_device_context
());
}
// gradient accumulation
std
::
vector
<
const
Tensor
*>
ins
;
...
...
paddle/fluid/pybind/op_function_generator.h
浏览文件 @
1882c496
...
...
@@ -30,8 +30,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"bincount"
,
{
"X"
,
"Weights"
}},
{
"fused_attention"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"
SrcMask"
,
"OutLinearW
"
,
"OutLinearBias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"
CacheKV"
,
"SrcMask
"
,
"OutLinear
W"
,
"OutLinear
Bias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"instance_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
...
...
@@ -104,11 +104,16 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{
"batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
{
"fused_attention"
,
{
"LnMean"
,
"LnVariance"
,
"LnOut"
,
"QKVOut"
,
"QKVBiasOut"
,
"TransposeOut2"
,
"QKOut"
,
"QKTVOut"
,
"SoftmaxOut"
,
"AttnDropoutMaskOut"
,
"AttnDropoutOut"
,
"SrcMaskOut"
,
"FMHAOut"
,
"OutLinearOut"
,
"DropoutMaskOut"
,
"Ln2Mean"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"Y"
}},
{
"fused_attention"
,
{
"LnMean"
,
"LnVariance"
,
"LnOut"
,
"QKVOut"
,
"QKVBiasOut"
,
"TransposeOut2"
,
"QKOut"
,
"QKTVOut"
,
"SoftmaxOut"
,
"AttnDropoutMaskOut"
,
"AttnDropoutOut"
,
"SrcMaskOut"
,
"FMHAOut"
,
"OutLinearOut"
,
"DropoutMaskOut"
,
"Ln2Mean"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"CacheKVOut"
,
"Y"
}},
{
"sync_batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
1882c496
...
...
@@ -24,6 +24,7 @@ list(APPEND DIST_TEST_OPS test_pipeline)
list
(
APPEND DIST_TEST_OPS test_ir_pass_pipeline
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
...
...
@@ -1155,6 +1156,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties
(
test_ir_pass_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_static_model_parallel PROPERTIES TIMEOUT 240
)
set_tests_properties
(
test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_split_embedding
test_collective_split_embedding_none_divisible
test_collective_split_row_linear
...
...
python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py
0 → 100644
浏览文件 @
1882c496
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
import
paddle.distributed.fleet
as
fleet
import
paddle.incubate.nn.functional
as
incubate_f
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid
import
core
from
paddle.nn.initializer
import
Constant
paddle
.
enable_static
()
def
_set_var_distributed
(
var
):
if
var
is
None
:
return
var
.
is_distributed
=
True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block
=
paddle
.
static
.
default_startup_program
().
current_block
()
main_block
=
paddle
.
static
.
default_main_program
().
current_block
()
startup_block
.
_find_var_recursive
(
var
.
name
).
is_distributed
=
True
main_block
.
_find_var_recursive
(
var
.
name
).
is_distributed
=
True
class
ParallelFusedMultiHeadAttention
(
Layer
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout_rate
=
0.5
,
attn_dropout_rate
=
0.5
,
kdim
=
None
,
vdim
=
None
,
normalize_before
=
False
,
need_weights
=
False
,
qkv_weight_attr
=
None
,
qkv_bias_attr
=
None
,
linear_weight_attr
=
None
,
linear_bias_attr
=
None
,
pre_ln_scale_attr
=
None
,
pre_ln_bias_attr
=
None
,
ln_scale_attr
=
None
,
ln_bias_attr
=
None
,
epsilon
=
1e-5
,
nranks
=
1
,
ring_id
=-
1
,
name
=
None
):
super
(
ParallelFusedMultiHeadAttention
,
self
).
__init__
()
assert
embed_dim
>
0
,
(
"Expected embed_dim to be greater than 0, "
"but recieved {}"
.
format
(
embed_dim
))
assert
num_heads
>
0
,
(
"Expected nhead to be greater than 0, "
"but recieved {}"
.
format
(
num_heads
))
self
.
normalize_before
=
normalize_before
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_epsilon
=
epsilon
self
.
_ring_id
=
ring_id
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
kdim
=
kdim
self
.
vdim
=
vdim
self
.
need_weights
=
need_weights
assert
self
.
head_dim
*
num_heads
==
embed_dim
,
"embed_dim must be divisible by num_heads"
assert
need_weights
==
False
,
"Only support need_weight is False now."
# tensor model parallel
assert
num_heads
%
nranks
==
0
num_heads
=
num_heads
//
nranks
self
.
qkv_weight
=
self
.
create_parameter
(
shape
=
[
3
,
num_heads
,
self
.
head_dim
,
embed_dim
],
attr
=
qkv_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
qkv_bias
=
self
.
create_parameter
(
shape
=
[
3
,
num_heads
,
self
.
head_dim
],
attr
=
qkv_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
self
.
linear_weight
=
self
.
create_parameter
(
shape
=
[
num_heads
*
self
.
head_dim
,
embed_dim
],
attr
=
linear_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
linear_bias
=
self
.
create_parameter
(
shape
=
[
embed_dim
],
attr
=
linear_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
# tensor model parallel
if
nranks
>
1
:
assert
ring_id
!=
-
1
# column parallel
_set_var_distributed
(
self
.
qkv_weight
)
_set_var_distributed
(
self
.
qkv_bias
)
# row parallel
_set_var_distributed
(
self
.
linear_weight
)
if
normalize_before
:
self
.
pre_ln_scale
=
self
.
create_parameter
(
attr
=
pre_ln_scale_attr
,
shape
=
[
embed_dim
],
default_initializer
=
Constant
(
value
=
1.0
))
self
.
pre_ln_bias
=
self
.
create_parameter
(
attr
=
pre_ln_bias_attr
,
shape
=
[
embed_dim
],
is_bias
=
True
)
self
.
ln_scale
=
None
self
.
ln_bias
=
None
else
:
self
.
pre_ln_scale
=
None
self
.
pre_ln_bias
=
None
self
.
ln_scale
=
self
.
create_parameter
(
attr
=
ln_scale_attr
,
shape
=
[
embed_dim
],
default_initializer
=
Constant
(
value
=
1.0
))
self
.
ln_bias
=
self
.
create_parameter
(
attr
=
ln_bias_attr
,
shape
=
[
embed_dim
],
is_bias
=
True
)
self
.
dropout_rate
=
dropout_rate
self
.
attn_dropout_rate
=
attn_dropout_rate
self
.
name
=
name
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
attn_mask
=
None
,
cache
=
None
):
out
=
incubate_f
.
fused_multi_head_attention
(
x
=
query
,
qkv_weight
=
self
.
qkv_weight
,
linear_weight
=
self
.
linear_weight
,
pre_layer_norm
=
self
.
normalize_before
,
pre_ln_scale
=
self
.
pre_ln_scale
,
pre_ln_bias
=
self
.
pre_ln_bias
,
ln_scale
=
self
.
ln_scale
,
ln_bias
=
self
.
ln_bias
,
pre_ln_epsilon
=
self
.
_epsilon
,
qkv_bias
=
self
.
qkv_bias
,
linear_bias
=
self
.
linear_bias
,
attn_mask
=
attn_mask
,
dropout_rate
=
self
.
dropout_rate
,
attn_dropout_rate
=
self
.
attn_dropout_rate
,
ln_epsilon
=
self
.
_epsilon
,
training
=
self
.
training
,
ring_id
=
self
.
_ring_id
,
name
=
self
.
name
)
return
out
def
get_param_attr
(
weight
,
bias
):
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
weight
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
bias
))
return
weight_attr
,
bias_attr
DTYPE
=
"float32"
MODEL_PARALLEL_SIZE
=
2
n_head
=
2
*
MODEL_PARALLEL_SIZE
d_key
=
4
hidden
=
n_head
*
d_key
def
create_model
(
data
,
rank
):
np
.
random
.
seed
(
2021
)
pre_ln_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
hidden
,
)).
astype
(
DTYPE
)
pre_ln_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
hidden
,
)).
astype
(
DTYPE
)
qkv_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
3
,
n_head
,
d_key
,
hidden
)).
astype
(
DTYPE
)
qkv_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
3
,
n_head
,
d_key
)).
astype
(
DTYPE
)
linear_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
n_head
*
d_key
,
hidden
)).
astype
(
DTYPE
)
linear_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
hidden
,
)).
astype
(
DTYPE
)
data
.
stop_gradient
=
False
if
rank
is
not
None
:
start
=
0
if
rank
==
0
else
n_head
//
MODEL_PARALLEL_SIZE
end
=
start
+
n_head
//
MODEL_PARALLEL_SIZE
col_qkv_w
=
qkv_w
[:,
start
:
end
,
:,
:]
col_qkv_b
=
qkv_b
[:,
start
:
end
,
:]
row_linear_w
=
linear_w
[(
start
*
d_key
):(
end
*
d_key
),
:]
pre_ln_w_attr
,
pre_ln_b_attr
=
get_param_attr
(
pre_ln_w
,
pre_ln_b
)
qkv_w_attr
,
qkv_b_attr
=
get_param_attr
(
col_qkv_w
,
col_qkv_b
)
linear_w_attr
,
linear_b_attr
=
get_param_attr
(
row_linear_w
,
linear_b
)
attn
=
ParallelFusedMultiHeadAttention
(
hidden
,
n_head
,
dropout_rate
=
0.0
,
attn_dropout_rate
=
0.0
,
normalize_before
=
False
,
qkv_weight_attr
=
qkv_w_attr
,
qkv_bias_attr
=
qkv_b_attr
,
linear_weight_attr
=
linear_w_attr
,
linear_bias_attr
=
linear_b_attr
,
pre_ln_scale_attr
=
pre_ln_w_attr
,
pre_ln_bias_attr
=
pre_ln_b_attr
,
ln_scale_attr
=
pre_ln_w_attr
,
ln_bias_attr
=
pre_ln_b_attr
,
nranks
=
MODEL_PARALLEL_SIZE
,
ring_id
=
0
)
result
=
attn
(
data
)
else
:
pre_ln_w_attr
,
pre_ln_b_attr
=
get_param_attr
(
pre_ln_w
,
pre_ln_b
)
qkv_w_attr
,
qkv_b_attr
=
get_param_attr
(
qkv_w
,
qkv_b
)
linear_w_attr
,
linear_b_attr
=
get_param_attr
(
linear_w
,
linear_b
)
attn
=
ParallelFusedMultiHeadAttention
(
hidden
,
n_head
,
dropout_rate
=
0.0
,
attn_dropout_rate
=
0.0
,
normalize_before
=
False
,
qkv_weight_attr
=
qkv_w_attr
,
qkv_bias_attr
=
qkv_b_attr
,
linear_weight_attr
=
linear_w_attr
,
linear_bias_attr
=
linear_b_attr
,
pre_ln_scale_attr
=
pre_ln_w_attr
,
pre_ln_bias_attr
=
pre_ln_b_attr
,
ln_scale_attr
=
pre_ln_w_attr
,
ln_bias_attr
=
pre_ln_b_attr
)
result
=
attn
(
data
)
predict
=
paddle
.
sum
(
result
)
return
predict
class
TestModelParallel
(
TestDistRunnerBase
):
def
get_model
(
self
,
batch_size
=
2
,
use_dgc
=
False
,
dist_strategy
=
None
):
# Input data
seq_len
=
2
data_in
=
fluid
.
data
(
name
=
'data_in'
,
shape
=
[
batch_size
,
seq_len
,
hidden
],
dtype
=
DTYPE
)
if
dist_strategy
:
data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
data_in
],
capacity
=
64
,
use_double_buffer
=
False
,
iterable
=
False
)
if
dist_strategy
:
fleet
.
init
(
is_collective
=
True
)
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
tensor_parallel
=
True
strategy
.
tensor_parallel_configs
=
{
'tensor_parallel_degree'
:
2
}
rank
=
fleet
.
worker_index
()
if
dist_strategy
else
None
avg_cost
=
create_model
(
data_in
,
rank
)
opt
=
fluid
.
optimizer
.
SGD
(
0.1
)
if
dist_strategy
:
dist_opt
=
fleet
.
distributed_optimizer
(
optimizer
=
opt
,
strategy
=
strategy
)
dist_opt
.
minimize
(
avg_cost
)
else
:
opt
.
minimize
(
avg_cost
)
def
gen_data
():
np
.
random
.
seed
(
2021
)
while
True
:
data
=
[
np
.
random
.
random
([
seq_len
,
hidden
]).
astype
(
DTYPE
)]
yield
data
train_reader
=
paddle
.
batch
(
gen_data
,
batch_size
=
batch_size
)
if
dist_strategy
:
return
None
,
avg_cost
,
train_reader
,
None
,
None
,
None
,
data_loader
else
:
return
None
,
avg_cost
,
train_reader
,
None
,
None
,
None
if
__name__
==
"__main__"
:
runtime_main
(
TestModelParallel
)
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
1882c496
...
...
@@ -70,10 +70,12 @@ class TestFusedAttentionOp(OpTest):
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
has_cache_kv
=
False
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
cache_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
...
...
@@ -88,10 +90,22 @@ class TestFusedAttentionOp(OpTest):
def
generate_input_data
(
self
):
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
).
astype
(
self
.
x_type
)
out_seq_len
=
self
.
key_length
if
self
.
has_cache_kv
:
assert
self
.
training
is
False
,
ValueError
(
'cache_kv can only used in inference'
)
self
.
cache_kv
=
np
.
random
.
rand
(
2
,
self
.
batch_size
,
self
.
num_heads
,
self
.
cache_length
,
self
.
head_dim
).
astype
(
self
.
x_type
)
out_seq_len
+=
self
.
cache_length
else
:
self
.
cache_kv
=
None
if
self
.
has_attn_mask
:
# [B, n_head, seq_len, out_seq_len]
self
.
attn_mask
=
np
.
ones
(
(
self
.
batch_size
,
self
.
num_heads
,
self
.
query_length
,
self
.
key_length
),
out_seq_len
),
dtype
=
self
.
attn_mask_type
)
if
self
.
attn_mask_type
==
np
.
int64
:
self
.
attn_mask
=
np
.
tril
(
self
.
attn_mask
)
...
...
@@ -110,6 +124,11 @@ class TestFusedAttentionOp(OpTest):
def
GetBaselineOut
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
tensor_query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
cache_kv
=
None
if
self
.
has_cache_kv
:
cache_kv
=
paddle
.
to_tensor
(
self
.
cache_kv
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
...
...
@@ -130,6 +149,18 @@ class TestFusedAttentionOp(OpTest):
v
=
tensor
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
v_out
=
tensor
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
if
self
.
has_cache_kv
:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k
,
cache_v
=
paddle
.
split
(
cache_kv
,
2
)
cache_k
=
paddle
.
squeeze
(
cache_k
,
axis
=
0
)
cache_v
=
paddle
.
squeeze
(
cache_v
,
axis
=
0
)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
k_out
=
paddle
.
concat
([
cache_k
,
k_out
],
axis
=-
2
)
v_out
=
paddle
.
concat
([
cache_v
,
v_out
],
axis
=-
2
)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out
=
layers
.
matmul
(
x
=
q_out
,
y
=
k_out
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
...
...
@@ -146,6 +177,8 @@ class TestFusedAttentionOp(OpTest):
self
.
dropout_prob
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out
=
tensor
.
matmul
(
dropout_out
,
v_out
)
else
:
qktv_out
=
tensor
.
matmul
(
softmax_out
,
v_out
)
...
...
@@ -160,6 +193,10 @@ class TestFusedAttentionOp(OpTest):
final_out
=
self
.
norm1
(
residual_out
)
else
:
final_out
=
residual_out
if
self
.
has_cache_kv
:
return
final_out
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
tensor_query
.
grad
...
...
@@ -206,6 +243,9 @@ class TestFusedAttentionOp(OpTest):
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
))
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
cache_kv
=
None
if
self
.
has_cache_kv
:
cache_kv
=
paddle
.
to_tensor
(
self
.
cache_kv
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
...
...
@@ -219,8 +259,12 @@ class TestFusedAttentionOp(OpTest):
final_out
=
incubate_f
.
fused_multi_head_attention
(
x
,
qkv_weight_tensor
,
out_linear_weight
,
self
.
pre_layer_norm
,
ln1_scale
,
ln1_bias
,
ln2_scale
,
ln2_bias
,
epsilon
,
qkv_bias_tensor
,
out_linear_bias
,
attn_mask
,
self
.
dropout_prob
,
out_linear_bias
,
cache_kv
,
attn_mask
,
self
.
dropout_prob
,
self
.
attn_dropout_prob
,
ln2_epsilon
)
if
self
.
has_cache_kv
:
return
final_out
[
0
],
final_out
[
1
]
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
x
.
grad
...
...
@@ -236,114 +280,27 @@ class TestFusedAttentionOp(OpTest):
class
TestFusedAttentionOpBiasIsNone
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
super
().
config
()
self
.
bias_attr
=
False
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
super
().
config
()
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
super
().
config
()
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
False
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
...
...
@@ -354,5 +311,21 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
class
TestFusedAttentionOpCacheKV
(
TestFusedAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
training
=
False
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
def
test_fused_attention_op
(
self
):
with
paddle
.
no_grad
():
final_out_ref
=
self
.
GetBaselineOut
()
final_out
,
cache_kv_out
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py
0 → 100644
浏览文件 @
1882c496
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
test_dist_base
import
TestDistBase
import
os
import
paddle
paddle
.
enable_static
()
flag_name
=
os
.
path
.
splitext
(
__file__
)[
0
]
class
TestStaticModelParallel
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_use_reduce
=
False
self
.
_use_reader_alloc
=
False
self
.
_nccl_comm_num
=
1
self
.
_pipeline_mode
=
True
def
test_dist_static_model_parallel_fused_feedforward
(
self
):
import
paddle.fluid
as
fluid
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"static_model_parallel_fused_attention.py"
,
delta
=
1e-5
,
check_error_log
=
True
,
log_name
=
flag_name
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/nn/functional/fused_transformer.py
浏览文件 @
1882c496
...
...
@@ -223,12 +223,14 @@ def fused_multi_head_attention(x,
pre_ln_epsilon
=
1e-05
,
qkv_bias
=
None
,
linear_bias
=
None
,
cache_kv
=
None
,
attn_mask
=
None
,
dropout_rate
=
0.5
,
attn_dropout_rate
=
0.5
,
ln_epsilon
=
1e-05
,
training
=
True
,
mode
=
'upscale_in_train'
,
ring_id
=-
1
,
name
=
None
):
r
"""
Attention mapps queries and a set of key-value pairs to outputs, and
...
...
@@ -242,8 +244,8 @@ def fused_multi_head_attention(x,
out = layer_norm(x)
out = linear(out) + qkv) + bias
else:
out = linear(x) + bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
out = linear(x) + bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
q = out[0:1,::]
k = out[1:2,::]
...
...
@@ -257,8 +259,8 @@ def fused_multi_head_attention(x,
out = out_linear(out)
if pre_layer_norm:
out = x + dropout(linear_bias + out)
else:
out = layer_norm(x + dropout(linear_bias + out))
else:
out = layer_norm(x + dropout(linear_bias + out))
Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
...
...
@@ -276,6 +278,7 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
...
...
@@ -303,6 +306,7 @@ def fused_multi_head_attention(x,
- train: out = input * mask
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
...
...
@@ -333,7 +337,7 @@ def fused_multi_head_attention(x,
output = F.fused_multi_head_attention(
x, qkv_weight, linear_weight, False,
None, None, None, None, 1e-5, qkv_bias,
linear_bias, attn_mask)
linear_bias,
None,
attn_mask)
# [2, 4, 128]
print(output.shape)
"""
...
...
@@ -359,17 +363,20 @@ def fused_multi_head_attention(x,
assert
qkv_weight
.
shape
[
1
]
*
qkv_weight
.
shape
[
2
]
==
qkv_weight
.
shape
[
3
],
"embed_dim must be divisible by num_heads."
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
final_out
=
_C_ops
.
fused_attention
(
x
,
pre_ln_scale
,
pre_ln_bias
,
qkv_weight
,
qkv_bias
,
attn_mask
,
linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
'pre_layer_norm'
,
pre_layer_norm
,
'epsilon'
,
pre_ln_epsilon
,
'dropout_rate'
,
dropout_rate
,
'attn_dropout_rate'
,
attn_dropout_rate
,
'ln_epsilon'
,
ln_epsilon
,
'attn_dropout_is_test'
,
not
training
,
'dropout_is_test'
,
not
training
,
'attn_dropout_fix_seed'
,
seed
is
not
None
,
'dropout_fix_seed'
,
seed
is
not
None
,
'attn_dropout_seed'
,
seed
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
cache_kv_out
,
final_out
=
_C_ops
.
fused_attention
(
x
,
pre_ln_scale
,
pre_ln_bias
,
qkv_weight
,
qkv_bias
,
cache_kv
,
attn_mask
,
linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
'pre_layer_norm'
,
pre_layer_norm
,
'epsilon'
,
pre_ln_epsilon
,
'dropout_rate'
,
dropout_rate
,
'attn_dropout_rate'
,
attn_dropout_rate
,
'ln_epsilon'
,
ln_epsilon
,
'attn_dropout_is_test'
,
not
training
,
'dropout_is_test'
,
not
training
,
'attn_dropout_fix_seed'
,
seed
is
not
None
,
'dropout_fix_seed'
,
seed
is
not
None
,
'attn_dropout_seed'
,
seed
if
seed
is
not
None
else
0
,
'dropout_seed'
,
seed
if
seed
is
not
None
else
0
,
'attn_dropout_implementation'
,
mode
,
'dropout_implementation'
,
mode
)
'dropout_implementation'
,
mode
,
'ring_id'
,
ring_id
)
if
cache_kv
is
not
None
:
return
final_out
,
cache_kv_out
return
final_out
else
:
helper
=
LayerHelper
(
'fused_multi_head_attention'
,
**
locals
())
...
...
@@ -398,6 +405,7 @@ def fused_multi_head_attention(x,
inputs
[
'Ln2Scale'
]
=
[
ln_scale
]
if
ln_bias
:
inputs
[
'Ln2Bias'
]
=
[
ln_bias
]
if
cache_kv
:
inputs
[
'CacheKV'
]
=
[
cache_kv
]
if
(
seed
is
None
or
seed
==
0
)
and
helper
.
main_program
.
random_seed
!=
0
:
seed
=
helper
.
main_program
.
random_seed
...
...
@@ -417,6 +425,7 @@ def fused_multi_head_attention(x,
'dropout_seed'
:
seed
if
seed
is
not
None
else
0
,
'attn_dropout_implementation'
:
mode
,
'dropout_implementation'
:
mode
,
'ring_id'
:
ring_id
}
# set outputs
...
...
@@ -449,6 +458,7 @@ def fused_multi_head_attention(x,
bias_dropout_residual_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
final_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
cache_kv_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'fused_attention'
,
...
...
@@ -472,7 +482,9 @@ def fused_multi_head_attention(x,
"Ln2Mean"
:
ln_mean_out
,
"Ln2Variance"
:
ln_variance_out
,
"BiasDropoutResidualOut"
:
bias_dropout_residual_out
,
'Y'
:
final_out
'Y'
:
final_out
,
'CacheKVOut'
:
cache_kv_out
},
attrs
=
attrs
)
return
final_out
return
(
final_out
,
cache_kv_out
)
if
cache_kv
else
final_out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录