Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
aac91e82
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aac91e82
编写于
6月 16, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
6月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Two kinds of profiler to pp/vp (#54586)
上级
9b2bcfd6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
323 addition
and
7 deletion
+323
-7
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+280
-7
python/paddle/distributed/fleet/meta_parallel/pp_utils/profiler_helper.py
...stributed/fleet/meta_parallel/pp_utils/profiler_helper.py
+42
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
aac91e82
...
...
@@ -63,6 +63,7 @@ message PpConfig {
optional
bool
delay_scale_loss
=
2
[
default
=
false
];
optional
bool
enable_timer
=
3
[
default
=
false
];
optional
bool
sharding_comm_overlap
=
4
[
default
=
false
];
optional
bool
profiling
=
5
[
default
=
false
];
}
message
HybridConfig
{
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
aac91e82
...
...
@@ -10,6 +10,8 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import
time
import
warnings
import
paddle
from
paddle
import
framework
...
...
@@ -140,6 +142,7 @@ class PipelineParallel(MetaParallelBase):
self
.
num_stages
=
self
.
_hcg
.
get_pipe_parallel_world_size
()
self
.
stage_id
=
self
.
_hcg
.
get_stage_id
()
self
.
global_rank
=
self
.
_hcg
.
get_global_rank
()
self
.
pp_group
=
self
.
_hcg
.
get_pipe_parallel_group
()
self
.
dp_group
=
self
.
_hcg
.
get_data_parallel_group
()
self
.
sharding_group
=
self
.
_hcg
.
get_sharding_parallel_group
()
...
...
@@ -163,6 +166,30 @@ class PipelineParallel(MetaParallelBase):
self
.
_enable_timer
=
self
.
_strategy
.
hybrid_configs
[
"pp_configs"
].
enable_timer
self
.
_profiling
=
self
.
_strategy
.
hybrid_configs
[
"pp_configs"
].
profiling
self
.
_records
=
[]
self
.
_record_format
=
(
'"name": "{}{}", "cat": "pipeline timeline", "ph": {}, "pid": 0, "tid": '
+
str
(
self
.
stage_id
+
1
)
+
', "ts": {}, "cname": "{}"'
)
self
.
_forward_color
=
"thread_state_running"
# RGB: 126, 200, 148
self
.
_backward_color
=
"rail_idle"
# RGB: 238, 142, 0
if
self
.
_profiling
:
logger
.
info
(
"If enable pp profiling, the max training steps should be restricted "
"to a reasonable value (such as 5) to avoid generating large profile files. "
"The profiler will generate a profile file 'profile_record_tmp_file_for_rank_*' "
"for each rank. Users should gather all profile files for one entire pipeline "
"to one node (rank 0 is recommended) to get the full view of the pipeline profile. "
"[DONT CHANGE THE NAME OF THE PROFILE FILES!]. "
"Then get the profile parser from this url: "
"https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/meta_parallel/pp_utils/profiler_helper.py "
"and save the script to the same directory of all profile files."
"Parse those files by this command: `python profiler_helper.py`. "
"After parsing, a new file 'pipeline_profile.json' will be generated. "
"Users can inspect this file by chrome://tracing website."
)
if
self
.
_dp_comm_overlap
:
assert
self
.
use_data_parallel
and
self
.
num_stages
>
1
...
...
@@ -306,11 +333,51 @@ class PipelineParallel(MetaParallelBase):
all_flag_names
=
self
.
timers
.
timers
.
keys
()
self
.
timers
.
log
(
all_flag_names
)
def
forward_backward_pipeline
(
self
,
data
,
scaler
=
None
):
def
_record_stamp
(
self
,
name
,
step
,
phase
,
color
):
if
self
.
_profiling
:
paddle
.
device
.
synchronize
()
self
.
_records
.
append
(
'{'
+
self
.
_record_format
.
format
(
name
,
step
,
phase
,
int
(
time
.
time
()
*
1000
),
color
,
)
+
'}'
)
def
_flush_records
(
self
):
if
self
.
_profiling
:
with
open
(
f
'./profile_record_tmp_file_for_rank_
{
self
.
global_rank
}
'
,
'a+'
,
)
as
f
:
for
record
in
self
.
_records
:
f
.
write
(
record
+
'
\n
'
)
self
.
_records
=
[]
def
forward_backward_pipeline
(
self
,
data
,
scaler
=
None
,
static_scheduler
=
False
):
# use the 1f1b scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if
static_scheduler
:
assert
(
not
self
.
_profiling
),
"While _profiling, static scheduler is not available"
if
data
is
not
None
:
warnings
.
warn
(
"Static scheduler run won't real run the model, but data has been provided"
)
logger
.
info
(
"enable static_scheduler will return the pp schedule instead of the loss"
)
schedule
=
""
self
.
scaler
=
scaler
# store total loss of entire batch
...
...
@@ -329,9 +396,15 @@ class PipelineParallel(MetaParallelBase):
micro_dataset
=
self
.
_wrap_data
(
data
)
for
step_id
in
range
(
startup_steps
):
if
static_scheduler
:
schedule
+=
f
"f
{
step_id
}
;"
logger
.
info
(
f
"forward step for micro step
{
step_id
}
"
)
continue
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
self
.
_record_stamp
(
"F"
,
step_id
,
'"B"'
,
self
.
_forward_color
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
self
.
_record_stamp
(
"F"
,
step_id
,
'"E"'
,
self
.
_forward_color
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
input_buffers
.
append
(
input_tensor
)
...
...
@@ -340,13 +413,25 @@ class PipelineParallel(MetaParallelBase):
if
not
self
.
is_pipeline_last_stage
():
self
.
_release_output
(
output_tensor
)
if
steady_steps
>
0
:
if
steady_steps
>
0
and
not
static_scheduler
:
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
for
i
in
range
(
steady_steps
):
if
static_scheduler
:
schedule
+=
f
"f
{
startup_steps
+
i
}
;"
schedule
+=
f
"b
{
i
}
;"
logger
.
info
(
f
"forward step for micro step
{
startup_steps
+
i
}
"
)
logger
.
info
(
f
"backward step for micro step
{
i
}
"
)
continue
last_iter
=
i
==
(
steady_steps
-
1
)
self
.
_record_stamp
(
"F"
,
startup_steps
+
i
,
'"B"'
,
self
.
_forward_color
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
self
.
_record_stamp
(
"F"
,
startup_steps
+
i
,
'"E"'
,
self
.
_forward_color
)
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor
,
self
.
is_pipeline_last_stage
()
...
...
@@ -362,9 +447,11 @@ class PipelineParallel(MetaParallelBase):
0
),
output_buffers
.
pop
(
0
)
self
.
_record_stamp
(
"B"
,
i
,
'"B"'
,
self
.
_backward_color
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
self
.
_record_stamp
(
"B"
,
i
,
'"E"'
,
self
.
_backward_color
)
if
last_iter
:
input_tensor
=
None
...
...
@@ -377,6 +464,10 @@ class PipelineParallel(MetaParallelBase):
)
for
i
in
range
(
startup_steps
):
if
static_scheduler
:
schedule
+=
f
"b
{
steady_steps
+
i
}
;"
logger
.
info
(
f
"backward step for micro step
{
steady_steps
+
i
}
"
)
continue
input_tensor
=
input_buffers
.
pop
(
0
)
output_tensor
=
output_buffers
.
pop
(
0
)
...
...
@@ -384,11 +475,22 @@ class PipelineParallel(MetaParallelBase):
self
.
is_pipeline_last_stage
()
)
self
.
_record_stamp
(
"B"
,
steady_steps
+
i
,
'"B"'
,
self
.
_backward_color
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
self
.
_record_stamp
(
"B"
,
steady_steps
+
i
,
'"E"'
,
self
.
_backward_color
)
p2p
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
())
if
static_scheduler
:
return
schedule
self
.
_flush_records
()
if
self
.
_comm_overlap
:
assert
len
(
self
.
_comm_buffers
)
>
0
for
buffer
in
self
.
_comm_buffers
:
...
...
@@ -687,12 +789,32 @@ class PipelineParallel(MetaParallelBase):
elif
can_free
(
output
):
output
.
_clear_dataptr
()
def
get_static_scheduler
(
self
):
return
self
.
forward_backward_pipeline
(
data
=
None
,
static_scheduler
=
True
)
class
PipelineParallelWithInterleave
(
PipelineParallel
):
# pipeline parallel with interleave scheduler
def
__init__
(
self
,
layers
,
hcg
,
strategy
):
super
().
__init__
(
layers
=
layers
,
hcg
=
hcg
,
strategy
=
strategy
)
self
.
_record_format
=
(
'"name": "{}{}_VP{}", "cat": "virtual pipeline timeline", "ph": {}, "pid": 0, "tid": '
+
str
(
self
.
stage_id
+
1
)
+
', "ts": {}, "cname": "{}"'
)
self
.
_forward_colors
=
[
"thread_state_running"
,
# RGB: 126, 200, 148
"thread_state_unknown"
,
# RGB: 199, 155, 125
]
self
.
_backward_colors
=
[
"rail_load"
,
# RGB: 13, 168, 97
"rail_idle"
,
# RGB: 238, 142, 0
]
# Structures to record the micro step for each layer chunk
self
.
_forward_micro_step_counter
=
{}
self
.
_backward_micro_step_counter
=
{}
assert
layers
.
get_num_virtual_stages
()
>
1
assert
(
self
.
num_stages
>
2
...
...
@@ -710,6 +832,52 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert
len
(
self
.
model_chunks
)
==
self
.
num_model_chunks
self
.
_virtual_pp_world_size
=
self
.
num_model_chunks
self
.
_virtual_pp_rank
=
0
self
.
_reset_counter
()
def
_reset_counter
(
self
):
for
i
in
range
(
self
.
num_model_chunks
):
self
.
_forward_micro_step_counter
[
i
]
=
0
self
.
_backward_micro_step_counter
[
i
]
=
0
def
_record_stamp
(
self
,
name
,
step
,
phase
,
forward
=
True
):
if
self
.
_profiling
:
paddle
.
device
.
synchronize
()
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
step
,
forward
=
forward
)
color_idx
=
virtual_pp_rank
%
2
# Get the profile color and micro step for current layer chunk
if
forward
:
color
=
self
.
_forward_colors
[
color_idx
]
micro_step
=
self
.
_forward_micro_step_counter
[
virtual_pp_rank
]
if
phase
==
'"E"'
:
self
.
_forward_micro_step_counter
[
virtual_pp_rank
]
+=
1
else
:
color
=
self
.
_backward_colors
[
color_idx
]
micro_step
=
self
.
_backward_micro_step_counter
[
virtual_pp_rank
]
if
phase
==
'"E"'
:
self
.
_backward_micro_step_counter
[
virtual_pp_rank
]
+=
1
self
.
_records
.
append
(
'{'
+
self
.
_record_format
.
format
(
name
,
micro_step
,
virtual_pp_rank
,
phase
,
int
(
time
.
time
()
*
1000
),
color
,
)
+
'}'
)
def
_flush_records
(
self
):
if
self
.
_profiling
:
with
open
(
f
'./profile_record_tmp_file_for_rank_
{
self
.
global_rank
}
'
,
'a+'
,
)
as
f
:
for
record
in
self
.
_records
:
f
.
write
(
record
+
'
\n
'
)
self
.
_records
=
[]
self
.
_reset_counter
()
def
_get_virtual_pp_rank
(
self
,
micro_step
,
forward
):
virtual_pp_stage
=
micro_step
%
(
...
...
@@ -771,7 +939,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
return
input_tensor_grad
def
forward_backward_pipeline
(
self
,
data
,
scaler
,
forward_only
=
False
,
compute_loss
=
True
self
,
data
,
scaler
,
forward_only
=
False
,
compute_loss
=
True
,
static_scheduler
=
False
,
):
# use interleave scheduling strategy.
# this strategy is inspired by:
...
...
@@ -781,6 +954,22 @@ class PipelineParallelWithInterleave(PipelineParallel):
not
forward_only
),
"compute_loss can only be set to False when forward_only is set to True"
if
static_scheduler
:
assert
(
not
forward_only
),
"static_scheduler only for training not for eval"
assert
(
not
self
.
_profiling
),
"While _profiling, static scheduler is not available"
if
data
is
not
None
:
warnings
.
warn
(
"Static scheduler run won't real run the model, but data has been provided"
)
logger
.
info
(
"enable static_scheduler will return the pp schedule instead of the loss"
)
schedule
=
""
# init some attributes for this batch run
self
.
scaler
=
scaler
self
.
total_loss
=
None
...
...
@@ -810,13 +999,32 @@ class PipelineParallelWithInterleave(PipelineParallel):
steady_steps
=
num_steps
-
startup_steps
self
.
set_virtual_pipeline_rank
(
0
)
self
.
input_tensors
[
0
].
append
(
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
(),
sync_recv
=
False
)
)
if
not
static_scheduler
:
self
.
input_tensors
[
0
].
append
(
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
(),
sync_recv
=
False
)
)
# run startup steps
for
micro_step
in
range
(
startup_steps
):
if
static_scheduler
:
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
,
forward
=
True
)
real_micro_step
=
self
.
_forward_micro_step_counter
[
virtual_pp_rank
]
self
.
_forward_micro_step_counter
[
virtual_pp_rank
]
+=
1
schedule
+=
f
"f
{
real_micro_step
}
_vp
{
virtual_pp_rank
}
;"
logger
.
info
(
f
"forward step for
{
real_micro_step
}
with virtual pp rank
{
virtual_pp_rank
}
"
)
continue
self
.
_record_stamp
(
"F"
,
micro_step
,
'"B"'
,
forward
=
True
)
output_tensor
=
self
.
_forward_step_helper
(
micro_dataset
,
micro_step
)
self
.
_record_stamp
(
"F"
,
micro_step
,
'"E"'
,
forward
=
True
)
# determine whether recv forward tensor or not
next_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
...
...
@@ -867,17 +1075,55 @@ class PipelineParallelWithInterleave(PipelineParallel):
# run 1f1b steady steps
for
micro_step
in
range
(
steady_steps
):
if
static_scheduler
:
forward_micro_step_id
=
micro_step
+
startup_steps
forward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
forward_micro_step_id
,
forward
=
True
)
backward_micro_step_id
=
micro_step
backward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
backward_micro_step_id
,
forward
=
False
)
real_forward_micro_step
=
self
.
_forward_micro_step_counter
[
forward_virtual_pp_rank
]
self
.
_forward_micro_step_counter
[
forward_virtual_pp_rank
]
+=
1
real_backward_micro_step
=
self
.
_backward_micro_step_counter
[
backward_virtual_pp_rank
]
self
.
_backward_micro_step_counter
[
backward_virtual_pp_rank
]
+=
1
schedule
+=
(
f
"f
{
real_forward_micro_step
}
_vp
{
forward_virtual_pp_rank
}
;"
)
schedule
+=
(
f
"b
{
real_backward_micro_step
}
_vp
{
backward_virtual_pp_rank
}
;"
)
logger
.
info
(
f
"forward step for
{
real_forward_micro_step
}
with virtual pp rank
{
forward_virtual_pp_rank
}
"
)
logger
.
info
(
f
"backward step for
{
real_backward_micro_step
}
with virtual pp rank
{
backward_virtual_pp_rank
}
"
)
continue
# forward
forward_micro_step_id
=
micro_step
+
startup_steps
self
.
_record_stamp
(
"F"
,
forward_micro_step_id
,
'"B"'
,
forward
=
True
)
output_tensor
=
self
.
_forward_step_helper
(
micro_dataset
,
forward_micro_step_id
)
self
.
_record_stamp
(
"F"
,
forward_micro_step_id
,
'"E"'
,
forward
=
True
)
# backward
backward_micro_step_id
=
micro_step
self
.
_record_stamp
(
"B"
,
backward_micro_step_id
,
'"B"'
,
forward
=
False
)
input_tensor_grad
=
self
.
_backward_step_helper
(
backward_micro_step_id
)
self
.
_record_stamp
(
"B"
,
backward_micro_step_id
,
'"E"'
,
forward
=
False
)
# four directions comm
# send output tensor to downstream
...
...
@@ -946,13 +1192,29 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
self
.
_release_output
(
output_tensor
)
self
.
_release_output
(
output_tensor
)
if
not
static_scheduler
:
self
.
_release_output
(
output_tensor
)
# remaining backward steps
if
not
forward_only
:
for
micro_step
in
range
(
steady_steps
,
num_steps
):
if
static_scheduler
:
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
,
forward
=
False
)
real_micro_step
=
self
.
_backward_micro_step_counter
[
virtual_pp_rank
]
self
.
_backward_micro_step_counter
[
virtual_pp_rank
]
+=
1
schedule
+=
f
"b
{
real_micro_step
}
_vp
{
virtual_pp_rank
}
;"
logger
.
info
(
f
"backward step for
{
real_micro_step
}
with virtual pp rank
{
virtual_pp_rank
}
"
)
continue
# cooldown loop
self
.
_record_stamp
(
"B"
,
micro_step
,
'"B"'
,
forward
=
False
)
input_tensor_grad
=
self
.
_backward_step_helper
(
micro_step
)
self
.
_record_stamp
(
"B"
,
micro_step
,
'"E"'
,
forward
=
False
)
next_backward_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
+
1
,
forward
=
False
)
...
...
@@ -978,12 +1240,18 @@ class PipelineParallelWithInterleave(PipelineParallel):
for
buffer
in
self
.
_comm_buffers
:
buffer
.
scale_and_split_grads
()
if
static_scheduler
:
self
.
_reset_counter
()
return
schedule
if
self
.
_enable_timer
:
self
.
timers
(
"allreduce_shared_weight_gradients"
).
start
()
self
.
_layers
.
allreduce_shared_weight_gradients
()
if
self
.
_enable_timer
:
self
.
timers
(
"allreduce_shared_weight_gradients"
).
stop
()
self
.
_flush_records
()
if
compute_loss
:
# return loss if compute loss
if
self
.
_enable_timer
:
...
...
@@ -1018,3 +1286,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
_compute_loss
=
compute_loss
return
self
.
forward_backward_pipeline
(
data
,
None
,
forward_only
=
True
)
def
get_static_scheduler
(
self
):
return
self
.
forward_backward_pipeline
(
data
=
None
,
scaler
=
None
,
static_scheduler
=
True
)
python/paddle/distributed/fleet/meta_parallel/pp_utils/profiler_helper.py
0 → 100644
浏览文件 @
aac91e82
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
def
main
():
all_record
=
[]
all_files
=
os
.
listdir
(
'./'
)
all_files
=
sorted
(
filter
(
lambda
file
:
file
.
startswith
(
"profile_record_tmp_file_for_rank_"
),
all_files
,
)
)
for
files
in
all_files
:
with
open
(
files
,
'r'
)
as
f
:
for
line
in
f
:
all_record
.
append
(
line
.
strip
())
with
open
(
'pipeline_profile.json'
,
'w'
)
as
f
:
f
.
write
(
'[ '
)
for
i
in
range
(
len
(
all_record
)
-
1
):
f
.
write
(
all_record
[
i
]
+
',
\n
'
)
f
.
write
(
all_record
[
-
1
])
f
.
write
(
' ]
\n
'
)
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录