Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4cc3d9a2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4cc3d9a2
编写于
8月 05, 2021
作者:
S
ShenLiang
提交者:
GitHub
8月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel]Fix bug of p2p for partial_send/recv (#34615)
* fix bug of p2p for partial * fix error
上级
090c863a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
75 addition
and
48 deletion
+75
-48
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+0
-14
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+64
-29
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py
...e/fluid/tests/unittests/hybrid_parallel_pp_transformer.py
+11
-5
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
4cc3d9a2
...
@@ -64,18 +64,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -64,18 +64,6 @@ class PipelineParallel(MetaParallelBase):
logger
.
info
(
"start broadcast dp parameters"
)
logger
.
info
(
"start broadcast dp parameters"
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
def
_set_tensor_trainable
(
self
,
tensor
):
if
tensor
is
None
:
return
if
isinstance
(
tensor
,
tuple
):
for
t
in
tensor
:
if
is_float_tensor
(
t
):
t
.
stop_gradient
=
False
else
:
if
is_float_tensor
(
tensor
):
tensor
.
stop_gradient
=
False
def
train_batch
(
self
,
data
,
optimizer
,
lr_scheduler
=
None
,
scaler
=
None
):
def
train_batch
(
self
,
data
,
optimizer
,
lr_scheduler
=
None
,
scaler
=
None
):
assert
isinstance
(
optimizer
,
HybridParallelOptimizer
),
(
assert
isinstance
(
optimizer
,
HybridParallelOptimizer
),
(
'optimizer should be HybridParallelOptimizer subclass.'
)
'optimizer should be HybridParallelOptimizer subclass.'
)
...
@@ -117,7 +105,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -117,7 +105,6 @@ class PipelineParallel(MetaParallelBase):
for
step_id
in
range
(
startup_steps
):
for
step_id
in
range
(
startup_steps
):
input_tensor
=
p2p
.
recv_forward
()
input_tensor
=
p2p
.
recv_forward
()
self
.
_set_tensor_trainable
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
p2p
.
send_forward
(
output_tensor
)
p2p
.
send_forward
(
output_tensor
)
...
@@ -131,7 +118,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -131,7 +118,6 @@ class PipelineParallel(MetaParallelBase):
for
i
in
range
(
steady_steps
):
for
i
in
range
(
steady_steps
):
last_iter
=
(
i
==
(
steady_steps
-
1
))
last_iter
=
(
i
==
(
steady_steps
-
1
))
self
.
_set_tensor_trainable
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor
)
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor
)
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
4cc3d9a2
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
import
paddle
import
paddle
from
.utils
import
paddle_2_number
,
number_2_dtype
from
.utils
import
paddle_2_number
,
number_2_dtype
from
...utils.log_util
import
logger
from
...utils.log_util
import
logger
import
numpy
as
np
from
paddle
import
_C_ops
_hcg
=
None
_hcg
=
None
...
@@ -40,6 +42,7 @@ class SendRecvMeta:
...
@@ -40,6 +42,7 @@ class SendRecvMeta:
self
.
recv_shape_message
=
None
self
.
recv_shape_message
=
None
self
.
recv_dtype_message
=
None
self
.
recv_dtype_message
=
None
self
.
recv_stop_gradient
=
None
self
.
has_send_meta
=
False
self
.
has_send_meta
=
False
self
.
has_recv_meta
=
False
self
.
has_recv_meta
=
False
...
@@ -57,7 +60,11 @@ class SendRecvMeta:
...
@@ -57,7 +60,11 @@ class SendRecvMeta:
# recv dtype
# recv dtype
dtype
=
paddle
.
to_tensor
([
0
])
dtype
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
dtype
,
src
=
0
,
group
=
group
)
paddle
.
distributed
.
recv
(
dtype
,
src
=
0
,
group
=
group
)
return
shape
.
numpy
().
tolist
(),
dtype
.
item
()
# recv stop_gradient
stop_grad
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
stop_grad
,
src
=
0
,
group
=
group
)
return
shape
.
numpy
().
tolist
(),
dtype
.
item
(),
stop_grad
.
item
()
def
recv_meta
(
self
,
group
):
def
recv_meta
(
self
,
group
):
tensor_type
=
paddle
.
to_tensor
([
0
])
tensor_type
=
paddle
.
to_tensor
([
0
])
...
@@ -65,9 +72,10 @@ class SendRecvMeta:
...
@@ -65,9 +72,10 @@ class SendRecvMeta:
tensor_type
=
tensor_type
.
item
()
tensor_type
=
tensor_type
.
item
()
if
tensor_type
==
0
:
if
tensor_type
==
0
:
shape
,
dtype
=
self
.
_recv_shape_dtype
(
group
)
shape
,
dtype
,
stop_grad
=
self
.
_recv_shape_dtype
(
group
)
self
.
recv_shape_message
=
shape
self
.
recv_shape_message
=
shape
self
.
recv_dtype_message
=
dtype
self
.
recv_dtype_message
=
dtype
self
.
recv_stop_gradient
=
bool
(
stop_grad
)
elif
tensor_type
==
1
:
elif
tensor_type
==
1
:
num
=
paddle
.
to_tensor
([
0
])
num
=
paddle
.
to_tensor
([
0
])
...
@@ -75,13 +83,16 @@ class SendRecvMeta:
...
@@ -75,13 +83,16 @@ class SendRecvMeta:
num
=
num
.
item
()
num
=
num
.
item
()
shapes
=
[]
shapes
=
[]
dtypes
=
[]
dtypes
=
[]
stop_grads
=
[]
for
i
in
range
(
num
):
for
i
in
range
(
num
):
shape
,
dtype
=
self
.
_recv_shape_dtype
(
group
)
shape
,
dtype
,
stop_grad
=
self
.
_recv_shape_dtype
(
group
)
shapes
.
append
(
shape
)
shapes
.
append
(
shape
)
dtypes
.
append
(
dtype
)
dtypes
.
append
(
dtype
)
stop_grads
.
append
(
bool
(
stop_grad
))
self
.
recv_shape_message
=
tuple
(
shapes
)
self
.
recv_shape_message
=
tuple
(
shapes
)
self
.
recv_dtype_message
=
tuple
(
dtypes
)
self
.
recv_dtype_message
=
tuple
(
dtypes
)
self
.
recv_stop_gradient
=
tuple
(
stop_grads
)
def
_send_dims_shape_dtype
(
self
,
tensor
,
group
):
def
_send_dims_shape_dtype
(
self
,
tensor
,
group
):
# send len(shape)
# send len(shape)
...
@@ -96,6 +107,10 @@ class SendRecvMeta:
...
@@ -96,6 +107,10 @@ class SendRecvMeta:
dtype
=
paddle
.
to_tensor
(
paddle_2_number
(
tensor
.
dtype
))
dtype
=
paddle
.
to_tensor
(
paddle_2_number
(
tensor
.
dtype
))
paddle
.
distributed
.
send
(
dtype
,
dst
=
1
,
group
=
group
)
paddle
.
distributed
.
send
(
dtype
,
dst
=
1
,
group
=
group
)
# send trainable
stop_grad
=
paddle
.
to_tensor
(
int
(
tensor
.
stop_gradient
))
paddle
.
distributed
.
send
(
stop_grad
,
dst
=
1
,
group
=
group
)
def
send_meta
(
self
,
tensor
,
group
):
def
send_meta
(
self
,
tensor
,
group
):
if
isinstance
(
tensor
,
paddle
.
Tensor
):
if
isinstance
(
tensor
,
paddle
.
Tensor
):
tensor_type
=
paddle
.
to_tensor
([
0
])
tensor_type
=
paddle
.
to_tensor
([
0
])
...
@@ -129,6 +144,12 @@ class SendRecvMeta:
...
@@ -129,6 +144,12 @@ class SendRecvMeta:
_send_recv_meta
=
SendRecvMeta
()
_send_recv_meta
=
SendRecvMeta
()
def
_is_valid_send_recv_partial
(
tensor
,
mp_degree
):
tensor_numel
=
np
.
prod
(
tensor
.
shape
)
assert
tensor_numel
!=
0
,
"can't send/recv zero element"
return
mp_degree
>
1
and
tensor_numel
%
mp_degree
==
0
def
send_partial
(
tensor
,
def
send_partial
(
tensor
,
dst
=
0
,
dst
=
0
,
nranks
=
1
,
nranks
=
1
,
...
@@ -138,9 +159,14 @@ def send_partial(tensor,
...
@@ -138,9 +159,14 @@ def send_partial(tensor,
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
return
paddle
.
fluid
.
core
.
ops
.
partial_send
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
dst
,
'num'
,
nranks
,
'id'
,
rank_id
)
return
_C_ops
.
partial_send
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
dst
,
'num'
,
nranks
,
'id'
,
rank_id
)
else
:
return
paddle
.
distributed
.
send
(
tensor
,
dst
=
dst
,
group
=
group
,
use_calc_stream
=
use_calc_stream
)
def
recv_partial
(
tensor
,
def
recv_partial
(
tensor
,
...
@@ -153,10 +179,14 @@ def recv_partial(tensor,
...
@@ -153,10 +179,14 @@ def recv_partial(tensor,
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
paddle
.
fluid
.
core
.
ops
.
partial_recv
(
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
_C_ops
.
partial_recv
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
src
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
'ring_id'
,
ring_id
,
'peer'
,
src
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
tensor
.
shape
)
tensor
.
shape
)
else
:
paddle
.
distributed
.
recv
(
tensor
,
src
=
src
,
group
=
group
,
use_calc_stream
=
use_calc_stream
)
def
allgather_partial
(
tensor
,
def
allgather_partial
(
tensor
,
...
@@ -164,15 +194,15 @@ def allgather_partial(tensor,
...
@@ -164,15 +194,15 @@ def allgather_partial(tensor,
rank_id
=
0
,
rank_id
=
0
,
group
=
None
,
group
=
None
,
use_calc_stream
=
True
):
use_calc_stream
=
True
):
if
n
ranks
==
1
:
if
n
ot
_is_valid_send_recv_partial
(
tensor
,
nranks
)
:
return
tensor
return
tensor
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
return
paddle
.
fluid
.
core
.
ops
.
partial_allgather_
(
return
_C_ops
.
partial_allgather_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'ring_id'
,
ring_id
,
'nranks'
,
nranks
,
'nranks'
,
nranks
,
'rank'
,
rank_id
)
'rank'
,
rank_id
)
def
_p2p_helper
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
):
def
_p2p_helper
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
):
...
@@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
# send / recv message
# send / recv message
recv_shape_msg
=
_send_recv_meta
.
recv_shape_message
recv_shape_msg
=
_send_recv_meta
.
recv_shape_message
recv_dtype_msg
=
_send_recv_meta
.
recv_dtype_message
recv_dtype_msg
=
_send_recv_meta
.
recv_dtype_message
recv_stop_gradient
=
_send_recv_meta
.
recv_stop_gradient
send_shape_msg
=
_send_recv_meta
.
send_shape_message
send_shape_msg
=
_send_recv_meta
.
send_shape_message
send_dtype_msg
=
_send_recv_meta
.
send_dtype_message
send_dtype_msg
=
_send_recv_meta
.
send_dtype_message
...
@@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if
isinstance
(
recv_shape_msg
,
tuple
):
if
isinstance
(
recv_shape_msg
,
tuple
):
tensor_recv_prev
=
[]
tensor_recv_prev
=
[]
for
idx
,
shape
in
enumerate
(
recv_shape_msg
):
for
idx
,
shape
in
enumerate
(
recv_shape_msg
):
tensor_recv_prev
.
append
(
tmp
=
paddle
.
empty
(
paddle
.
empty
(
shape
=
shape
,
dtype
=
number_2_dtype
(
recv_dtype_msg
[
idx
]))
shape
=
shape
,
dtype
=
number_2_dtype
(
recv_dtype_msg
[
idx
])))
tmp
.
stop_gradient
=
recv_stop_gradient
[
idx
]
tensor_recv_prev
.
append
(
tmp
)
tensor_recv_prev
=
tuple
(
tensor_recv_prev
)
tensor_recv_prev
=
tuple
(
tensor_recv_prev
)
else
:
else
:
tensor_recv_prev
=
paddle
.
empty
(
tensor_recv_prev
=
paddle
.
empty
(
shape
=
recv_shape_msg
,
dtype
=
number_2_dtype
(
recv_dtype_msg
))
shape
=
recv_shape_msg
,
dtype
=
number_2_dtype
(
recv_dtype_msg
))
tensor_recv_prev
.
stop_gradient
=
recv_stop_gradient
if
recv_next
:
if
recv_next
:
if
isinstance
(
send_shape_msg
,
tuple
):
if
isinstance
(
send_shape_msg
,
tuple
):
...
@@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for
d
in
tensor_send_prev
:
for
d
in
tensor_send_prev
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
send_partial
(
d
,
d
.
detach
()
,
dst
=
0
,
dst
=
0
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
...
@@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else
:
else
:
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
send_partial
(
send_partial
(
tensor_send_prev
,
tensor_send_prev
.
detach
()
,
dst
=
0
,
dst
=
0
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
...
@@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if
isinstance
(
tensor_recv_prev
,
tuple
):
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
for
d
in
tensor_recv_prev
:
recv_partial
(
recv_partial
(
d
,
d
.
detach
()
,
src
=
0
,
src
=
0
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
True
)
use_calc_stream
=
True
)
allgather_partial
(
allgather_partial
(
d
,
d
.
detach
()
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
mp_group
,
group
=
mp_group
,
use_calc_stream
=
True
)
use_calc_stream
=
True
)
else
:
else
:
recv_partial
(
recv_partial
(
tensor_recv_prev
,
tensor_recv_prev
.
detach
()
,
src
=
0
,
src
=
0
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
True
)
use_calc_stream
=
True
)
allgather_partial
(
allgather_partial
(
tensor_recv_prev
,
tensor_recv_prev
.
detach
()
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
mp_group
,
group
=
mp_group
,
...
@@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for
d
in
tensor_send_next
:
for
d
in
tensor_send_next
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
send_partial
(
d
,
d
.
detach
()
,
dst
=
1
,
dst
=
1
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
...
@@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else
:
else
:
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
send_partial
(
send_partial
(
tensor_send_next
,
tensor_send_next
.
detach
()
,
dst
=
1
,
dst
=
1
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
...
@@ -294,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -294,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if
isinstance
(
tensor_recv_next
,
tuple
):
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
for
d
in
tensor_recv_next
:
recv_partial
(
recv_partial
(
d
,
d
.
detach
()
,
src
=
1
,
src
=
1
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
True
)
use_calc_stream
=
True
)
allgather_partial
(
allgather_partial
(
d
,
d
.
detach
()
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
mp_group
,
group
=
mp_group
,
...
@@ -309,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -309,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else
:
else
:
recv_partial
(
recv_partial
(
tensor_recv_next
,
tensor_recv_next
.
detach
()
,
src
=
1
,
src
=
1
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
...
@@ -317,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -317,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
use_calc_stream
=
True
)
use_calc_stream
=
True
)
allgather_partial
(
allgather_partial
(
tensor_recv_next
,
tensor_recv_next
.
detach
()
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
mp_group
,
group
=
mp_group
,
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py
浏览文件 @
4cc3d9a2
...
@@ -54,13 +54,17 @@ class EmbeddingNet(Layer):
...
@@ -54,13 +54,17 @@ class EmbeddingNet(Layer):
attention_mask
=
paddle
.
tensor
.
triu
(
attention_mask
=
paddle
.
tensor
.
triu
(
(
paddle
.
ones
(
(
paddle
.
ones
(
(
length
,
length
),
dtype
=
"float32"
)
*
-
1e9
),
1
)
(
length
,
length
),
dtype
=
"float32"
)
*
-
1e9
),
1
)
attention_mask
.
stop_gradient
=
True
no_used
=
paddle
.
ones
((
3
,
3
),
dtype
=
"int32"
)
w_emb
=
self
.
word_embeddings
(
x
)
w_emb
=
self
.
word_embeddings
(
x
)
p_emb
=
self
.
position_embeddings
(
x
)
p_emb
=
self
.
position_embeddings
(
x
)
w_emb
=
w_emb
+
p_emb
w_emb
=
w_emb
+
p_emb
attention_mask
.
stop_gradient
=
True
no_used
.
stop_gradient
=
True
# need to fix bug of backward()
# need to fix bug of backward()
return
w_emb
,
attention_mask
return
w_emb
,
attention_mask
,
no_used
,
p_emb
class
TransformerNet
(
Layer
):
class
TransformerNet
(
Layer
):
...
@@ -99,12 +103,12 @@ class EmbeddingPipe(EmbeddingNet):
...
@@ -99,12 +103,12 @@ class EmbeddingPipe(EmbeddingNet):
class
TransformerNetPipe
(
TransformerNet
):
class
TransformerNetPipe
(
TransformerNet
):
def
forward
(
self
,
args
):
def
forward
(
self
,
args
):
x
,
mask
=
args
[
0
],
args
[
1
]
x
,
mask
,
no_used
,
p_emb
=
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]
output
=
super
().
forward
(
x
,
mask
)
output
=
super
().
forward
(
x
,
mask
)
output
=
output
output
=
output
+
p_emb
mask
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
output
,
mask
return
output
,
mask
,
no_used
,
p_emb
class
CriterionPipe
(
Layer
):
class
CriterionPipe
(
Layer
):
...
@@ -175,6 +179,8 @@ class TestDistPPTraning(unittest.TestCase):
...
@@ -175,6 +179,8 @@ class TestDistPPTraning(unittest.TestCase):
loss
=
model
.
train_batch
([
x
,
x
],
optimizer
,
scheduler
)
loss
=
model
.
train_batch
([
x
,
x
],
optimizer
,
scheduler
)
# TODO(shenliang03) add utest for loss
# TODO(shenliang03) add utest for loss
print
(
"loss: "
,
loss
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录