Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
875c1f6d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
875c1f6d
编写于
9月 05, 2023
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add chunk timer
上级
26a83ed1
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
102 addition
and
2 deletion
+102
-2
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+61
-1
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+40
-1
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
875c1f6d
...
...
@@ -63,6 +63,7 @@ message PpConfig {
optional
bool
delay_scale_loss
=
2
[
default
=
false
];
optional
bool
enable_timer
=
3
[
default
=
false
];
optional
bool
sharding_comm_overlap
=
4
[
default
=
false
];
optional
bool
enable_chunk_timer
=
5
[
default
=
false
];
}
message
HybridConfig
{
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
875c1f6d
...
...
@@ -13,6 +13,7 @@
import
os
import
sys
import
time
from
collections
import
defaultdict
import
paddle
...
...
@@ -36,6 +37,38 @@ __all__ = []
g_shard_use_reduce
=
int
(
os
.
environ
.
get
(
"FLAGS_shard_use_reduce"
,
1
))
class
ChunkTimer
:
def
__init__
(
self
,
group
):
self
.
rank
=
group
.
get_group_rank
(
paddle
.
distributed
.
get_rank
())
self
.
group
=
group
self
.
reset
()
def
begin
(
self
):
paddle
.
distributed
.
barrier
(
self
.
group
)
self
.
reset
()
def
reset
(
self
):
self
.
begin
=
time
.
time
()
self
.
records
.
clear
()
def
start
(
self
,
name
):
paddle
.
device
.
cuda
.
synchronize
()
t
=
time
.
time
()
self
.
records
.
append
([
name
,
t
,
None
])
def
end
(
self
,
name
):
paddle
.
device
.
cuda
.
synchronize
()
t
=
time
.
time
()
self
.
records
[
-
1
][
-
1
]
=
t
def
export_info
(
self
):
return
{
"rank"
:
self
.
rank
,
"begin"
:
self
.
begin
,
"records"
:
self
.
records
,
}
# assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader
class
FakeMicroDataset
:
...
...
@@ -151,6 +184,14 @@ class PipelineParallel(MetaParallelBase):
]
self
.
_using_cache
=
self
.
_strategy
.
pipeline_configs
[
'p2p_cache_shape'
]
self
.
_enable_chunk_timer
=
self
.
_strategy
.
pipeline_configs
[
'enable_chunk_timer'
]
if
self
.
_enable_chunk_timer
:
self
.
_chunk_timer
=
ChunkTimer
()
else
:
self
.
_chunk_timer
=
None
self
.
num_stages
=
self
.
_hcg
.
get_pipe_parallel_world_size
()
self
.
stage_id
=
self
.
_hcg
.
get_stage_id
()
self
.
pp_group
=
self
.
_hcg
.
get_pipe_parallel_group
()
...
...
@@ -204,7 +245,7 @@ class PipelineParallel(MetaParallelBase):
)
# construct pipeline meta info
self
.
_p2p_helper
=
p2p
.
P2pHelper
(
self
.
_using_cache
)
self
.
_p2p_helper
=
p2p
.
P2pHelper
(
self
.
_using_cache
,
self
.
_chunk_timer
)
self
.
global_rank
=
self
.
_hcg
.
get_global_rank
()
self
.
micro_batch_id
=
0
...
...
@@ -234,6 +275,12 @@ class PipelineParallel(MetaParallelBase):
self
.
_layers
,
self
.
dp_group
,
self
.
accumulate_steps
,
True
)
def
_export_chunk_timer_info
(
self
):
if
self
.
_chunk_timer
is
not
None
:
return
self
.
_chunk_timer
.
export_info
()
else
:
return
None
def
is_pipeline_first_stage
(
self
,
ignore_virtual
=
False
):
if
not
ignore_virtual
:
if
self
.
_virtual_pp_world_size
is
not
None
:
...
...
@@ -333,6 +380,8 @@ class PipelineParallel(MetaParallelBase):
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
begin
()
self
.
scaler
=
scaler
# store total loss of entire batch
...
...
@@ -564,6 +613,8 @@ class PipelineParallel(MetaParallelBase):
return
self
.
train_loss
def
_forward_step
(
self
,
input_tensor
,
micro_dataset
,
chunk_id
=
None
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"forward_step"
)
if
self
.
_enable_timer
:
self
.
timers
(
"forward_step"
).
start
()
if
self
.
is_pipeline_first_stage
():
...
...
@@ -601,6 +652,8 @@ class PipelineParallel(MetaParallelBase):
self
.
micro_batch_id
+=
1
if
self
.
_enable_timer
:
self
.
timers
(
"forward_step"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
output_tensor
def
_check_micro_batch_data_valid
(
self
,
micro_batch_data
):
...
...
@@ -614,6 +667,8 @@ class PipelineParallel(MetaParallelBase):
),
f
"expected micro_batch_size
{
self
.
micro_batch_size
}
but get
{
micro_batch_size
}
"
def
_backward_step
(
self
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"backward_step"
)
if
self
.
_enable_timer
:
self
.
timers
(
"backward_step"
).
start
()
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
...
...
@@ -647,6 +702,8 @@ class PipelineParallel(MetaParallelBase):
input_tensor_grad
=
input_tensor
.
grad
if
self
.
_enable_timer
:
self
.
timers
(
"backward_step"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
input_tensor_grad
def
_broadcast_final_loss
(
self
):
...
...
@@ -884,6 +941,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
_using_cache
),
"cache should be enabled for pipeline with interleave"
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
begin
()
# init some attributes for this batch run
self
.
scaler
=
scaler
self
.
total_loss
=
None
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
875c1f6d
...
...
@@ -455,9 +455,10 @@ def _p2p_helper(
class
P2pHelper
:
def
__init__
(
self
,
use_cache
=
True
):
def
__init__
(
self
,
use_cache
=
True
,
chunk_timer
=
None
):
self
.
_send_recv_meta
=
SendRecvMeta
()
self
.
_use_cache
=
use_cache
self
.
_chunk_timer
=
None
def
_send_meta
(
self
,
output_tensor
):
if
not
self
.
_send_recv_meta
.
has_send_meta
:
...
...
@@ -473,6 +474,8 @@ class P2pHelper:
self
.
_send_recv_meta
.
has_recv_meta
=
self
.
_use_cache
def
recv_forward
(
self
,
pp_first_stage
,
sync_recv
=
True
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"recv_forward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
start
()
...
...
@@ -491,9 +494,13 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
input_tensor
def
recv_backward
(
self
,
pp_last_stage
,
sync_recv
=
True
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"recv_backward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
start
()
...
...
@@ -510,9 +517,13 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
output_tensor_grad
def
send_forward
(
self
,
output_tensor
,
pp_last_stage
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_forward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
start
()
...
...
@@ -528,8 +539,12 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
def
send_backward
(
self
,
input_tensor_grad
,
pp_first_stage
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_backward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
start
()
...
...
@@ -543,8 +558,12 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
def
send_forward_recv_backward
(
self
,
output_tensor
,
pp_last_stage
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_forward_recv_backward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
start
()
...
...
@@ -560,9 +579,13 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
output_tensor_grad
def
send_backward_recv_forward
(
self
,
input_tensor_grad
,
pp_first_stage
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_backward_recv_forward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
start
()
...
...
@@ -578,11 +601,17 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
input_tensor
def
send_forward_backward_recv_forward_backward
(
self
,
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_forward_backward_recv_forward_backward"
)
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
...
...
@@ -602,10 +631,14 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
input_tensor
,
output_tensor_grad
def
send_forward_recv_forward
(
self
,
output_tensor
,
recv_prev
):
# always have to send dytpe info to downstream
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_forward_recv_forward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
start
()
...
...
@@ -624,9 +657,13 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
input_tensor
def
send_backward_recv_backward
(
self
,
input_tensor_grad
,
recv_next
):
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
start
(
"send_backward_recv_backward"
)
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
start
()
...
...
@@ -640,6 +677,8 @@ class P2pHelper:
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
stop
()
if
self
.
_chunk_timer
is
not
None
:
self
.
_chunk_timer
.
end
()
return
output_tensor_grad
def
__repr__
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录