Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d4bf8b1a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d4bf8b1a
编写于
11月 03, 2022
作者:
S
ShenLiang
提交者:
GitHub
11月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support unbalanced data for pipeline (#47199) (#47569)
* add unbalanced data * fix utest
上级
ba4fbe71
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
294 addition
and
139 deletion
+294
-139
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+210
-132
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py
...e/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py
+67
-0
python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py
...llective/fleet/test_parallel_dygraph_pipeline_parallel.py
+17
-7
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
d4bf8b1a
...
...
@@ -20,7 +20,10 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from
..utils.hybrid_parallel_util
import
broadcast_dp_parameters
from
..utils.hybrid_parallel_util
import
broadcast_sharding_parameters
from
..utils.log_util
import
logger
from
..meta_optimizers.dygraph_optimizer
import
HybridParallelOptimizer
,
HybridParallelGradScaler
from
..meta_optimizers.dygraph_optimizer
import
(
HybridParallelOptimizer
,
HybridParallelGradScaler
,
)
import
paddle.fluid.framework
as
framework
from
.pp_utils
import
p2p_communication
as
p2p
import
paddle.fluid.core
as
core
...
...
@@ -29,27 +32,31 @@ __all__ = []
class
PipelineParallel
(
MetaParallelBase
):
def
__init__
(
self
,
layers
,
hcg
,
strategy
):
if
not
isinstance
(
layers
,
PipelineLayer
):
raise
TypeError
(
"The Layer should be a derived class of PipelineLayer."
)
"The Layer should be a derived class of PipelineLayer."
)
super
(
PipelineParallel
,
self
).
__init__
(
layers
,
hcg
,
strategy
)
self
.
use_data_parallel
=
self
.
_hcg
.
get_data_parallel_world_size
()
>
1
self
.
use_model_parallel
=
self
.
_hcg
.
get_model_parallel_world_size
()
>
1
self
.
use_sharding_parallel
=
self
.
_hcg
.
get_sharding_parallel_world_size
(
)
>
1
self
.
use_sharding_parallel
=
(
self
.
_hcg
.
get_sharding_parallel_world_size
()
>
1
)
self
.
total_loss
=
None
self
.
micro_batch_size
=
self
.
_strategy
.
pipeline_configs
[
'micro_batch_size'
]
'micro_batch_size'
]
self
.
accumulate_steps
=
self
.
_strategy
.
pipeline_configs
[
'accumulate_steps'
]
'accumulate_steps'
]
# If sent tensor are not the same from different hosts,
# they shouldn't been sent partially and then concated as a whole tensor.
self
.
_enable_partial_send_recv
=
self
.
_strategy
.
pipeline_configs
[
'enable_partial_send_recv'
]
'enable_partial_send_recv'
]
self
.
_using_cache
=
self
.
_strategy
.
pipeline_configs
[
'p2p_cache_shape'
]
self
.
num_stages
=
self
.
_hcg
.
get_pipe_parallel_world_size
()
...
...
@@ -61,16 +68,20 @@ class PipelineParallel(MetaParallelBase):
self
.
_real_pp_world_size
=
self
.
num_stages
self
.
_real_pp_rank
=
self
.
stage_id
p2p
.
initialize_p2p_groups
(
hcg
,
self
.
_using_cache
,
self
.
_enable_partial_send_recv
)
p2p
.
initialize_p2p_groups
(
hcg
,
self
.
_using_cache
,
self
.
_enable_partial_send_recv
)
self
.
global_rank
=
self
.
_hcg
.
get_global_rank
()
self
.
micro_batch_id
=
0
self
.
_compute_loss
=
True
logger
.
info
(
"Pipeline Info -- num_stages: {}, stage_id: {}"
.
format
(
self
.
num_stages
,
self
.
stage_id
))
logger
.
info
(
"Pipeline Info -- num_stages: {}, stage_id: {}"
.
format
(
self
.
num_stages
,
self
.
stage_id
)
)
if
self
.
use_model_parallel
:
logger
.
info
(
"start broadcast mp parameters"
)
...
...
@@ -122,7 +133,7 @@ class PipelineParallel(MetaParallelBase):
# store data id for micro_batch
self
.
micro_batch_id
=
0
startup_steps
=
(
self
.
num_stages
-
self
.
stage_id
-
1
)
startup_steps
=
self
.
num_stages
-
self
.
stage_id
-
1
startup_steps
=
min
(
startup_steps
,
self
.
accumulate_steps
)
steady_steps
=
self
.
accumulate_steps
-
startup_steps
...
...
@@ -142,39 +153,46 @@ class PipelineParallel(MetaParallelBase):
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
for
i
in
range
(
steady_steps
):
last_iter
=
(
i
==
(
steady_steps
-
1
)
)
last_iter
=
i
==
(
steady_steps
-
1
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
output_tensor
,
self
.
is_pipeline_last_stage
()
)
input_buffers
.
append
(
input_tensor
)
output_buffers
.
append
(
output_tensor
)
input_tensor
,
output_tensor
=
input_buffers
.
pop
(
0
),
output_buffers
.
pop
(
0
)
0
),
output_buffers
.
pop
(
0
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iter
:
input_tensor
=
None
p2p
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
())
p2p
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
()
)
else
:
input_tensor
=
p2p
.
send_backward_recv_forward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
())
input_tensor_grad
,
self
.
is_pipeline_first_stage
()
)
for
i
in
range
(
startup_steps
):
input_tensor
=
input_buffers
.
pop
(
0
)
output_tensor
=
output_buffers
.
pop
(
0
)
output_tensor_grad
=
p2p
.
recv_backward
(
self
.
is_pipeline_last_stage
())
self
.
is_pipeline_last_stage
()
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
())
self
.
_layers
.
allreduce_shared_weight_gradients
()
...
...
@@ -186,17 +204,20 @@ class PipelineParallel(MetaParallelBase):
# reset the virtual pp rank for each run
self
.
set_virtual_pipeline_rank
(
0
)
assert
isinstance
(
optimizer
,
HybridParallelOptimizer
),
(
'optimizer should be HybridParallelOptimizer subclass.'
)
assert
isinstance
(
optimizer
,
HybridParallelOptimizer
),
'optimizer should be HybridParallelOptimizer subclass.'
assert
fluid
.
framework
.
_dygraph_tracer
().
_has_grad
,
(
'Please enable the generation of gradients.'
)
assert
(
fluid
.
framework
.
_dygraph_tracer
().
_has_grad
),
'Please enable the generation of gradients.'
if
self
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
assert
data
is
not
None
,
(
"For the first and the last stage, the data must be set."
)
ignore_virtual
=
True
)
or
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
assert
(
data
is
not
None
),
"For the first and the last stage, the data must be set."
else
:
data
=
None
...
...
@@ -233,7 +254,7 @@ class PipelineParallel(MetaParallelBase):
# store total loss of entire batch
self
.
total_loss
=
None
startup_steps
=
(
self
.
num_stages
-
self
.
stage_id
-
1
)
startup_steps
=
self
.
num_stages
-
self
.
stage_id
-
1
startup_steps
=
min
(
startup_steps
,
self
.
accumulate_steps
)
steady_steps
=
self
.
accumulate_steps
-
startup_steps
...
...
@@ -253,7 +274,7 @@ class PipelineParallel(MetaParallelBase):
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
for
i
in
range
(
steady_steps
):
last_iter
=
(
i
==
(
steady_steps
-
1
)
)
last_iter
=
i
==
(
steady_steps
-
1
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
...
...
@@ -282,13 +303,14 @@ class PipelineParallel(MetaParallelBase):
if
self
.
is_pipeline_last_stage
():
# train calculate loss for train
if
self
.
_compute_loss
:
assert
self
.
_layers
.
_loss_fn
is
not
None
,
"loss function should exist to compute loss"
assert
(
self
.
_layers
.
_loss_fn
is
not
None
),
"loss function should exist to compute loss"
labels
=
self
.
_load_micro_batch
(
self
.
micro_batch_id
)
output_tensor
=
self
.
_layers
.
_loss_fn
(
output_tensor
,
labels
)
assert
isinstance
(
output_tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)),
"Currently, loss_fn should obtain Paddle.Tensor dtype"
output_tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
),
"Currently, loss_fn should obtain Paddle.Tensor dtype"
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
if
self
.
accumulate_steps
>
1
:
...
...
@@ -318,91 +340,113 @@ class PipelineParallel(MetaParallelBase):
assert
len
(
outputs
)
==
len
(
output_tensor_grad
)
paddle
.
autograd
.
backward
(
tensors
=
outputs
,
grad_tensors
=
[
t
for
t
in
output_tensor_grad
])
grad_tensors
=
[
t
for
t
in
output_tensor_grad
],
)
else
:
paddle
.
autograd
.
backward
(
tensors
=
[
output_tensor
],
grad_tensors
=
[
output_tensor_grad
])
paddle
.
autograd
.
backward
(
tensors
=
[
output_tensor
],
grad_tensors
=
[
output_tensor_grad
],
)
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
if
isinstance
(
input_tensor
,
tuple
):
input_tensor_grad
=
tuple
(
[
t
.
grad
for
t
in
input_tensor
if
not
t
.
stop_gradient
])
[
t
.
grad
for
t
in
input_tensor
if
not
t
.
stop_gradient
]
)
else
:
input_tensor_grad
=
input_tensor
.
grad
return
input_tensor_grad
def
_load_micro_batch
(
self
,
cache_id
):
inputs
=
self
.
data
def
_check_data_vaild
(
self
,
data
):
batch_size
=
data
.
shape
[
0
]
assert
self
.
micro_batch_size
*
self
.
accumulate_steps
==
batch_size
,
(
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
%
(
batch_size
,
self
.
micro_batch_size
,
self
.
accumulate_steps
)
)
def
_load_micro_batch_impl
(
self
,
inputs
,
cache_id
):
begin
=
cache_id
*
self
.
micro_batch_size
end
=
begin
+
self
.
micro_batch_size
# The virtual first and last pipeline stage need data, all others don't need.
if
isinstance
(
inputs
,
tuple
):
output
=
[]
for
data
in
inputs
:
if
isinstance
(
data
,
list
):
assert
(
len
(
data
)
==
self
.
accumulate_steps
),
"length of data should be %d, but it is %d"
%
(
self
.
accumulate_steps
,
len
(
data
),
)
output
.
append
(
data
[
cache_id
].
detach
())
else
:
self
.
_check_data_vaild
(
data
)
output
.
append
(
data
[
begin
:
end
,
:].
detach
())
return
tuple
(
output
)
elif
isinstance
(
inputs
,
list
):
assert
(
len
(
inputs
)
==
self
.
accumulate_steps
),
"length of data should be %d, but it is %d"
%
(
self
.
accumulate_steps
,
len
(
inputs
),
)
return
inputs
[
cache_id
].
detach
()
else
:
self
.
_check_data_vaild
(
inputs
)
return
inputs
[
begin
:
end
,
:].
detach
()
def
_load_micro_batch
(
self
,
cache_id
):
inputs
=
self
.
data
if
self
.
is_pipeline_first_stage
():
assert
len
(
inputs
)
==
2
,
"length of input should be 2"
if
isinstance
(
inputs
[
0
],
tuple
):
assert
len
(
inputs
[
0
]
)
>
1
,
"If you use tuple for input data, it should have at least two inputs."
batch_size
=
inputs
[
0
][
0
].
shape
[
0
]
assert
self
.
micro_batch_size
*
self
.
accumulate_steps
==
batch_size
,
(
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
%
(
batch_size
,
self
.
micro_batch_size
,
self
.
accumulate_steps
))
data
=
[
input
[
begin
:
end
,
:].
detach
()
for
input
in
inputs
[
0
]]
return
tuple
(
data
)
else
:
batch_size
=
inputs
[
0
].
shape
[
0
]
assert
self
.
micro_batch_size
*
self
.
accumulate_steps
==
batch_size
return
inputs
[
0
][
begin
:
end
,
:].
detach
()
return
self
.
_load_micro_batch_impl
(
inputs
[
0
],
cache_id
)
elif
self
.
is_pipeline_last_stage
():
assert
len
(
inputs
)
==
2
,
"length of input should be 2"
if
isinstance
(
inputs
[
1
],
tuple
):
batch_size
=
inputs
[
1
][
0
].
shape
[
0
]
assert
self
.
micro_batch_size
*
self
.
accumulate_steps
==
batch_size
data
=
[
input
[
begin
:
end
,
:].
detach
()
for
input
in
inputs
[
1
]]
return
tuple
(
data
)
else
:
batch_size
=
inputs
[
1
].
shape
[
0
]
assert
self
.
micro_batch_size
*
self
.
accumulate_steps
==
batch_size
return
inputs
[
1
][
begin
:
end
,
:].
detach
()
return
self
.
_load_micro_batch_impl
(
inputs
[
1
],
cache_id
)
else
:
# No data input is required for other stages
inputs
=
None
def
_broadcast_final_loss
(
self
):
# Since the last backward run in interleave will set the virtual rank to 0,
# here we need to check last stage ignoring virtual stage.
if
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
assert
self
.
total_loss
is
not
None
,
"train_batch() in last stage should obtain vaild loss"
assert
(
self
.
total_loss
is
not
None
),
"train_batch() in last stage should obtain vaild loss"
loss
=
self
.
total_loss
.
detach
()
is_fp32
=
paddle
.
to_tensor
(
1
)
if
loss
.
dtype
==
paddle
.
float32
else
paddle
.
to_tensor
(
0
)
paddle
.
distributed
.
broadcast
(
is_fp32
,
src
=
self
.
global_rank
,
sync_op
=
True
,
group
=
self
.
pp_group
)
paddle
.
distributed
.
broadcast
(
loss
,
src
=
self
.
global_rank
,
sync_op
=
True
,
group
=
self
.
pp_group
)
is_fp32
=
(
paddle
.
to_tensor
(
1
)
if
loss
.
dtype
==
paddle
.
float32
else
paddle
.
to_tensor
(
0
)
)
paddle
.
distributed
.
broadcast
(
is_fp32
,
src
=
self
.
global_rank
,
sync_op
=
True
,
group
=
self
.
pp_group
)
paddle
.
distributed
.
broadcast
(
loss
,
src
=
self
.
global_rank
,
sync_op
=
True
,
group
=
self
.
pp_group
)
else
:
is_fp32
=
paddle
.
to_tensor
(
1
)
paddle
.
distributed
.
broadcast
(
is_fp32
,
src
=
self
.
_hcg
.
get_rank_from_stage
(
self
.
num_stages
-
1
),
sync_op
=
True
,
group
=
self
.
pp_group
)
loss
=
paddle
.
zeros
(
shape
=
[
1
],
dtype
=
"float32"
)
if
is_fp32
.
numpy
()[
0
]
else
paddle
.
zeros
(
shape
=
[
1
],
dtype
=
"float16"
)
group
=
self
.
pp_group
,
)
loss
=
(
paddle
.
zeros
(
shape
=
[
1
],
dtype
=
"float32"
)
if
is_fp32
.
numpy
()[
0
]
else
paddle
.
zeros
(
shape
=
[
1
],
dtype
=
"float16"
)
)
paddle
.
distributed
.
broadcast
(
loss
,
src
=
self
.
_hcg
.
get_rank_from_stage
(
self
.
num_stages
-
1
),
sync_op
=
True
,
group
=
self
.
pp_group
)
group
=
self
.
pp_group
,
)
return
loss
def
_optimizer_step
(
self
):
...
...
@@ -421,11 +465,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
# pipeline parallel with interleave scheduler
def
__init__
(
self
,
layers
,
hcg
,
strategy
):
super
(
PipelineParallelWithInterleave
,
self
).
__init__
(
layers
=
layers
,
hcg
=
hcg
,
strategy
=
strategy
)
super
(
PipelineParallelWithInterleave
,
self
).
__init__
(
layers
=
layers
,
hcg
=
hcg
,
strategy
=
strategy
)
assert
layers
.
get_num_virtual_stages
()
>
1
assert
framework
.
in_dygraph_mode
(
assert
(
framework
.
in_dygraph_mode
()
),
"virtual pipeline stage with interleave only support eager dygraph mode"
# setup for interleave scheduler
self
.
num_model_chunks
=
layers
.
get_num_virtual_stages
()
...
...
@@ -436,11 +481,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
_virtual_pp_rank
=
0
def
_get_virtual_pp_rank
(
self
,
micro_step
,
forward
):
virtual_pp_stage
=
micro_step
%
(
self
.
num_stages
*
self
.
num_model_chunks
)
virtual_pp_stage
=
micro_step
%
(
self
.
num_stages
*
self
.
num_model_chunks
)
virtual_pp_stage
=
virtual_pp_stage
//
self
.
num_stages
if
not
forward
:
virtual_pp_stage
=
(
self
.
num_model_chunks
-
virtual_pp_stage
-
1
)
virtual_pp_stage
=
self
.
num_model_chunks
-
virtual_pp_stage
-
1
return
virtual_pp_stage
def
_forward_step_helper
(
self
,
micro_step
):
...
...
@@ -455,7 +501,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
if
self
.
is_pipeline_first_stage
():
if
len
(
self
.
input_tensors
[
virtual_pp_rank
])
==
len
(
self
.
output_tensors
[
virtual_pp_rank
]):
self
.
output_tensors
[
virtual_pp_rank
]
):
self
.
input_tensors
[
virtual_pp_rank
].
append
(
None
)
input_tensor
=
self
.
input_tensors
[
virtual_pp_rank
][
-
1
]
output_tensor
=
self
.
_forward_step
(
input_tensor
,
virtual_pp_rank
)
...
...
@@ -484,21 +531,22 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor
=
self
.
input_tensors
[
virtual_pp_rank
].
pop
(
0
)
output_tensor
=
self
.
output_tensors
[
virtual_pp_rank
].
pop
(
0
)
output_tensor_grad
=
self
.
output_tensor_grads
[
virtual_pp_rank
].
pop
(
0
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
def
interleave_pipeline
(
self
,
data
,
scaler
,
forward_only
=
False
,
compute_loss
=
True
):
def
interleave_pipeline
(
self
,
data
,
scaler
,
forward_only
=
False
,
compute_loss
=
True
):
# use interleave scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if
not
compute_loss
:
assert
not
forward_only
,
"compute_loss can only be set to False when forward_only is set to True"
assert
(
not
forward_only
),
"compute_loss can only be set to False when forward_only is set to True"
# init some attributes for this batch run
self
.
scaler
=
scaler
...
...
@@ -530,15 +578,17 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
set_virtual_pipeline_rank
(
0
)
self
.
input_tensors
[
0
].
append
(
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
(),
sync_recv
=
False
))
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
(),
sync_recv
=
False
)
)
# run startup steps
for
micro_step
in
range
(
startup_steps
):
output_tensor
=
self
.
_forward_step_helper
(
micro_step
)
# determine whether recv forward tensor or not
next_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
+
1
,
forward
=
True
)
next_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
+
1
,
forward
=
True
)
recv_prev
=
True
if
self
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_virtual_pp_rank
==
0
:
...
...
@@ -552,24 +602,33 @@ class PipelineParallelWithInterleave(PipelineParallel):
if
self
.
is_pipeline_last_stage
():
output_tensor
=
None
if
micro_step
==
(
startup_steps
-
1
)
and
not
forward_only
and
not
all_startup_steps
:
if
(
micro_step
==
(
startup_steps
-
1
)
and
not
forward_only
and
not
all_startup_steps
):
input_tensor_grad
=
None
recv_next
=
True
if
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
# the last startup step needs on four direction comm to set up for steady 1f1b
input_tensor
,
output_tensor_grad
=
p2p
.
send_forward_backward_recv_forward_backward
(
(
input_tensor
,
output_tensor_grad
,
)
=
p2p
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
self
.
output_tensor_grads
[
self
.
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
recv_next
=
recv_next
,
)
self
.
output_tensor_grads
[
self
.
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
p2p
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
)
output_tensor
,
recv_prev
=
recv_prev
)
self
.
input_tensors
[
next_virtual_pp_rank
].
append
(
input_tensor
)
# run 1f1b steady steps
...
...
@@ -581,7 +640,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
# backward
backward_micro_step_id
=
micro_step
input_tensor_grad
=
self
.
_backward_step_helper
(
backward_micro_step_id
)
backward_micro_step_id
)
# four directions comm
# send output tensor to downstream
...
...
@@ -591,14 +651,16 @@ class PipelineParallelWithInterleave(PipelineParallel):
# last stage doesn't send rst to downstream
forward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
forward_micro_step_id
,
forward
=
True
)
forward_micro_step_id
,
forward
=
True
)
self
.
set_virtual_pipeline_rank
(
forward_virtual_pp_rank
)
if
self
.
is_pipeline_last_stage
():
output_tensor
=
None
# first stage doesn't send grad to upstream
backward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
backward_micro_step_id
,
forward
=
False
)
backward_micro_step_id
,
forward
=
False
)
self
.
set_virtual_pipeline_rank
(
backward_virtual_pp_rank
)
if
self
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
...
...
@@ -607,14 +669,16 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_prev
=
True
if
self
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
next_forward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
forward_micro_step_id
-
(
self
.
num_stages
-
1
),
forward
=
True
)
forward_micro_step_id
-
(
self
.
num_stages
-
1
),
forward
=
True
)
if
next_forward_virtual_pp_rank
==
(
self
.
num_model_chunks
-
1
):
# first pp stage and first virtual stage
recv_prev
=
False
next_forward_virtual_pp_rank
+=
1
else
:
next_forward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
forward_micro_step_id
+
1
,
forward
=
True
)
forward_micro_step_id
+
1
,
forward
=
True
)
# last iteration doesn't need recv from upstream
if
micro_step
==
(
steady_steps
-
1
):
...
...
@@ -625,53 +689,67 @@ class PipelineParallelWithInterleave(PipelineParallel):
if
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
next_backward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
backward_micro_step_id
-
(
self
.
num_stages
-
1
),
forward
=
False
)
forward
=
False
,
)
if
next_backward_virtual_pp_rank
==
0
:
# last pp stage and last virtual stage
recv_next
=
False
next_backward_virtual_pp_rank
-=
1
else
:
next_backward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
backward_micro_step_id
+
1
,
forward
=
False
)
backward_micro_step_id
+
1
,
forward
=
False
)
input_tensor
,
output_tensor_grad
=
p2p
.
send_forward_backward_recv_forward_backward
(
(
input_tensor
,
output_tensor_grad
,
)
=
p2p
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
recv_next
=
recv_next
,
)
if
recv_prev
:
self
.
input_tensors
[
next_forward_virtual_pp_rank
].
append
(
input_tensor
)
input_tensor
)
if
recv_next
:
self
.
output_tensor_grads
[
next_backward_virtual_pp_rank
].
append
(
output_tensor_grad
)
output_tensor_grad
)
# remaining backward steps
if
not
forward_only
:
if
all_startup_steps
:
self
.
output_tensor_grads
[
self
.
num_model_chunks
-
1
].
append
(
p2p
.
recv_backward
(
self
.
is_pipeline_last_stage
(),
sync_recv
=
False
))
p2p
.
recv_backward
(
self
.
is_pipeline_last_stage
(),
sync_recv
=
False
)
)
for
micro_step
in
range
(
steady_steps
,
num_steps
):
# cooldown loop
input_tensor_grad
=
self
.
_backward_step_helper
(
micro_step
)
next_backward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
+
1
,
forward
=
False
)
micro_step
+
1
,
forward
=
False
)
recv_next
=
True
if
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_virtual_pp_rank
==
(
self
.
num_model_chunks
-
1
):
if
next_backward_virtual_pp_rank
==
(
self
.
num_model_chunks
-
1
):
recv_next
=
False
if
micro_step
==
(
num_steps
-
1
):
recv_next
=
False
self
.
output_tensor_grads
[
next_backward_virtual_pp_rank
].
append
(
p2p
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
))
p2p
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
)
)
self
.
_layers
.
allreduce_shared_weight_gradients
()
...
...
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py
0 → 100644
浏览文件 @
d4bf8b1a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
numpy
as
np
import
paddle.distributed
as
dist
import
paddle.distributed.fleet
as
fleet
from
hybrid_parallel_pp_transformer
import
(
TestDistPPTraning
,
set_random_seed
,
ModelPipe
,
batch_size
,
length
,
micro_batch_size
,
vocab_size
,
)
class
TestDistPPTraningUnbalancedData
(
TestDistPPTraning
):
def
test_pp_model
(
self
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
word_size
=
hcg
.
get_model_parallel_world_size
()
dp_id
=
hcg
.
get_data_parallel_rank
()
pp_id
=
hcg
.
get_stage_id
()
rank_id
=
dist
.
get_rank
()
topology
=
hcg
.
topology
()
set_random_seed
(
1024
,
dp_id
,
rank_id
)
model
=
ModelPipe
(
topology
)
scheduler
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler
,
parameters
=
model
.
parameters
()
)
model
=
fleet
.
distributed_model
(
model
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
for
step_id
in
range
(
5
):
x
=
[]
for
_
in
range
(
batch_size
//
micro_batch_size
):
size
=
micro_batch_size
x_data
=
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
[
size
,
length
])
x
.
append
(
paddle
.
to_tensor
(
x_data
))
e_loss
=
model
.
eval_batch
([
x
,
x
],
True
)
loss
=
model
.
train_batch
([
x
,
x
],
optimizer
,
scheduler
)
# TODO(shenliang03) add utest for loss
if
pp_id
!=
0
:
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
e_loss
.
numpy
())
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py
浏览文件 @
d4bf8b1a
...
...
@@ -22,13 +22,14 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class
TestHybridPipeParallel
(
TestMultipleGpus
):
def
test_hybrid_parallel_pp_layer
(
self
):
self
.
run_mnist_2gpu
(
os
.
path
.
abspath
(
'../../hybrid_parallel_pp_layer.py'
))
os
.
path
.
abspath
(
'../../hybrid_parallel_pp_layer.py'
)
)
self
.
run_mnist_2gpu
(
os
.
path
.
abspath
(
'../../hybrid_parallel_pp_layer.py'
),
eager_mode
=
False
)
eager_mode
=
False
,
)
def
test_hybrid_parallel_pp_tuple_inputs
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_embedding.py'
)
...
...
@@ -36,8 +37,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def
test_hybrid_parallel_shared_weight
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_shared_weight.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_shared_weight.py'
,
eager_mode
=
False
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_shared_weight.py'
,
eager_mode
=
False
)
def
test_pipeline_parallel_amp
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_amp.py'
)
...
...
@@ -49,8 +51,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def
test_hybrid_parallel_transformer
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer.py'
,
eager_mode
=
False
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_save_load
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_save_load.py'
)
...
...
@@ -64,6 +67,13 @@ class TestHybridPipeParallel(TestMultipleGpus):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_clip_grad.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_clip_grad.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_transformer_unbalanced_data
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer_unbalanced_data.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer_unbalanced_data.py'
,
eager_mode
=
False
,
)
if
__name__
==
"__main__"
:
os
.
environ
[
"FLAGS_enable_eager_mode"
]
=
"1"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录