Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1882c496
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1882c496
编写于
3月 11, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] Support tensor parallel and cache structure for fused attention op. (#40101)
上级
e24ca55e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
608 addition
and
147 deletion
+608
-147
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+30
-9
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+69
-12
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+59
-10
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+12
-7
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py
.../tests/unittests/static_model_parallel_fused_attention.py
+297
-0
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+66
-93
python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py
...s/unittests/test_static_model_parallel_fused_attention.py
+45
-0
python/paddle/incubate/nn/functional/fused_transformer.py
python/paddle/incubate/nn/functional/fused_transformer.py
+28
-16
未找到文件。
paddle/fluid/operators/fused/fmha_ref.h
浏览文件 @
1882c496
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -69,20 +70,21 @@ class FMHARef {
...
@@ -69,20 +70,21 @@ class FMHARef {
~
FMHARef
()
{}
~
FMHARef
()
{}
void
ComputeForward
(
const
Tensor
&
qkv_input_tensor
,
void
ComputeForward
(
const
Tensor
&
qkv_input_tensor
,
const
Tensor
*
cache_kv_tensor
,
const
Tensor
*
src_mask_tensor
,
const
Tensor
*
src_mask_tensor
,
Tensor
*
transpose_2_out_tensor
,
Tensor
*
qk_out_tensor
,
Tensor
*
transpose_2_out_tensor
,
Tensor
*
cache_kv_out_tensor
,
Tensor
*
qk_out_tensor
,
Tensor
*
src_mask_out_tensor
,
Tensor
*
softmax_out_tensor
,
Tensor
*
src_mask_out_tensor
,
Tensor
*
softmax_out_tensor
,
Tensor
*
dropout_mask_out_tensor
,
Tensor
*
dropout_mask_out_tensor
,
Tensor
*
dropout_out_tensor
,
Tensor
*
qktv_out_tensor
,
Tensor
*
dropout_out_tensor
,
Tensor
*
qktv_out_tensor
,
Tensor
*
fmha_out_tensor
)
{
Tensor
*
fmha_out_tensor
)
{
// input shape: [bs, seq_len, 3, num_head, head_dim]
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0,
1, 3
, 4],
// transpose with perm [2, 0,
3, 1
, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
// output_shape: [3, bs, num_head, seq_len, head_dim]
int
ndims
=
5
;
int
ndims
=
5
;
std
::
vector
<
int
>
perm_1
=
{
2
,
0
,
3
,
1
,
4
};
std
::
vector
<
int
>
perm_1
=
{
2
,
0
,
3
,
1
,
4
};
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
qkv_input_tensor
,
perm_1
,
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
qkv_input_tensor
,
perm_1
,
transpose_2_out_tensor
);
transpose_2_out_tensor
);
T
*
qkv_data
=
transpose_2_out_tensor
->
data
<
T
>
();
T
*
qkv_data
=
transpose_2_out_tensor
->
data
<
T
>
();
T
*
qk_out_data
=
qk_out_tensor
->
data
<
T
>
();
T
*
qk_out_data
=
qk_out_tensor
->
data
<
T
>
();
T
*
qktv_out_data
=
qktv_out_tensor
->
data
<
T
>
();
T
*
qktv_out_data
=
qktv_out_tensor
->
data
<
T
>
();
...
@@ -90,11 +92,30 @@ class FMHARef {
...
@@ -90,11 +92,30 @@ class FMHARef {
T
*
dropout_out_data
=
dropout_out_tensor
->
data
<
T
>
();
T
*
dropout_out_data
=
dropout_out_tensor
->
data
<
T
>
();
T
*
fmha_out_data
=
fmha_out_tensor
->
data
<
T
>
();
T
*
fmha_out_data
=
fmha_out_tensor
->
data
<
T
>
();
int
q_size
=
batch_size_
*
seq_len_
*
num_head_
*
head_dim_
;
auto
out_seq_len
=
seq_len_
;
int
k_size
=
q_size
;
if
(
cache_kv_tensor
)
{
// kv [2, bs, num_head, seq_len, head_dim]
auto
kv_tensor
=
transpose_2_out_tensor
->
Slice
(
1
,
3
);
phi
::
funcs
::
ConcatFunctor
<
phi
::
GPUContext
,
T
>
concat
;
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat
(
dev_ctx_
,
{
*
cache_kv_tensor
,
kv_tensor
},
3
,
cache_kv_out_tensor
);
out_seq_len
=
cache_kv_out_tensor
->
dims
()[
3
];
}
int64_t
q_size
=
batch_size_
*
seq_len_
*
num_head_
*
head_dim_
;
T
*
q_ptr
=
qkv_data
;
T
*
q_ptr
=
qkv_data
;
T
*
k_ptr
=
q_ptr
+
q_size
;
T
*
k_ptr
=
nullptr
;
T
*
v_ptr
=
k_ptr
+
k_size
;
T
*
v_ptr
=
nullptr
;
if
(
cache_kv_tensor
)
{
int64_t
k_size
=
cache_kv_out_tensor
->
numel
()
/
2
;
k_ptr
=
cache_kv_out_tensor
->
data
<
T
>
();
v_ptr
=
k_ptr
+
k_size
;
}
else
{
int64_t
k_size
=
q_size
;
k_ptr
=
q_ptr
+
q_size
;
v_ptr
=
k_ptr
+
k_size
;
}
// q*k^t, batched_gemm
// q*k^t, batched_gemm
CBLAS_TRANSPOSE
transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
transA
=
CblasNoTrans
;
...
@@ -102,7 +123,7 @@ class FMHARef {
...
@@ -102,7 +123,7 @@ class FMHARef {
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
int
gemm_batch_size
=
batch_size_
*
num_head_
;
int
gemm_batch_size
=
batch_size_
*
num_head_
;
int
gemm_m
=
seq_len_
;
int
gemm_m
=
seq_len_
;
int
gemm_n
=
seq_len_
;
int
gemm_n
=
out_seq_len
;
int
gemm_k
=
head_dim_
;
int
gemm_k
=
head_dim_
;
T
alpha
=
static_cast
<
T
>
(
1.0
/
sqrt
(
head_dim_
));
T
alpha
=
static_cast
<
T
>
(
1.0
/
sqrt
(
head_dim_
));
T
beta
=
static_cast
<
T
>
(
0.0
);
T
beta
=
static_cast
<
T
>
(
0.0
);
...
@@ -133,7 +154,7 @@ class FMHARef {
...
@@ -133,7 +154,7 @@ class FMHARef {
transB
=
CblasNoTrans
;
transB
=
CblasNoTrans
;
gemm_m
=
seq_len_
;
gemm_m
=
seq_len_
;
gemm_n
=
head_dim_
;
gemm_n
=
head_dim_
;
gemm_k
=
seq_len_
;
gemm_k
=
out_seq_len
;
alpha
=
static_cast
<
T
>
(
1.0
);
alpha
=
static_cast
<
T
>
(
1.0
);
stride_a
=
gemm_m
*
gemm_k
;
stride_a
=
gemm_m
*
gemm_k
;
stride_b
=
gemm_k
*
gemm_n
;
stride_b
=
gemm_k
*
gemm_n
;
...
...
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
1882c496
...
@@ -61,6 +61,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -61,6 +61,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKTVOut"
),
"Output"
,
"QKTVOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKTVOut"
),
"Output"
,
"QKTVOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
if
(
ctx
->
HasInput
(
"CacheKV"
))
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"CacheKVOut"
),
"Output"
,
"CacheKVOut"
,
"FusedAttentionOp"
);
}
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SrcMaskOut"
),
"Output"
,
"SrcMaskOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SrcMaskOut"
),
"Output"
,
"SrcMaskOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
...
@@ -105,12 +109,14 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -105,12 +109,14 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]"
,
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
x_dim
,
y_dim
));
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
if
(
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
)
==
-
1
)
{
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
"The dimensions of qkv_weight must be 4"
platform
::
errors
::
InvalidArgument
(
"(3, num_head, dim_head, dim_embed),"
"The dimensions of qkv_weight must be 4"
"and must satisfy the limitations: "
"(3, num_head, dim_head, dim_embed),"
"(num_head * dim_head == dim_embed)"
));
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
}
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
...
@@ -132,20 +138,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -132,20 +138,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [3, batch_size, num_head, seq_len, head_size]
// [3, batch_size, num_head, seq_len, head_size]
ctx
->
SetOutputDim
(
"TransposeOut2"
,
ctx
->
SetOutputDim
(
"TransposeOut2"
,
{
y_dim
[
0
],
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
{
y_dim
[
0
],
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
// [batch, num_head, seq_len, seq_len]
ctx
->
SetOutputDim
(
"QKOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
// cache_seq_len + seq_len if cache else seq_len
auto
out_seq_len
=
x_dim
[
1
];
if
(
ctx
->
HasInput
(
"CacheKV"
))
{
// [2, batch_size, num_head, cache_seq_len, head_size]
auto
c_dim
=
ctx
->
GetInputDim
(
"CacheKV"
);
PADDLE_ENFORCE_EQ
(
c_dim
.
size
(),
5
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The CacheKV must be 5 dims, but got %d"
,
c_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
c_dim
[
0
],
2
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The first dim of CacheKV must be 2, but got %d"
,
c_dim
[
0
]));
// 2
PADDLE_ENFORCE_EQ
(
c_dim
[
1
],
x_dim
[
0
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d"
,
x_dim
[
0
],
c_dim
[
1
]));
// batch_size
PADDLE_ENFORCE_EQ
(
c_dim
[
2
],
y_dim
[
1
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d"
,
y_dim
[
1
],
c_dim
[
2
]));
// num_head
PADDLE_ENFORCE_GE
(
c_dim
[
3
],
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The forth dim of CacheKV must be greater than 0, but got %d"
,
c_dim
[
3
]));
// cache_seq_len
PADDLE_ENFORCE_EQ
(
c_dim
[
4
],
y_dim
[
2
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d"
,
y_dim
[
2
],
c_dim
[
4
]));
// head_size
out_seq_len
+=
c_dim
[
3
];
// [3, batch_size, num_head, cache_seq_len + seq_len, head_size]
ctx
->
SetOutputDim
(
"CacheKVOut"
,
{
c_dim
[
0
],
c_dim
[
1
],
c_dim
[
2
],
out_seq_len
,
c_dim
[
4
]});
}
// [batch, num_head, seq_len, out_seq_len]
ctx
->
SetOutputDim
(
"QKOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
ctx
->
SetOutputDim
(
"SrcMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"SrcMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
}
}
// the same as QKOut's shape.
// the same as QKOut's shape.
ctx
->
SetOutputDim
(
"AttnDropoutOut"
,
ctx
->
SetOutputDim
(
"AttnDropoutOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]
});
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"attn_dropout_is_test"
)
==
false
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"attn_dropout_is_test"
)
==
false
)
{
ctx
->
SetOutputDim
(
"AttnDropoutMaskOut"
,
ctx
->
SetOutputDim
(
"AttnDropoutMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]
});
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
}
}
ctx
->
SetOutputDim
(
"SoftmaxOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"SoftmaxOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
out_seq_len
});
// [batch_size, num_heads, seq_len, head_dim]
// [batch_size, num_heads, seq_len, head_dim]
ctx
->
SetOutputDim
(
"QKTVOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
ctx
->
SetOutputDim
(
"QKTVOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
// [batch_size, seq_len, number of heads*head size]
// [batch_size, seq_len, number of heads*head size]
...
@@ -182,6 +232,8 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -182,6 +232,8 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"QKVW"
,
"The qkv weight tensor."
);
AddInput
(
"QKVW"
,
"The qkv weight tensor."
);
AddInput
(
"QKVBias"
,
"The qkv bias tensor."
).
AsDispensable
();
AddInput
(
"QKVBias"
,
"The qkv bias tensor."
).
AsDispensable
();
AddInput
(
"CacheKV"
,
"(optional) The cached KV for generation inference."
)
.
AsDispensable
();
AddInput
(
"SrcMask"
,
"(optional) The attention mask tensor in fmha."
)
AddInput
(
"SrcMask"
,
"(optional) The attention mask tensor in fmha."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"OutLinearW"
,
"The out_linear weight tensor."
);
AddInput
(
"OutLinearW"
,
"The out_linear weight tensor."
);
...
@@ -217,6 +269,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -217,6 +269,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"BiasDropoutResidualOut"
,
AddOutput
(
"BiasDropoutResidualOut"
,
"Result of residual + dropout(src + bias)."
)
"Result of residual + dropout(src + bias)."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"CacheKVOut"
,
"The udpated cache KV."
);
AddOutput
(
"Y"
,
"Result after attention."
);
AddOutput
(
"Y"
,
"Result after attention."
);
AddAttr
<
bool
>
(
"pre_layer_norm"
,
AddAttr
<
bool
>
(
"pre_layer_norm"
,
...
@@ -324,6 +377,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -324,6 +377,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s]."
,
"0.0 and 0.001, But received [%s]."
,
ln_epsilon
));
ln_epsilon
));
});
});
AddAttr
<
int
>
(
"ring_id"
,
"ring id for tensor model parallel. distributed training and inference"
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Add fused attention op whose logic is as follows:
Add fused attention op whose logic is as follows:
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
1882c496
...
@@ -27,11 +27,39 @@ limitations under the License. */
...
@@ -27,11 +27,39 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
static
void
AllReduce
(
framework
::
Tensor
&
tensor
,
// NOLINT
const
int
ring_id
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
if
(
ring_id
==
-
1
)
return
;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
dtype
=
platform
::
ToNCCLDataType
(
framework
::
TransToProtoVarType
(
tensor
.
dtype
()));
int64_t
numel
=
tensor
.
numel
();
const
void
*
sendbuff
=
tensor
.
data
<
T
>
();
auto
place
=
ctx
.
GetPlace
();
void
*
recvbuff
=
tensor
.
mutable_data
<
T
>
(
place
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
auto
stream
=
ctx
.
stream
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
dtype
,
ncclSum
,
comm
->
comm
(),
stream
));
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."
));
#endif
}
template
<
typename
T
>
template
<
typename
T
>
class
FusedAttentionOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FusedAttentionOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -56,6 +84,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -56,6 +84,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
src_mask
=
ctx
.
Input
<
Tensor
>
(
"SrcMask"
);
auto
*
src_mask
=
ctx
.
Input
<
Tensor
>
(
"SrcMask"
);
auto
*
transpose_out_2
=
ctx
.
Output
<
Tensor
>
(
"TransposeOut2"
);
auto
*
transpose_out_2
=
ctx
.
Output
<
Tensor
>
(
"TransposeOut2"
);
auto
*
cache_kv
=
ctx
.
Input
<
Tensor
>
(
"CacheKV"
);
auto
*
cache_kv_out
=
ctx
.
Output
<
Tensor
>
(
"CacheKVOut"
);
auto
*
qk_out
=
ctx
.
Output
<
Tensor
>
(
"QKOut"
);
auto
*
qk_out
=
ctx
.
Output
<
Tensor
>
(
"QKOut"
);
auto
*
qktv_out
=
ctx
.
Output
<
Tensor
>
(
"QKTVOut"
);
auto
*
qktv_out
=
ctx
.
Output
<
Tensor
>
(
"QKTVOut"
);
auto
*
softmax_out
=
ctx
.
Output
<
Tensor
>
(
"SoftmaxOut"
);
auto
*
softmax_out
=
ctx
.
Output
<
Tensor
>
(
"SoftmaxOut"
);
...
@@ -86,6 +116,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -86,6 +116,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
seed_1
=
ctx
.
HasInput
(
"Seed1"
)
?
ctx
.
Input
<
Tensor
>
(
"Seed1"
)
:
nullptr
;
auto
*
seed_1
=
ctx
.
HasInput
(
"Seed1"
)
?
ctx
.
Input
<
Tensor
>
(
"Seed1"
)
:
nullptr
;
bool
is_fix_seed_1
=
ctx
.
Attr
<
bool
>
(
"attn_dropout_fix_seed"
);
bool
is_fix_seed_1
=
ctx
.
Attr
<
bool
>
(
"attn_dropout_fix_seed"
);
int
seed_val_1
=
ctx
.
Attr
<
int
>
(
"attn_dropout_seed"
);
int
seed_val_1
=
ctx
.
Attr
<
int
>
(
"attn_dropout_seed"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// final output.
// final output.
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
...
@@ -105,6 +136,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -105,6 +136,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// get data ptr for FMHA.
// get data ptr for FMHA.
auto
*
transpose_out_2_data
=
auto
*
transpose_out_2_data
=
transpose_out_2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
transpose_out_2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
cache_kv_out_data
=
(
cache_kv_out
==
nullptr
)
?
nullptr
:
cache_kv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qk_out_data
=
qk_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qk_out_data
=
qk_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qktv_out_data
=
qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qktv_out_data
=
qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
src_mask_out_data
=
auto
*
src_mask_out_data
=
...
@@ -161,9 +196,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -161,9 +196,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
output_size
=
hidden_size
;
output_size
=
hidden_size
;
// (transA, transB, compute_bias) = (false, false, false)
// (transA, transB, compute_bias) = (false, false, false)
// NOTE(Yuang Liu): For general input size == output size, change the
// position won't have effects. For mp, the output size is mp_head * dkey
// which is actually the input size. While the input size is hidden size,
// which is actually the output size. So for out linear, switch the
// input size and output size.
auto
out_linear_compute
=
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
false
,
false
,
bsz_seq
,
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
false
,
false
,
bsz_seq
,
output_size
,
in
put_size
,
false
);
input_size
,
out
put_size
,
false
);
DropoutParam
dropout_param2
(
ctx
,
0
);
DropoutParam
dropout_param2
(
ctx
,
0
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper
(
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper
(
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
...
@@ -186,15 +226,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -186,15 +226,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_bias_out
);
qkv_bias_out
);
}
}
if
(
qkv_bias
==
nullptr
)
{
if
(
qkv_bias
==
nullptr
)
{
fmha_ref_compute
.
ComputeForward
(
*
qkv_out
,
src_mask
,
transpose_out_2
,
fmha_ref_compute
.
ComputeForward
(
qk_out
,
src_mask_out
,
softmax
_out
,
*
qkv_out
,
cache_kv
,
src_mask
,
transpose_out_2
,
cache_kv_out
,
qk
_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
qktv_out
,
fmha_out
);
}
else
{
}
else
{
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
src_mask
,
transpose_out_2
,
fmha_ref_compute
.
ComputeForward
(
qk_out
,
src_mask_out
,
softmax
_out
,
*
qkv_bias_out
,
cache_kv
,
src_mask
,
transpose_out_2
,
cache_kv
_out
,
attn_dropout_mask_out
,
attn_dropout
_out
,
qk_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask
_out
,
qktv_out
,
fmha_out
);
attn_dropout_out
,
qktv_out
,
fmha_out
);
}
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// fmha_out: [batch_size, seq_len, num_head, head_dim]
...
@@ -202,6 +242,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -202,6 +242,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// out_linear_out: [batch_size, seq_len, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute
.
ComputeForward
(
out_linear_weight
,
fmha_out
,
nullptr
,
out_linear_compute
.
ComputeForward
(
out_linear_weight
,
fmha_out
,
nullptr
,
out_linear_out
,
nullptr
);
out_linear_out
,
nullptr
);
// tensor model parallel
AllReduce
<
T
>
(
*
out_linear_out
,
ring_id
,
ctx
.
cuda_device_context
());
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
// output = (residual + dropout(input + bias))
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper
.
ResidualDropoutBias
(
fused_dropout_layernorm_helper
.
ResidualDropoutBias
(
...
@@ -244,6 +287,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -244,6 +287,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
seed_1
=
ctx
.
HasInput
(
"Seed1"
)
?
ctx
.
Input
<
Tensor
>
(
"Seed1"
)
:
nullptr
;
auto
*
seed_1
=
ctx
.
HasInput
(
"Seed1"
)
?
ctx
.
Input
<
Tensor
>
(
"Seed1"
)
:
nullptr
;
bool
is_fix_seed_1
=
ctx
.
Attr
<
bool
>
(
"attn_dropout_fix_seed"
);
bool
is_fix_seed_1
=
ctx
.
Attr
<
bool
>
(
"attn_dropout_fix_seed"
);
int
seed_val_1
=
ctx
.
Attr
<
int
>
(
"attn_dropout_seed"
);
int
seed_val_1
=
ctx
.
Attr
<
int
>
(
"attn_dropout_seed"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// get inputs.
// get inputs.
auto
*
d_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
d_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
...
@@ -399,9 +443,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -399,9 +443,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
transA
=
false
;
transA
=
false
;
transB
=
false
;
transB
=
false
;
bool
compute_bias
=
false
;
bool
compute_bias
=
false
;
// (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
auto
out_linear_compute
=
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
output_size
,
in
put_size
,
compute_bias
);
input_size
,
out
put_size
,
compute_bias
);
DropoutParam
dropout_param2
(
ctx
,
0
);
DropoutParam
dropout_param2
(
ctx
,
0
);
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper
(
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
fused_dropout_layernorm_helper
(
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
...
@@ -475,6 +520,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -475,6 +520,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
qkv_compute
.
ComputeBackward
(
ln_out
,
qkv_weight
,
d_qkv_out
,
d_ln_out
,
qkv_compute
.
ComputeBackward
(
ln_out
,
qkv_weight
,
d_qkv_out
,
d_ln_out
,
d_qkv_weight
,
d_qkv_bias
);
d_qkv_weight
,
d_qkv_bias
);
}
}
// tensor model parallel
AllReduce
<
T
>
(
*
d_ln_out
,
ring_id
,
ctx
.
cuda_device_context
());
layer_norm_compute
.
ComputeBackward
(
x_data
,
d_ln_out_data
,
ln_scale_data
,
layer_norm_compute
.
ComputeBackward
(
x_data
,
d_ln_out_data
,
ln_scale_data
,
ln_mean_data
,
ln_var_data
,
d_x_data
,
ln_mean_data
,
ln_var_data
,
d_x_data
,
d_ln_scale_data
,
d_ln_bias_data
);
d_ln_scale_data
,
d_ln_bias_data
);
...
@@ -486,6 +533,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -486,6 +533,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
qkv_compute
.
ComputeBackward
(
input_x
,
qkv_weight
,
d_qkv_out
,
d_x
,
qkv_compute
.
ComputeBackward
(
input_x
,
qkv_weight
,
d_qkv_out
,
d_x
,
d_qkv_weight
,
d_qkv_bias
);
d_qkv_weight
,
d_qkv_bias
);
}
}
// tensor model parallel
AllReduce
<
T
>
(
*
d_x
,
ring_id
,
ctx
.
cuda_device_context
());
}
}
// gradient accumulation
// gradient accumulation
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
const
Tensor
*>
ins
;
...
...
paddle/fluid/pybind/op_function_generator.h
浏览文件 @
1882c496
...
@@ -30,8 +30,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
...
@@ -30,8 +30,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"bincount"
,
{
"X"
,
"Weights"
}},
{
"bincount"
,
{
"X"
,
"Weights"
}},
{
"fused_attention"
,
{
"fused_attention"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"
SrcMask"
,
"OutLinearW
"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"
CacheKV"
,
"SrcMask
"
,
"OutLinearBias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
"OutLinear
W"
,
"OutLinear
Bias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"instance_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"instance_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
...
@@ -104,11 +104,16 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
...
@@ -104,11 +104,16 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{
"batch_norm"
,
{
"batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
"ReserveSpace"
}},
{
"fused_attention"
,
{
"fused_attention"
,
{
"LnMean"
,
"LnVariance"
,
{
"LnMean"
,
"LnVariance"
,
"LnOut"
,
"QKVOut"
,
"QKVBiasOut"
,
"TransposeOut2"
,
"LnOut"
,
"QKVOut"
,
"QKOut"
,
"QKTVOut"
,
"SoftmaxOut"
,
"AttnDropoutMaskOut"
,
"AttnDropoutOut"
,
"QKVBiasOut"
,
"TransposeOut2"
,
"SrcMaskOut"
,
"FMHAOut"
,
"OutLinearOut"
,
"DropoutMaskOut"
,
"Ln2Mean"
,
"QKOut"
,
"QKTVOut"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"Y"
}},
"SoftmaxOut"
,
"AttnDropoutMaskOut"
,
"AttnDropoutOut"
,
"SrcMaskOut"
,
"FMHAOut"
,
"OutLinearOut"
,
"DropoutMaskOut"
,
"Ln2Mean"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"CacheKVOut"
,
"Y"
}},
{
"sync_batch_norm"
,
{
"sync_batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
"ReserveSpace"
}},
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
1882c496
...
@@ -24,6 +24,7 @@ list(APPEND DIST_TEST_OPS test_pipeline)
...
@@ -24,6 +24,7 @@ list(APPEND DIST_TEST_OPS test_pipeline)
list
(
APPEND DIST_TEST_OPS test_ir_pass_pipeline
)
list
(
APPEND DIST_TEST_OPS test_ir_pass_pipeline
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
...
@@ -1155,6 +1156,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
...
@@ -1155,6 +1156,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties
(
test_ir_pass_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_ir_pass_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_static_model_parallel PROPERTIES TIMEOUT 240
)
set_tests_properties
(
test_static_model_parallel PROPERTIES TIMEOUT 240
)
set_tests_properties
(
test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_split_embedding
set_tests_properties
(
test_collective_split_embedding
test_collective_split_embedding_none_divisible
test_collective_split_embedding_none_divisible
test_collective_split_row_linear
test_collective_split_row_linear
...
...
python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py
0 → 100644
浏览文件 @
1882c496
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
import
paddle.distributed.fleet
as
fleet
import
paddle.incubate.nn.functional
as
incubate_f
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid
import
core
from
paddle.nn.initializer
import
Constant
paddle
.
enable_static
()
def
_set_var_distributed
(
var
):
if
var
is
None
:
return
var
.
is_distributed
=
True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block
=
paddle
.
static
.
default_startup_program
().
current_block
()
main_block
=
paddle
.
static
.
default_main_program
().
current_block
()
startup_block
.
_find_var_recursive
(
var
.
name
).
is_distributed
=
True
main_block
.
_find_var_recursive
(
var
.
name
).
is_distributed
=
True
class
ParallelFusedMultiHeadAttention
(
Layer
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout_rate
=
0.5
,
attn_dropout_rate
=
0.5
,
kdim
=
None
,
vdim
=
None
,
normalize_before
=
False
,
need_weights
=
False
,
qkv_weight_attr
=
None
,
qkv_bias_attr
=
None
,
linear_weight_attr
=
None
,
linear_bias_attr
=
None
,
pre_ln_scale_attr
=
None
,
pre_ln_bias_attr
=
None
,
ln_scale_attr
=
None
,
ln_bias_attr
=
None
,
epsilon
=
1e-5
,
nranks
=
1
,
ring_id
=-
1
,
name
=
None
):
super
(
ParallelFusedMultiHeadAttention
,
self
).
__init__
()
assert
embed_dim
>
0
,
(
"Expected embed_dim to be greater than 0, "
"but recieved {}"
.
format
(
embed_dim
))
assert
num_heads
>
0
,
(
"Expected nhead to be greater than 0, "
"but recieved {}"
.
format
(
num_heads
))
self
.
normalize_before
=
normalize_before
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_epsilon
=
epsilon
self
.
_ring_id
=
ring_id
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
kdim
=
kdim
self
.
vdim
=
vdim
self
.
need_weights
=
need_weights
assert
self
.
head_dim
*
num_heads
==
embed_dim
,
"embed_dim must be divisible by num_heads"
assert
need_weights
==
False
,
"Only support need_weight is False now."
# tensor model parallel
assert
num_heads
%
nranks
==
0
num_heads
=
num_heads
//
nranks
self
.
qkv_weight
=
self
.
create_parameter
(
shape
=
[
3
,
num_heads
,
self
.
head_dim
,
embed_dim
],
attr
=
qkv_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
qkv_bias
=
self
.
create_parameter
(
shape
=
[
3
,
num_heads
,
self
.
head_dim
],
attr
=
qkv_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
self
.
linear_weight
=
self
.
create_parameter
(
shape
=
[
num_heads
*
self
.
head_dim
,
embed_dim
],
attr
=
linear_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
linear_bias
=
self
.
create_parameter
(
shape
=
[
embed_dim
],
attr
=
linear_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
# tensor model parallel
if
nranks
>
1
:
assert
ring_id
!=
-
1
# column parallel
_set_var_distributed
(
self
.
qkv_weight
)
_set_var_distributed
(
self
.
qkv_bias
)
# row parallel
_set_var_distributed
(
self
.
linear_weight
)
if
normalize_before
:
self
.
pre_ln_scale
=
self
.
create_parameter
(
attr
=
pre_ln_scale_attr
,
shape
=
[
embed_dim
],
default_initializer
=
Constant
(
value
=
1.0
))
self
.
pre_ln_bias
=
self
.
create_parameter
(
attr
=
pre_ln_bias_attr
,
shape
=
[
embed_dim
],
is_bias
=
True
)
self
.
ln_scale
=
None
self
.
ln_bias
=
None
else
:
self
.
pre_ln_scale
=
None
self
.
pre_ln_bias
=
None
self
.
ln_scale
=
self
.
create_parameter
(
attr
=
ln_scale_attr
,
shape
=
[
embed_dim
],
default_initializer
=
Constant
(
value
=
1.0
))
self
.
ln_bias
=
self
.
create_parameter
(
attr
=
ln_bias_attr
,
shape
=
[
embed_dim
],
is_bias
=
True
)
self
.
dropout_rate
=
dropout_rate
self
.
attn_dropout_rate
=
attn_dropout_rate
self
.
name
=
name
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
attn_mask
=
None
,
cache
=
None
):
out
=
incubate_f
.
fused_multi_head_attention
(
x
=
query
,
qkv_weight
=
self
.
qkv_weight
,
linear_weight
=
self
.
linear_weight
,
pre_layer_norm
=
self
.
normalize_before
,
pre_ln_scale
=
self
.
pre_ln_scale
,
pre_ln_bias
=
self
.
pre_ln_bias
,
ln_scale
=
self
.
ln_scale
,
ln_bias
=
self
.
ln_bias
,
pre_ln_epsilon
=
self
.
_epsilon
,
qkv_bias
=
self
.
qkv_bias
,
linear_bias
=
self
.
linear_bias
,
attn_mask
=
attn_mask
,
dropout_rate
=
self
.
dropout_rate
,
attn_dropout_rate
=
self
.
attn_dropout_rate
,
ln_epsilon
=
self
.
_epsilon
,
training
=
self
.
training
,
ring_id
=
self
.
_ring_id
,
name
=
self
.
name
)
return
out
def
get_param_attr
(
weight
,
bias
):
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
weight
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
bias
))
return
weight_attr
,
bias_attr
DTYPE
=
"float32"
MODEL_PARALLEL_SIZE
=
2
n_head
=
2
*
MODEL_PARALLEL_SIZE
d_key
=
4
hidden
=
n_head
*
d_key
def
create_model
(
data
,
rank
):
np
.
random
.
seed
(
2021
)
pre_ln_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
hidden
,
)).
astype
(
DTYPE
)
pre_ln_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
hidden
,
)).
astype
(
DTYPE
)
qkv_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
3
,
n_head
,
d_key
,
hidden
)).
astype
(
DTYPE
)
qkv_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
3
,
n_head
,
d_key
)).
astype
(
DTYPE
)
linear_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
n_head
*
d_key
,
hidden
)).
astype
(
DTYPE
)
linear_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
hidden
,
)).
astype
(
DTYPE
)
data
.
stop_gradient
=
False
if
rank
is
not
None
:
start
=
0
if
rank
==
0
else
n_head
//
MODEL_PARALLEL_SIZE
end
=
start
+
n_head
//
MODEL_PARALLEL_SIZE
col_qkv_w
=
qkv_w
[:,
start
:
end
,
:,
:]
col_qkv_b
=
qkv_b
[:,
start
:
end
,
:]
row_linear_w
=
linear_w
[(
start
*
d_key
):(
end
*
d_key
),
:]
pre_ln_w_attr
,
pre_ln_b_attr
=
get_param_attr
(
pre_ln_w
,
pre_ln_b
)
qkv_w_attr
,
qkv_b_attr
=
get_param_attr
(
col_qkv_w
,
col_qkv_b
)
linear_w_attr
,
linear_b_attr
=
get_param_attr
(
row_linear_w
,
linear_b
)
attn
=
ParallelFusedMultiHeadAttention
(
hidden
,
n_head
,
dropout_rate
=
0.0
,
attn_dropout_rate
=
0.0
,
normalize_before
=
False
,
qkv_weight_attr
=
qkv_w_attr
,
qkv_bias_attr
=
qkv_b_attr
,
linear_weight_attr
=
linear_w_attr
,
linear_bias_attr
=
linear_b_attr
,
pre_ln_scale_attr
=
pre_ln_w_attr
,
pre_ln_bias_attr
=
pre_ln_b_attr
,
ln_scale_attr
=
pre_ln_w_attr
,
ln_bias_attr
=
pre_ln_b_attr
,
nranks
=
MODEL_PARALLEL_SIZE
,
ring_id
=
0
)
result
=
attn
(
data
)
else
:
pre_ln_w_attr
,
pre_ln_b_attr
=
get_param_attr
(
pre_ln_w
,
pre_ln_b
)
qkv_w_attr
,
qkv_b_attr
=
get_param_attr
(
qkv_w
,
qkv_b
)
linear_w_attr
,
linear_b_attr
=
get_param_attr
(
linear_w
,
linear_b
)
attn
=
ParallelFusedMultiHeadAttention
(
hidden
,
n_head
,
dropout_rate
=
0.0
,
attn_dropout_rate
=
0.0
,
normalize_before
=
False
,
qkv_weight_attr
=
qkv_w_attr
,
qkv_bias_attr
=
qkv_b_attr
,
linear_weight_attr
=
linear_w_attr
,
linear_bias_attr
=
linear_b_attr
,
pre_ln_scale_attr
=
pre_ln_w_attr
,
pre_ln_bias_attr
=
pre_ln_b_attr
,
ln_scale_attr
=
pre_ln_w_attr
,
ln_bias_attr
=
pre_ln_b_attr
)
result
=
attn
(
data
)
predict
=
paddle
.
sum
(
result
)
return
predict
class
TestModelParallel
(
TestDistRunnerBase
):
def
get_model
(
self
,
batch_size
=
2
,
use_dgc
=
False
,
dist_strategy
=
None
):
# Input data
seq_len
=
2
data_in
=
fluid
.
data
(
name
=
'data_in'
,
shape
=
[
batch_size
,
seq_len
,
hidden
],
dtype
=
DTYPE
)
if
dist_strategy
:
data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
data_in
],
capacity
=
64
,
use_double_buffer
=
False
,
iterable
=
False
)
if
dist_strategy
:
fleet
.
init
(
is_collective
=
True
)
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
tensor_parallel
=
True
strategy
.
tensor_parallel_configs
=
{
'tensor_parallel_degree'
:
2
}
rank
=
fleet
.
worker_index
()
if
dist_strategy
else
None
avg_cost
=
create_model
(
data_in
,
rank
)
opt
=
fluid
.
optimizer
.
SGD
(
0.1
)
if
dist_strategy
:
dist_opt
=
fleet
.
distributed_optimizer
(
optimizer
=
opt
,
strategy
=
strategy
)
dist_opt
.
minimize
(
avg_cost
)
else
:
opt
.
minimize
(
avg_cost
)
def
gen_data
():
np
.
random
.
seed
(
2021
)
while
True
:
data
=
[
np
.
random
.
random
([
seq_len
,
hidden
]).
astype
(
DTYPE
)]
yield
data
train_reader
=
paddle
.
batch
(
gen_data
,
batch_size
=
batch_size
)
if
dist_strategy
:
return
None
,
avg_cost
,
train_reader
,
None
,
None
,
None
,
data_loader
else
:
return
None
,
avg_cost
,
train_reader
,
None
,
None
,
None
if
__name__
==
"__main__"
:
runtime_main
(
TestModelParallel
)
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
1882c496
...
@@ -70,10 +70,12 @@ class TestFusedAttentionOp(OpTest):
...
@@ -70,10 +70,12 @@ class TestFusedAttentionOp(OpTest):
self
.
attn_mask_type
=
np
.
float64
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
has_attn_mask
=
True
self
.
has_cache_kv
=
False
self
.
training
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
query_length
=
128
self
.
cache_length
=
128
self
.
head_dim
=
64
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
...
@@ -88,10 +90,22 @@ class TestFusedAttentionOp(OpTest):
...
@@ -88,10 +90,22 @@ class TestFusedAttentionOp(OpTest):
def
generate_input_data
(
self
):
def
generate_input_data
(
self
):
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
).
astype
(
self
.
x_type
)
self
.
embed_dim
).
astype
(
self
.
x_type
)
out_seq_len
=
self
.
key_length
if
self
.
has_cache_kv
:
assert
self
.
training
is
False
,
ValueError
(
'cache_kv can only used in inference'
)
self
.
cache_kv
=
np
.
random
.
rand
(
2
,
self
.
batch_size
,
self
.
num_heads
,
self
.
cache_length
,
self
.
head_dim
).
astype
(
self
.
x_type
)
out_seq_len
+=
self
.
cache_length
else
:
self
.
cache_kv
=
None
if
self
.
has_attn_mask
:
if
self
.
has_attn_mask
:
# [B, n_head, seq_len, out_seq_len]
self
.
attn_mask
=
np
.
ones
(
self
.
attn_mask
=
np
.
ones
(
(
self
.
batch_size
,
self
.
num_heads
,
self
.
query_length
,
(
self
.
batch_size
,
self
.
num_heads
,
self
.
query_length
,
self
.
key_length
),
out_seq_len
),
dtype
=
self
.
attn_mask_type
)
dtype
=
self
.
attn_mask_type
)
if
self
.
attn_mask_type
==
np
.
int64
:
if
self
.
attn_mask_type
==
np
.
int64
:
self
.
attn_mask
=
np
.
tril
(
self
.
attn_mask
)
self
.
attn_mask
=
np
.
tril
(
self
.
attn_mask
)
...
@@ -110,6 +124,11 @@ class TestFusedAttentionOp(OpTest):
...
@@ -110,6 +124,11 @@ class TestFusedAttentionOp(OpTest):
def
GetBaselineOut
(
self
):
def
GetBaselineOut
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
tensor_query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
tensor_query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
cache_kv
=
None
if
self
.
has_cache_kv
:
cache_kv
=
paddle
.
to_tensor
(
self
.
cache_kv
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
else
:
...
@@ -130,6 +149,18 @@ class TestFusedAttentionOp(OpTest):
...
@@ -130,6 +149,18 @@ class TestFusedAttentionOp(OpTest):
v
=
tensor
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
v
=
tensor
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
v_out
=
tensor
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
v_out
=
tensor
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
if
self
.
has_cache_kv
:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k
,
cache_v
=
paddle
.
split
(
cache_kv
,
2
)
cache_k
=
paddle
.
squeeze
(
cache_k
,
axis
=
0
)
cache_v
=
paddle
.
squeeze
(
cache_v
,
axis
=
0
)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
k_out
=
paddle
.
concat
([
cache_k
,
k_out
],
axis
=-
2
)
v_out
=
paddle
.
concat
([
cache_v
,
v_out
],
axis
=-
2
)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out
=
layers
.
matmul
(
qk_out
=
layers
.
matmul
(
x
=
q_out
,
y
=
k_out
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
x
=
q_out
,
y
=
k_out
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
...
@@ -146,6 +177,8 @@ class TestFusedAttentionOp(OpTest):
...
@@ -146,6 +177,8 @@ class TestFusedAttentionOp(OpTest):
self
.
dropout_prob
,
self
.
dropout_prob
,
training
=
self
.
training
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
mode
=
"upscale_in_train"
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out
=
tensor
.
matmul
(
dropout_out
,
v_out
)
qktv_out
=
tensor
.
matmul
(
dropout_out
,
v_out
)
else
:
else
:
qktv_out
=
tensor
.
matmul
(
softmax_out
,
v_out
)
qktv_out
=
tensor
.
matmul
(
softmax_out
,
v_out
)
...
@@ -160,6 +193,10 @@ class TestFusedAttentionOp(OpTest):
...
@@ -160,6 +193,10 @@ class TestFusedAttentionOp(OpTest):
final_out
=
self
.
norm1
(
residual_out
)
final_out
=
self
.
norm1
(
residual_out
)
else
:
else
:
final_out
=
residual_out
final_out
=
residual_out
if
self
.
has_cache_kv
:
return
final_out
paddle
.
autograd
.
backward
(
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
tensor_query
.
grad
return
final_out
,
tensor_query
.
grad
...
@@ -206,6 +243,9 @@ class TestFusedAttentionOp(OpTest):
...
@@ -206,6 +243,9 @@ class TestFusedAttentionOp(OpTest):
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
))
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
))
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
cache_kv
=
None
if
self
.
has_cache_kv
:
cache_kv
=
paddle
.
to_tensor
(
self
.
cache_kv
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
else
:
...
@@ -219,8 +259,12 @@ class TestFusedAttentionOp(OpTest):
...
@@ -219,8 +259,12 @@ class TestFusedAttentionOp(OpTest):
final_out
=
incubate_f
.
fused_multi_head_attention
(
final_out
=
incubate_f
.
fused_multi_head_attention
(
x
,
qkv_weight_tensor
,
out_linear_weight
,
self
.
pre_layer_norm
,
x
,
qkv_weight_tensor
,
out_linear_weight
,
self
.
pre_layer_norm
,
ln1_scale
,
ln1_bias
,
ln2_scale
,
ln2_bias
,
epsilon
,
qkv_bias_tensor
,
ln1_scale
,
ln1_bias
,
ln2_scale
,
ln2_bias
,
epsilon
,
qkv_bias_tensor
,
out_linear_bias
,
attn_mask
,
self
.
dropout_prob
,
out_linear_bias
,
cache_kv
,
attn_mask
,
self
.
dropout_prob
,
self
.
attn_dropout_prob
,
ln2_epsilon
)
self
.
attn_dropout_prob
,
ln2_epsilon
)
if
self
.
has_cache_kv
:
return
final_out
[
0
],
final_out
[
1
]
paddle
.
autograd
.
backward
(
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
x
.
grad
return
final_out
,
x
.
grad
...
@@ -236,114 +280,27 @@ class TestFusedAttentionOp(OpTest):
...
@@ -236,114 +280,27 @@ class TestFusedAttentionOp(OpTest):
class
TestFusedAttentionOpBiasIsNone
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpBiasIsNone
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
super
().
config
()
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
False
self
.
bias_attr
=
False
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
super
().
config
()
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
super
().
config
()
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
False
self
.
has_attn_mask
=
False
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
x_type
=
np
.
float16
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
...
@@ -354,5 +311,21 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
...
@@ -354,5 +311,21 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
class
TestFusedAttentionOpCacheKV
(
TestFusedAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
training
=
False
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
def
test_fused_attention_op
(
self
):
with
paddle
.
no_grad
():
final_out_ref
=
self
.
GetBaselineOut
()
final_out
,
cache_kv_out
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py
0 → 100644
浏览文件 @
1882c496
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
test_dist_base
import
TestDistBase
import
os
import
paddle
paddle
.
enable_static
()
flag_name
=
os
.
path
.
splitext
(
__file__
)[
0
]
class
TestStaticModelParallel
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_use_reduce
=
False
self
.
_use_reader_alloc
=
False
self
.
_nccl_comm_num
=
1
self
.
_pipeline_mode
=
True
def
test_dist_static_model_parallel_fused_feedforward
(
self
):
import
paddle.fluid
as
fluid
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"static_model_parallel_fused_attention.py"
,
delta
=
1e-5
,
check_error_log
=
True
,
log_name
=
flag_name
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/nn/functional/fused_transformer.py
浏览文件 @
1882c496
...
@@ -223,12 +223,14 @@ def fused_multi_head_attention(x,
...
@@ -223,12 +223,14 @@ def fused_multi_head_attention(x,
pre_ln_epsilon
=
1e-05
,
pre_ln_epsilon
=
1e-05
,
qkv_bias
=
None
,
qkv_bias
=
None
,
linear_bias
=
None
,
linear_bias
=
None
,
cache_kv
=
None
,
attn_mask
=
None
,
attn_mask
=
None
,
dropout_rate
=
0.5
,
dropout_rate
=
0.5
,
attn_dropout_rate
=
0.5
,
attn_dropout_rate
=
0.5
,
ln_epsilon
=
1e-05
,
ln_epsilon
=
1e-05
,
training
=
True
,
training
=
True
,
mode
=
'upscale_in_train'
,
mode
=
'upscale_in_train'
,
ring_id
=-
1
,
name
=
None
):
name
=
None
):
r
"""
r
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Attention mapps queries and a set of key-value pairs to outputs, and
...
@@ -242,8 +244,8 @@ def fused_multi_head_attention(x,
...
@@ -242,8 +244,8 @@ def fused_multi_head_attention(x,
out = layer_norm(x)
out = layer_norm(x)
out = linear(out) + qkv) + bias
out = linear(out) + qkv) + bias
else:
else:
out = linear(x) + bias
out = linear(x) + bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
# extract q, k and v from out.
q = out[0:1,::]
q = out[0:1,::]
k = out[1:2,::]
k = out[1:2,::]
...
@@ -257,8 +259,8 @@ def fused_multi_head_attention(x,
...
@@ -257,8 +259,8 @@ def fused_multi_head_attention(x,
out = out_linear(out)
out = out_linear(out)
if pre_layer_norm:
if pre_layer_norm:
out = x + dropout(linear_bias + out)
out = x + dropout(linear_bias + out)
else:
else:
out = layer_norm(x + dropout(linear_bias + out))
out = layer_norm(x + dropout(linear_bias + out))
Parameters:
Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
...
@@ -276,6 +278,7 @@ def fused_multi_head_attention(x,
...
@@ -276,6 +278,7 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
...
@@ -303,6 +306,7 @@ def fused_multi_head_attention(x,
...
@@ -303,6 +306,7 @@ def fused_multi_head_attention(x,
- train: out = input * mask
- train: out = input * mask
- inference: out = input * (1.0 - p)
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Returns:
...
@@ -333,7 +337,7 @@ def fused_multi_head_attention(x,
...
@@ -333,7 +337,7 @@ def fused_multi_head_attention(x,
output = F.fused_multi_head_attention(
output = F.fused_multi_head_attention(
x, qkv_weight, linear_weight, False,
x, qkv_weight, linear_weight, False,
None, None, None, None, 1e-5, qkv_bias,
None, None, None, None, 1e-5, qkv_bias,
linear_bias, attn_mask)
linear_bias,
None,
attn_mask)
# [2, 4, 128]
# [2, 4, 128]
print(output.shape)
print(output.shape)
"""
"""
...
@@ -359,17 +363,20 @@ def fused_multi_head_attention(x,
...
@@ -359,17 +363,20 @@ def fused_multi_head_attention(x,
assert
qkv_weight
.
shape
[
1
]
*
qkv_weight
.
shape
[
2
]
==
qkv_weight
.
shape
[
assert
qkv_weight
.
shape
[
1
]
*
qkv_weight
.
shape
[
2
]
==
qkv_weight
.
shape
[
3
],
"embed_dim must be divisible by num_heads."
3
],
"embed_dim must be divisible by num_heads."
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
final_out
=
_C_ops
.
fused_attention
(
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
cache_kv_out
,
final_out
=
_C_ops
.
fused_attention
(
x
,
pre_ln_scale
,
pre_ln_bias
,
qkv_weight
,
qkv_bias
,
attn_mask
,
x
,
pre_ln_scale
,
pre_ln_bias
,
qkv_weight
,
qkv_bias
,
cache_kv
,
linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
'pre_layer_norm'
,
attn_mask
,
linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
pre_layer_norm
,
'epsilon'
,
pre_ln_epsilon
,
'dropout_rate'
,
'pre_layer_norm'
,
pre_layer_norm
,
'epsilon'
,
pre_ln_epsilon
,
dropout_rate
,
'attn_dropout_rate'
,
attn_dropout_rate
,
'ln_epsilon'
,
'dropout_rate'
,
dropout_rate
,
'attn_dropout_rate'
,
ln_epsilon
,
'attn_dropout_is_test'
,
not
training
,
'dropout_is_test'
,
attn_dropout_rate
,
'ln_epsilon'
,
ln_epsilon
,
'attn_dropout_is_test'
,
not
training
,
'attn_dropout_fix_seed'
,
seed
is
not
None
,
not
training
,
'dropout_is_test'
,
not
training
,
'dropout_fix_seed'
,
seed
is
not
None
,
'attn_dropout_seed'
,
seed
'attn_dropout_fix_seed'
,
seed
is
not
None
,
'dropout_fix_seed'
,
seed
is
not
None
,
'attn_dropout_seed'
,
seed
if
seed
is
not
None
else
0
,
'dropout_seed'
,
seed
if
seed
is
not
None
else
0
,
'dropout_seed'
,
seed
if
seed
is
not
None
else
0
,
'attn_dropout_implementation'
,
mode
,
if
seed
is
not
None
else
0
,
'attn_dropout_implementation'
,
mode
,
'dropout_implementation'
,
mode
)
'dropout_implementation'
,
mode
,
'ring_id'
,
ring_id
)
if
cache_kv
is
not
None
:
return
final_out
,
cache_kv_out
return
final_out
return
final_out
else
:
else
:
helper
=
LayerHelper
(
'fused_multi_head_attention'
,
**
locals
())
helper
=
LayerHelper
(
'fused_multi_head_attention'
,
**
locals
())
...
@@ -398,6 +405,7 @@ def fused_multi_head_attention(x,
...
@@ -398,6 +405,7 @@ def fused_multi_head_attention(x,
inputs
[
'Ln2Scale'
]
=
[
ln_scale
]
inputs
[
'Ln2Scale'
]
=
[
ln_scale
]
if
ln_bias
:
if
ln_bias
:
inputs
[
'Ln2Bias'
]
=
[
ln_bias
]
inputs
[
'Ln2Bias'
]
=
[
ln_bias
]
if
cache_kv
:
inputs
[
'CacheKV'
]
=
[
cache_kv
]
if
(
seed
is
None
or
seed
==
0
)
and
helper
.
main_program
.
random_seed
!=
0
:
if
(
seed
is
None
or
seed
==
0
)
and
helper
.
main_program
.
random_seed
!=
0
:
seed
=
helper
.
main_program
.
random_seed
seed
=
helper
.
main_program
.
random_seed
...
@@ -417,6 +425,7 @@ def fused_multi_head_attention(x,
...
@@ -417,6 +425,7 @@ def fused_multi_head_attention(x,
'dropout_seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout_seed'
:
seed
if
seed
is
not
None
else
0
,
'attn_dropout_implementation'
:
mode
,
'attn_dropout_implementation'
:
mode
,
'dropout_implementation'
:
mode
,
'dropout_implementation'
:
mode
,
'ring_id'
:
ring_id
}
}
# set outputs
# set outputs
...
@@ -449,6 +458,7 @@ def fused_multi_head_attention(x,
...
@@ -449,6 +458,7 @@ def fused_multi_head_attention(x,
bias_dropout_residual_out
=
helper
.
create_variable_for_type_inference
(
bias_dropout_residual_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
dtype
=
dtype
)
final_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
final_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
cache_kv_out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'fused_attention'
,
type
=
'fused_attention'
,
...
@@ -472,7 +482,9 @@ def fused_multi_head_attention(x,
...
@@ -472,7 +482,9 @@ def fused_multi_head_attention(x,
"Ln2Mean"
:
ln_mean_out
,
"Ln2Mean"
:
ln_mean_out
,
"Ln2Variance"
:
ln_variance_out
,
"Ln2Variance"
:
ln_variance_out
,
"BiasDropoutResidualOut"
:
bias_dropout_residual_out
,
"BiasDropoutResidualOut"
:
bias_dropout_residual_out
,
'Y'
:
final_out
'Y'
:
final_out
,
'CacheKVOut'
:
cache_kv_out
},
},
attrs
=
attrs
)
attrs
=
attrs
)
return
final_out
return
(
final_out
,
cache_kv_out
)
if
cache_kv
else
final_out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录