Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4640f4be
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4640f4be
编写于
3月 21, 2023
作者:
S
ShenLiang
提交者:
GitHub
3月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[OPT] FlashAttention && ModelParallel (#51617)
* fix flash_attention * Update mp_layers.py
上级
f82da79c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
55 addition
and
46 deletion
+55
-46
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
+7
-10
paddle/phi/kernels/gpu/flash_attn_kernel.cu
paddle/phi/kernels/gpu/flash_attn_kernel.cu
+9
-8
python/paddle/distributed/fleet/layers/mpu/mp_layers.py
python/paddle/distributed/fleet/layers/mpu/mp_layers.py
+37
-10
python/paddle/distributed/fleet/layers/mpu/mp_ops.py
python/paddle/distributed/fleet/layers/mpu/mp_ops.py
+2
-18
未找到文件。
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
浏览文件 @
4640f4be
...
...
@@ -67,10 +67,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
bool
zero_tensors
=
false
;
std
::
vector
<
int64_t
>
seed_offset_vec
;
phi
::
TensorToVector
<
int64_t
>
(
seed_offset
,
ctx
,
&
seed_offset_vec
);
uint64_t
seed
=
seed_offset_vec
[
0
];
uint64_t
offset
=
seed_offset_vec
[
1
];
const
int64_t
*
seed_offset_data
=
seed_offset
.
data
<
int64_t
>
();
uint64_t
seed
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
0
]);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
1
]);
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
DenseTensor
dsoftmax
=
Empty
<
float
>
(
ctx
,
{
batch_size
,
num_heads
,
seq_len_q
});
...
...
@@ -188,12 +187,10 @@ void FlashAttnGradKernel(const Context& ctx,
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
DenseTensor
q_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
q
,
{
total_q
,
num_heads
,
head_size
});
DenseTensor
k_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
k
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
v_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
v
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
v_t_s
.
ShareDataWith
(
v
).
Resize
({
total_k
,
num_heads
,
head_size
});
DenseTensor
cu_seqlens_q
;
DenseTensor
cu_seqlens_k
;
...
...
paddle/phi/kernels/gpu/flash_attn_kernel.cu
浏览文件 @
4640f4be
...
...
@@ -75,11 +75,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
auto
gen
=
ctx
.
GetGenerator
();
uint64_t
inc
=
batch_size
*
num_heads
*
32
;
auto
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
uint64_t
seed
=
seed_offset_pair
.
first
;
uint64_t
offset
=
seed_offset_pair
.
second
;
std
::
vector
<
int64_t
>
seed_offset_vec
{
int64_t
(
seed
),
int64_t
(
offset
)};
phi
::
TensorFromVector
<
int64_t
>
(
seed_offset_vec
,
ctx
,
seed_offset
);
seed_offset
->
Resize
({
2
});
auto
*
seed_offset_data
=
ctx
.
template
HostAlloc
<
int64_t
>(
seed_offset
);
seed_offset_data
[
0
]
=
static_cast
<
int64_t
>
(
seed
);
seed_offset_data
[
1
]
=
static_cast
<
int64_t
>
(
offset
);
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
...
...
@@ -210,12 +213,10 @@ void FlashAttnKernel(const Context& ctx,
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
DenseTensor
q_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
q
,
{
total_q
,
num_heads
,
head_size
});
DenseTensor
k_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
k
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
v_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
v
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
v_t_s
.
ShareDataWith
(
v
).
Resize
({
total_k
,
num_heads
,
head_size
});
DenseTensor
cu_seqlens_q
;
DenseTensor
cu_seqlens_k
;
...
...
python/paddle/distributed/fleet/layers/mpu/mp_layers.py
浏览文件 @
4640f4be
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
paddle
from
paddle.autograd
import
PyLayer
from
paddle.fluid
import
core
from
paddle.nn
import
functional
as
F
...
...
@@ -328,6 +329,17 @@ class ColumnParallelLinear(paddle.nn.Layer):
return
output
class
MPScale
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
,
mp_degree
):
out
=
paddle
.
scale
(
x
,
1.0
/
mp_degree
)
return
out
@
staticmethod
def
backward
(
ctx
,
dout
):
return
dout
class
RowParallelLinear
(
paddle
.
nn
.
Layer
):
"""Linear layer with mp parallelized(row).
this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.
...
...
@@ -467,6 +479,7 @@ class RowParallelLinear(paddle.nn.Layer):
from
paddle.incubate.nn.functional
import
fused_linear
self
.
linear
=
fused_linear
self
.
fuse_matmul_bias
=
fuse_matmul_bias
def
forward
(
self
,
x
):
if
self
.
input_is_parallel
or
(
not
self
.
is_mp
):
...
...
@@ -476,16 +489,30 @@ class RowParallelLinear(paddle.nn.Layer):
input_parallel
=
mp_ops
.
_c_split
(
x
,
group
=
self
.
model_parallel_group
)
if
self
.
is_mp
:
output_parallel
=
self
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
_name
)
output_
=
mp_ops
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
,
)
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
if
self
.
fuse_matmul_bias
:
bias
=
MPScale
.
apply
(
self
.
bias
,
self
.
world_size
)
output_parallel
=
self
.
linear
(
input_parallel
,
self
.
weight
,
bias
,
name
=
self
.
_name
)
output
=
mp_ops
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
,
)
else
:
output_parallel
=
self
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
_name
)
output_
=
mp_ops
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
,
)
output
=
(
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
)
else
:
output
=
self
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
,
name
=
self
.
_name
...
...
python/paddle/distributed/fleet/layers/mpu/mp_ops.py
浏览文件 @
4640f4be
...
...
@@ -46,15 +46,7 @@ def _c_identity(tensor, group=None):
class
c_identity_eager
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
tensor
):
return
_legacy_C_ops
.
c_identity
(
tensor
,
'use_calc_stream'
,
True
,
'ring_id'
,
group
.
id
,
'use_model_parallel'
,
True
,
)
return
tensor
@
staticmethod
def
backward
(
ctx
,
dy
):
...
...
@@ -257,15 +249,7 @@ def _mp_allreduce(
@
staticmethod
def
backward
(
ctx
,
dy
):
return
_legacy_C_ops
.
c_identity
(
dy
,
'use_calc_stream'
,
True
,
'ring_id'
,
ctx
.
ring_id
,
'use_model_parallel'
,
True
,
)
return
dy
return
mp_allreduce_eager
.
apply
(
tensor
,
group
,
use_calc_stream
,
use_model_parallel
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录