Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5ccc49e7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5ccc49e7
编写于
6月 02, 2022
作者:
H
Haohongxiang
提交者:
GitHub
6月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support eager dygraph in moe_layer (#43168)
上级
0fbf815c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
195 addition
and
28 deletion
+195
-28
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+1
-3
python/paddle/distributed/parallel.py
python/paddle/distributed/parallel.py
+3
-0
python/paddle/incubate/distributed/models/moe/moe_layer.py
python/paddle/incubate/distributed/models/moe/moe_layer.py
+178
-21
python/paddle/incubate/distributed/models/moe/utils.py
python/paddle/incubate/distributed/models/moe/utils.py
+13
-4
未找到文件。
python/paddle/distributed/collective.py
浏览文件 @
5ccc49e7
...
@@ -405,9 +405,7 @@ def new_group(ranks=None, backend=None):
...
@@ -405,9 +405,7 @@ def new_group(ranks=None, backend=None):
# TODO(shenliang03): This is a temporary solution to solve the problem of
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
# hang caused by tcp
tmp
=
paddle
.
to_tensor
([
1
],
dtype
=
"int32"
)
paddle
.
distributed
.
barrier
(
group
=
group
)
paddle
.
distributed
.
all_reduce
(
tmp
,
group
=
group
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tmp
)
return
group
return
group
if
not
backend
:
if
not
backend
:
...
...
python/paddle/distributed/parallel.py
浏览文件 @
5ccc49e7
...
@@ -19,6 +19,7 @@ from multiprocessing import Process # noqa: F401
...
@@ -19,6 +19,7 @@ from multiprocessing import Process # noqa: F401
from
multiprocessing
import
Manager
# noqa: F401
from
multiprocessing
import
Manager
# noqa: F401
import
time
import
time
import
sys
import
sys
import
paddle
from
paddle
import
compat
as
cpt
from
paddle
import
compat
as
cpt
...
@@ -259,6 +260,8 @@ def init_parallel_env():
...
@@ -259,6 +260,8 @@ def init_parallel_env():
_set_group_map_by_name
(
_default_group_name
,
group
)
_set_group_map_by_name
(
_default_group_name
,
group
)
_set_group_map
(
0
,
group
)
_set_group_map
(
0
,
group
)
parallel_helper
.
_set_parallel_ctx
(
True
)
parallel_helper
.
_set_parallel_ctx
(
True
)
paddle
.
distributed
.
barrier
(
group
=
group
)
return
group
return
group
node_num
=
set
([
i
.
split
(
":"
)[
0
]
for
i
in
parallel_env
.
trainer_endpoints
])
node_num
=
set
([
i
.
split
(
":"
)[
0
]
for
i
in
parallel_env
.
trainer_endpoints
])
...
...
python/paddle/incubate/distributed/models/moe/moe_layer.py
浏览文件 @
5ccc49e7
...
@@ -31,11 +31,12 @@ from paddle.distributed import alltoall, all_gather
...
@@ -31,11 +31,12 @@ from paddle.distributed import alltoall, all_gather
from
paddle.distributed.fleet.meta_parallel
import
get_rng_state_tracker
from
paddle.distributed.fleet.meta_parallel
import
get_rng_state_tracker
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
from
paddle.autograd
import
PyLayer
from
paddle.autograd
import
PyLayer
,
EagerPyLayer
from
.gate
import
NaiveGate
,
GShardGate
,
SwitchGate
,
BaseGate
from
.gate
import
NaiveGate
,
GShardGate
,
SwitchGate
,
BaseGate
from
.utils
import
count_by_gate
from
.utils
import
count_by_gate
from
paddle.distributed.fleet.meta_parallel.pp_utils.utils
import
_hp_recompute
from
paddle.distributed.fleet.meta_parallel.pp_utils.utils
import
_hp_recompute
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid.framework
import
in_dygraph_mode
def
_local_scatter
(
inp
,
pos
):
def
_local_scatter
(
inp
,
pos
):
...
@@ -63,17 +64,26 @@ def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
...
@@ -63,17 +64,26 @@ def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
def
_all_gather
(
tensor
,
group
=
None
,
use_calc_stream
=
True
):
def
_all_gather
(
tensor
,
group
=
None
,
use_calc_stream
=
True
):
"""
The main difference with paddle.distributed.all_gather:
no need to pass in tensor_list, the returned tensor is spliced
"""
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
nranks
=
paddle
.
distributed
.
collective
.
_get_global_group
(
if
in_dygraph_mode
():
).
nranks
if
group
is
None
else
group
.
nranks
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
return
paddle
.
_C_ops
.
c_allgather
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
)
if
group
is
None
else
group
'ring_id'
,
ring_id
,
'nranks'
,
nranks
)
tensor_shape
=
list
(
tensor
.
shape
)
tensor_shape
[
0
]
*=
group
.
nranks
out
=
paddle
.
empty
(
tensor_shape
,
tensor
.
dtype
)
task
=
group
.
process_group
.
all_gather
(
tensor
,
out
)
task
.
wait
()
return
out
else
:
ring_id
=
0
if
group
is
None
else
group
.
id
nranks
=
paddle
.
distributed
.
collective
.
_get_global_group
(
).
nranks
if
group
is
None
else
group
.
nranks
return
paddle
.
_C_ops
.
c_allgather
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'nranks'
,
nranks
)
class
MoEScatter
(
PyLayer
):
class
MoEScatter
(
PyLayer
):
...
@@ -122,6 +132,52 @@ class MoEScatter(PyLayer):
...
@@ -122,6 +132,52 @@ class MoEScatter(PyLayer):
return
grad_in
,
None
,
None
,
None
return
grad_in
,
None
,
None
,
None
class
EagerMoEScatter
(
EagerPyLayer
):
r
"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@
staticmethod
def
forward
(
ctx
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
,
group
=
None
):
local_input_buf
=
_local_scatter
(
inp
,
pos
)
if
world_size
>
1
:
global_input_buf
=
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
group
=
group
)
else
:
global_input_buf
=
local_input_buf
ctx
.
moe_args
=
inp
.
shape
[
0
],
world_size
,
group
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
global_input_buf
@
staticmethod
def
backward
(
ctx
,
grad
):
(
pos
,
local_expert_count
,
global_expert_count
)
=
ctx
.
saved_tensor
()
(
inp_batch_size
,
world_size
,
group
)
=
ctx
.
moe_args
if
world_size
>
1
:
local_grad_in
=
global_gather
(
grad
,
local_expert_count
,
global_expert_count
,
group
=
group
)
else
:
local_grad_in
=
grad
grad_in
=
_local_gather
(
local_grad_in
,
pos
,
inp_batch_size
)
return
grad_in
,
None
,
None
,
None
class
MoEGather
(
PyLayer
):
class
MoEGather
(
PyLayer
):
r
"""
r
"""
Gather output samples from contiguous alone experts back to [batch x
Gather output samples from contiguous alone experts back to [batch x
...
@@ -169,6 +225,53 @@ class MoEGather(PyLayer):
...
@@ -169,6 +225,53 @@ class MoEGather(PyLayer):
return
global_grad_out_buf
,
None
,
None
,
None
return
global_grad_out_buf
,
None
,
None
,
None
class
EagerMoEGather
(
EagerPyLayer
):
r
"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MoEScatter.
"""
@
staticmethod
def
forward
(
ctx
,
global_output_buf
,
pos
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
,
group
=
None
):
if
world_size
>
1
:
local_output_buf
=
global_gather
(
global_output_buf
,
local_expert_count
,
global_expert_count
,
group
=
group
)
else
:
local_output_buf
=
global_output_buf
output
=
_local_gather
(
local_output_buf
,
pos
,
local_batch_size
,
maybe_overlap
=
False
)
ctx
.
moe_args
=
(
global_output_buf
.
shape
[
0
],
world_size
,
group
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensor
()
fwd_batch_size
,
world_size
,
group
=
ctx
.
moe_args
grad_out_buf
=
_local_scatter
(
grad_out
,
pos
)
if
world_size
>
1
:
global_grad_out_buf
=
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
group
=
group
)
else
:
global_grad_out_buf
=
grad_out_buf
return
global_grad_out_buf
,
None
,
None
,
None
class
AllGather
(
PyLayer
):
class
AllGather
(
PyLayer
):
r
"""
r
"""
A wrapper for the All-Gather function to support auto-differentiation.
A wrapper for the All-Gather function to support auto-differentiation.
...
@@ -189,6 +292,26 @@ class AllGather(PyLayer):
...
@@ -189,6 +292,26 @@ class AllGather(PyLayer):
grad_out
,
axes
=
[
0
],
starts
=
[
rank
*
dim0
],
ends
=
[(
rank
+
1
)
*
dim0
])
grad_out
,
axes
=
[
0
],
starts
=
[
rank
*
dim0
],
ends
=
[(
rank
+
1
)
*
dim0
])
class
EagerAllGather
(
EagerPyLayer
):
r
"""
A wrapper for the All-Gather function to support auto-differentiation.
"""
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
tensor_list
=
[]
paddle
.
distributed
.
all_gather
(
tensor_list
,
inp
,
group
=
group
)
output
=
paddle
.
concat
(
tensor_list
,
axis
=
0
)
ctx
.
args
=
rank
,
inp
.
shape
[
0
]
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
rank
,
dim0
=
ctx
.
args
return
paddle
.
slice
(
grad_out
,
axes
=
[
0
],
starts
=
[
rank
*
dim0
],
ends
=
[(
rank
+
1
)
*
dim0
])
class
Slice
(
PyLayer
):
class
Slice
(
PyLayer
):
r
"""
r
"""
A wrapper for the Slice function to support auto-differentiation.
A wrapper for the Slice function to support auto-differentiation.
...
@@ -208,11 +331,29 @@ class Slice(PyLayer):
...
@@ -208,11 +331,29 @@ class Slice(PyLayer):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
world_size
,
group
=
ctx
.
args
world_size
,
group
=
ctx
.
args
# tensor_list = []
# paddle.distributed.all_gather(tensor_list, grad_out, group=group)
# grad_out = paddle.concat(tensor_list, axis=0)
return
_all_gather
(
grad_out
,
group
=
group
)
return
_all_gather
(
grad_out
,
group
=
group
)
# return grad_out
class
EagerSlice
(
EagerPyLayer
):
r
"""
A wrapper for the Slice function to support auto-differentiation.
"""
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
B
=
inp
.
shape
[
0
]
local_batch_size
=
B
//
world_size
batch_start
=
local_batch_size
*
rank
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
paddle
.
slice
(
inp
,
axes
=
[
0
],
starts
=
[
batch_start
],
ends
=
[
batch_end
])
ctx
.
args
=
world_size
,
group
return
inp
@
staticmethod
def
backward
(
ctx
,
grad_out
):
world_size
,
group
=
ctx
.
args
return
_all_gather
(
grad_out
,
group
=
group
)
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
moe_group
):
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
moe_group
):
...
@@ -369,7 +510,10 @@ class MoELayer(nn.Layer):
...
@@ -369,7 +510,10 @@ class MoELayer(nn.Layer):
mp_rank
=
self
.
mp_group
.
rank
mp_rank
=
self
.
mp_group
.
rank
mp_size
=
self
.
mp_group
.
nranks
mp_size
=
self
.
mp_group
.
nranks
if
mp_size
>
1
:
if
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
mp_rank
,
mp_size
,
self
.
mp_group
)
if
in_dygraph_mode
():
inp
=
EagerSlice
.
apply
(
inp
,
mp_rank
,
mp_size
,
self
.
mp_group
)
else
:
inp
=
Slice
.
apply
(
inp
,
mp_rank
,
mp_size
,
self
.
mp_group
)
value
,
gate
=
self
.
gate
(
inp
)
value
,
gate
=
self
.
gate
(
inp
)
(
(
...
@@ -390,9 +534,14 @@ class MoELayer(nn.Layer):
...
@@ -390,9 +534,14 @@ class MoELayer(nn.Layer):
temp_pos
=
pos
temp_pos
=
pos
assert
topk
==
self
.
top_k
assert
topk
==
self
.
top_k
x
=
MoEScatter
.
apply
(
inp
,
temp_pos
,
local_expert_count
,
if
in_dygraph_mode
():
global_expert_count
,
fwd_batch_size
,
x
=
EagerMoEScatter
.
apply
(
inp
,
temp_pos
,
local_expert_count
,
self
.
world_size
,
self
.
group
)
global_expert_count
,
fwd_batch_size
,
self
.
world_size
,
self
.
group
)
else
:
x
=
MoEScatter
.
apply
(
inp
,
temp_pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
self
.
world_size
,
self
.
group
)
d_model
=
self
.
d_model
d_model
=
self
.
d_model
...
@@ -421,15 +570,23 @@ class MoELayer(nn.Layer):
...
@@ -421,15 +570,23 @@ class MoELayer(nn.Layer):
if
len
(
gate
.
shape
)
==
2
:
if
len
(
gate
.
shape
)
==
2
:
out_batch_size
*=
gate
.
shape
[
1
]
out_batch_size
*=
gate
.
shape
[
1
]
x
=
MoEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
if
in_dygraph_mode
():
out_batch_size
,
self
.
world_size
,
self
.
group
)
x
=
EagerMoEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
self
.
world_size
,
self
.
group
)
else
:
x
=
MoEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
self
.
world_size
,
self
.
group
)
x
=
x
.
reshape
([
-
1
,
self
.
top_k
,
d_model
])
x
=
x
.
reshape
([
-
1
,
self
.
top_k
,
d_model
])
value
=
value
.
reshape
([
x
.
shape
[
0
],
1
,
self
.
top_k
])
value
=
value
.
reshape
([
x
.
shape
[
0
],
1
,
self
.
top_k
])
x
=
paddle
.
bmm
(
value
,
x
).
reshape
([
-
1
,
d_model
])
x
=
paddle
.
bmm
(
value
,
x
).
reshape
([
-
1
,
d_model
])
if
mp_size
>
1
:
if
mp_size
>
1
:
x
=
AllGather
.
apply
(
x
,
mp_rank
,
mp_size
,
self
.
mp_group
)
if
in_dygraph_mode
():
x
=
EagerAllGather
.
apply
(
x
,
mp_rank
,
mp_size
,
self
.
mp_group
)
else
:
x
=
AllGather
.
apply
(
x
,
mp_rank
,
mp_size
,
self
.
mp_group
)
x
=
paddle
.
reshape_
(
x
,
origin_shape
)
x
=
paddle
.
reshape_
(
x
,
origin_shape
)
...
...
python/paddle/incubate/distributed/models/moe/utils.py
浏览文件 @
5ccc49e7
...
@@ -21,15 +21,24 @@
...
@@ -21,15 +21,24 @@
from
paddle.distributed.models.moe.utils
import
_number_count
,
_limit_by_capacity
,
_prune_gate_by_capacity
,
_assign_pos
from
paddle.distributed.models.moe.utils
import
_number_count
,
_limit_by_capacity
,
_prune_gate_by_capacity
,
_assign_pos
import
paddle
import
paddle
from
paddle.fluid.framework
import
in_dygraph_mode
def
_alltoall
(
in_tensor_list
,
group
=
None
,
use_calc_stream
=
True
):
def
_alltoall
(
in_tensor_list
,
group
=
None
,
use_calc_stream
=
True
):
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
nranks
=
len
(
in_tensor_list
)
if
in_dygraph_mode
():
return
paddle
.
_C_ops
.
alltoall
(
in_tensor_list
,
'use_calc_stream'
,
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
use_calc_stream
,
'ring_id'
,
ring_id
)
)
if
group
is
None
else
group
out
=
paddle
.
empty
(
in_tensor_list
.
shape
,
in_tensor_list
.
dtype
)
task
=
group
.
process_group
.
alltoall
(
in_tensor_list
,
out
)
task
.
wait
()
return
out
else
:
ring_id
=
0
if
group
is
None
else
group
.
id
return
paddle
.
_C_ops
.
alltoall
(
in_tensor_list
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
,
group
=
None
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
,
group
=
None
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录