Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3a014783
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a014783
编写于
11月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the split logic for uniform (#47670) (#47705)
* code format change * update the split logic for uniform (#47670)
上级
d5809836
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
221 addition
and
130 deletion
+221
-130
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+221
-130
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
3a014783
...
@@ -57,7 +57,6 @@ __all__ = []
...
@@ -57,7 +57,6 @@ __all__ = []
class
LayerDesc
(
object
):
class
LayerDesc
(
object
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
self
.
layer_func
=
layer_func
self
.
layer_func
=
layer_func
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
@@ -65,25 +64,28 @@ class LayerDesc(object):
...
@@ -65,25 +64,28 @@ class LayerDesc(object):
if
not
issubclass
(
layer_func
,
Layer
):
if
not
issubclass
(
layer_func
,
Layer
):
raise
TypeError
(
raise
TypeError
(
"The input(layer_func) should be a derived class of Layer."
)
"The input(layer_func) should be a derived class of Layer."
)
def
build_layer
(
self
):
def
build_layer
(
self
):
return
self
.
layer_func
(
*
self
.
inputs
,
**
self
.
kwargs
)
return
self
.
layer_func
(
*
self
.
inputs
,
**
self
.
kwargs
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
layer_to_str
(
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
return
layer_to_str
(
**
self
.
kwargs
)
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
**
self
.
kwargs
)
class
SharedLayerDesc
(
LayerDesc
):
class
SharedLayerDesc
(
LayerDesc
):
def
__init__
(
def
__init__
(
self
,
self
,
key
,
key
,
layer_func
,
layer_func
,
forward_func
=
None
,
forward_func
=
None
,
shared_weight_attr
=
'weight'
,
shared_weight_attr
=
'weight'
,
*
inputs
,
*
inputs
,
**
kwargs
):
**
kwargs
):
super
(
SharedLayerDesc
,
self
).
__init__
(
layer_func
,
*
inputs
,
**
kwargs
)
super
(
SharedLayerDesc
,
self
).
__init__
(
layer_func
,
*
inputs
,
**
kwargs
)
self
.
layer_name
=
key
self
.
layer_name
=
key
self
.
forward_func
=
forward_func
self
.
forward_func
=
forward_func
...
@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
...
@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
class
SegmentLayers
(
object
):
class
SegmentLayers
(
object
):
def
__init__
(
def
__init__
(
self
,
self
,
layers_desc
,
layers_desc
,
num_parts
,
num_parts
,
method
=
"uniform"
,
method
=
"uniform"
,
num_virtual_pipeline_stage
=
None
):
num_virtual_pipeline_stage
=
None
,
):
self
.
_layers_desc
=
layers_desc
self
.
_layers_desc
=
layers_desc
self
.
method
=
method
self
.
method
=
method
self
.
num_parts
=
num_parts
self
.
num_parts
=
num_parts
...
@@ -104,7 +107,9 @@ class SegmentLayers(object):
...
@@ -104,7 +107,9 @@ class SegmentLayers(object):
self
.
num_virtual_pipeline_stage
=
num_virtual_pipeline_stage
self
.
num_virtual_pipeline_stage
=
num_virtual_pipeline_stage
if
self
.
num_virtual_pipeline_stage
is
not
None
:
if
self
.
num_virtual_pipeline_stage
is
not
None
:
self
.
total_parts
=
num_parts
*
self
.
num_virtual_pipeline_stage
self
.
total_parts
=
num_parts
*
self
.
num_virtual_pipeline_stage
assert
self
.
num_items
>=
self
.
num_parts
,
"layer number should be greater than number of segments"
assert
(
self
.
num_items
>=
self
.
num_parts
),
"layer number should be greater than number of segments"
def
do_segment
(
self
):
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
if
self
.
method
==
"uniform"
:
...
@@ -118,12 +123,17 @@ class SegmentLayers(object):
...
@@ -118,12 +123,17 @@ class SegmentLayers(object):
for
idx
in
weight_idxs
:
for
idx
in
weight_idxs
:
weights
[
idx
]
=
1
weights
[
idx
]
=
1
actual_num_parts
=
self
.
num_parts
if
self
.
num_virtual_pipeline_stage
is
None
else
self
.
total_parts
actual_num_parts
=
(
self
.
num_parts
if
self
.
num_virtual_pipeline_stage
is
None
else
self
.
total_parts
)
assert
sum
(
assert
(
weights
sum
(
weights
)
%
actual_num_parts
==
0
)
%
actual_num_parts
==
0
,
"number of layers ({}) should be divided by part number({})"
.
format
(
),
"number of layers ({}) should be divided by part number({})"
.
format
(
sum
(
weights
),
actual_num_parts
)
sum
(
weights
),
actual_num_parts
)
part_size
=
sum
(
weights
)
//
actual_num_parts
part_size
=
sum
(
weights
)
//
actual_num_parts
result
=
[
0
for
_
in
range
(
actual_num_parts
+
1
)]
result
=
[
0
for
_
in
range
(
actual_num_parts
+
1
)]
...
@@ -156,21 +166,23 @@ class SegmentLayers(object):
...
@@ -156,21 +166,23 @@ class SegmentLayers(object):
if
regex
.
search
(
name
):
if
regex
.
search
(
name
):
weight_idxs
.
append
(
idx
)
weight_idxs
.
append
(
idx
)
assert
len
(
assert
(
weight_idxs
)
>
0
,
"weight_idxs' length should be greater than 0"
len
(
weight_idxs
)
>
0
),
"weight_idxs' length should be greater than 0"
return
weight_idxs
return
weight_idxs
def
uniform
(
self
,
num_items
,
num_parts
):
def
uniform
(
self
,
num_items
,
num_parts
):
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
part_size
=
math
.
floor
(
num_items
/
num_parts
)
part_size
=
math
.
floor
(
num_items
/
num_parts
)
for
i
in
range
(
num_parts
):
extra_layers
=
num_items
%
num_parts
result
[
i
]
=
int
(
min
(
part_size
*
i
,
num_items
))
for
i
in
range
(
1
,
num_parts
):
offset
=
1
if
i
>
(
num_parts
-
extra_layers
)
else
0
result
[
i
]
=
int
(
min
(
result
[
i
-
1
]
+
part_size
+
offset
,
num_items
))
result
[
num_parts
]
=
num_items
result
[
num_parts
]
=
num_items
return
result
return
result
class
PipelineLayerChunk
(
Layer
):
class
PipelineLayerChunk
(
Layer
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
PipelineLayerChunk
,
self
).
__init__
()
super
(
PipelineLayerChunk
,
self
).
__init__
()
self
.
run_function
=
[]
self
.
run_function
=
[]
...
@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer):
...
@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer):
# behavior under recompute circumstance.
# behavior under recompute circumstance.
raise
PermissionError
(
raise
PermissionError
(
"The forward function of PipelineLayerChunk cannot be called directly. "
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer."
)
"Please call forward function of PipelineLayer."
)
class
PipelineLayer
(
Layer
):
class
PipelineLayer
(
Layer
):
...
@@ -274,7 +287,8 @@ class PipelineLayer(Layer):
...
@@ -274,7 +287,8 @@ class PipelineLayer(Layer):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
layers
,
layers
,
num_stages
=
None
,
num_stages
=
None
,
topology
=
None
,
topology
=
None
,
...
@@ -282,24 +296,32 @@ class PipelineLayer(Layer):
...
@@ -282,24 +296,32 @@ class PipelineLayer(Layer):
seg_method
=
"uniform"
,
seg_method
=
"uniform"
,
recompute_interval
=
0
,
recompute_interval
=
0
,
recompute_ctx
=
None
,
recompute_ctx
=
None
,
num_virtual_pipeline_stages
=
None
):
num_virtual_pipeline_stages
=
None
,
):
super
(
PipelineLayer
,
self
).
__init__
()
super
(
PipelineLayer
,
self
).
__init__
()
if
num_stages
is
None
and
topology
is
None
:
if
num_stages
is
None
and
topology
is
None
:
raise
ValueError
(
"should provide num_stages or topology"
)
raise
ValueError
(
"should provide num_stages or topology"
)
if
num_virtual_pipeline_stages
:
if
num_virtual_pipeline_stages
:
assert
isinstance
(
num_virtual_pipeline_stages
,
int
),
\
assert
isinstance
(
"virtual_pipeline_stage should be None or an int"
num_virtual_pipeline_stages
,
int
),
"virtual_pipeline_stage should be None or an int"
if
num_virtual_pipeline_stages
>
1
:
if
num_virtual_pipeline_stages
>
1
:
logger
.
info
(
logger
.
info
(
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
)
)
assert
isinstance
(
seg_method
,
str
),
\
assert
isinstance
(
"seg_method should be a str for interleave scheduler"
seg_method
,
str
assert
seg_method
.
startswith
(
'layer:'
),
\
),
"seg_method should be a str for interleave scheduler"
"seg_method shoud be start with layer: for interleave scheduler"
assert
seg_method
.
startswith
(
'layer:'
self
.
_num_virtual_pipeline_stages
=
1
if
num_virtual_pipeline_stages
is
None
else
num_virtual_pipeline_stages
),
"seg_method shoud be start with layer: for interleave scheduler"
self
.
_num_virtual_pipeline_stages
=
(
1
if
num_virtual_pipeline_stages
is
None
else
num_virtual_pipeline_stages
)
# lazy import
# lazy import
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
...
@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
...
@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
self
.
recompute_ctx
=
recompute_ctx
self
.
recompute_ctx
=
recompute_ctx
if
recompute_interval
>
0
:
if
recompute_interval
>
0
:
assert
recompute_ctx
is
not
None
,
"recompute_ctx must be not None for recompute."
assert
(
recompute_ctx
is
not
None
),
"recompute_ctx must be not None for recompute."
offload
=
recompute_ctx
.
get
(
'offload'
,
False
)
offload
=
recompute_ctx
.
get
(
'offload'
,
False
)
partition
=
recompute_ctx
.
get
(
'partition'
,
False
)
partition
=
recompute_ctx
.
get
(
'partition'
,
False
)
logger
.
info
(
logger
.
info
(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.
format
(
.
format
(
offload
,
partition
))
offload
,
partition
)
)
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
self
.
global_rank
=
dist
.
get_rank
()
self
.
global_rank
=
dist
.
get_rank
()
...
@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
...
@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
if
num_stages
:
if
num_stages
:
assert
self
.
_num_stages
==
num_stages
,
"num_stages should be equal to be %d"
%
(
assert
(
self
.
_num_stages
)
self
.
_num_stages
==
num_stages
),
"num_stages should be equal to be %d"
%
(
self
.
_num_stages
)
else
:
else
:
# construct default topology
# construct default topology
if
world_size
%
num_stages
!=
0
:
if
world_size
%
num_stages
!=
0
:
raise
ValueError
(
raise
ValueError
(
"should provide correct num_stages({}) "
"should provide correct num_stages({}) "
"which can be divided by world_size({})"
.
format
(
"which can be divided by world_size({})"
.
format
(
num_stages
,
world_size
))
num_stages
,
world_size
)
)
dp_num
=
world_size
//
num_stages
dp_num
=
world_size
//
num_stages
self
.
_topo
=
fleet
.
CommunicateTopology
([
"data"
,
"pipe"
,
"model"
],
self
.
_topo
=
fleet
.
CommunicateTopology
(
[
dp_num
,
num_stages
,
1
])
[
"data"
,
"pipe"
,
"model"
],
[
dp_num
,
num_stages
,
1
]
)
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_total_stages_with_virtual_stages
=
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
self
.
_total_stages_with_virtual_stages
=
(
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
)
# initialize segment
# initialize segment
self
.
_layers_desc
=
list
(
self
.
layers
)
self
.
_layers_desc
=
list
(
self
.
layers
)
...
@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
...
@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
start_idx
=
virtual_pp_rank
*
self
.
_num_stages
start_idx
=
virtual_pp_rank
*
self
.
_num_stages
for
stage
in
range
(
self
.
_num_stages
):
for
stage
in
range
(
self
.
_num_stages
):
# stage mark the real pp stage
# stage mark the real pp stage
if
self
.
segment_parts
[
start_idx
+
if
(
stage
]
<=
layer_idx
<
self
.
segment_parts
[
self
.
segment_parts
[
start_idx
+
stage
]
start_idx
+
stage
+
1
]:
<=
layer_idx
<
self
.
segment_parts
[
start_idx
+
stage
+
1
]
):
return
stage
return
stage
def
get_num_virtual_stages
(
self
):
def
get_num_virtual_stages
(
self
):
return
self
.
_num_virtual_pipeline_stages
return
self
.
_num_virtual_pipeline_stages
def
get_model_chunks
(
self
):
def
get_model_chunks
(
self
):
return
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
return
(
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
)
def
_construct_shared_comm
(
self
):
def
_construct_shared_comm
(
self
):
shared_comm
=
{}
shared_comm
=
{}
...
@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
...
@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
return
return
layers_desc
=
self
.
_layers_desc
layers_desc
=
self
.
_layers_desc
shared_layer_names
=
set
(
s
.
layer_name
for
s
in
layers_desc
shared_layer_names
=
set
(
if
isinstance
(
s
,
SharedLayerDesc
))
s
.
layer_name
for
s
in
layers_desc
if
isinstance
(
s
,
SharedLayerDesc
)
)
for
key
in
shared_layer_names
:
for
key
in
shared_layer_names
:
shared_layers
=
[]
shared_layers
=
[]
for
idx
,
layer
in
enumerate
(
layers_desc
):
for
idx
,
layer
in
enumerate
(
layers_desc
):
if
isinstance
(
layer
,
if
(
SharedLayerDesc
)
and
layer
.
layer_name
==
key
:
isinstance
(
layer
,
SharedLayerDesc
)
and
layer
.
layer_name
==
key
):
shared_layers
.
append
(
idx
)
shared_layers
.
append
(
idx
)
shared_stages
=
set
(
shared_stages
=
set
(
self
.
get_stage_from_index
(
idx
)
for
idx
in
shared_layers
)
self
.
get_stage_from_index
(
idx
)
for
idx
in
shared_layers
)
self
.
_dp_degree
=
self
.
_topo
.
get_dim
(
'data'
)
self
.
_dp_degree
=
self
.
_topo
.
get_dim
(
'data'
)
self
.
_mp_degree
=
self
.
_topo
.
get_dim
(
'model'
)
self
.
_mp_degree
=
self
.
_topo
.
get_dim
(
'model'
)
self
.
_sharding_degree
=
self
.
_topo
.
get_dim
(
'sharding'
)
self
.
_sharding_degree
=
self
.
_topo
.
get_dim
(
'sharding'
)
...
@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
...
@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
pipe
=
s
,
pipe
=
s
,
data
=
dp
,
data
=
dp
,
sharding
=
sharding
,
sharding
=
sharding
,
model
=
mp
))
model
=
mp
,
)
)
group
=
paddle
.
distributed
.
new_group
(
ranks
=
shared_ranks
)
group
=
paddle
.
distributed
.
new_group
(
ranks
=
shared_ranks
)
if
self
.
global_rank
in
shared_ranks
:
if
self
.
global_rank
in
shared_ranks
:
...
@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
...
@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
shared_comm
[
key
]
=
{
shared_comm
[
key
]
=
{
'ranks'
:
shared_ranks
,
'ranks'
:
shared_ranks
,
'group'
:
group
,
'group'
:
group
,
'weight_attr'
:
'weight_attr'
:
self
.
shared_weight_attrs
[
self
.
shared_weight_attrs
[
key
],
key
],
'layer'
:
self
.
shared_layers
[
key
],
'layer'
:
self
.
shared_layers
[
key
],
}
}
return
shared_comm
return
shared_comm
...
@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
...
@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
def
_synchronize_shared_weights
(
self
):
def
_synchronize_shared_weights
(
self
):
for
key
,
comm
in
self
.
shared_comm
.
items
():
for
key
,
comm
in
self
.
shared_comm
.
items
():
with
paddle
.
framework
.
no_grad
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
broadcast
(
getattr
(
comm
[
'layer'
],
paddle
.
distributed
.
broadcast
(
comm
[
'weight_attr'
]),
getattr
(
comm
[
'layer'
],
comm
[
'weight_attr'
]),
src
=
min
(
comm
[
'ranks'
]),
src
=
min
(
comm
[
'ranks'
]),
group
=
comm
[
'group'
])
group
=
comm
[
'group'
],
)
for
param
in
comm
[
'layer'
].
parameters
():
for
param
in
comm
[
'layer'
].
parameters
():
if
self
.
global_rank
!=
min
(
comm
[
'ranks'
]):
if
self
.
global_rank
!=
min
(
comm
[
'ranks'
]):
...
@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
...
@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
# need use trace_op to allreduce weight
# need use trace_op to allreduce weight
if
in_dygraph_mode
():
if
in_dygraph_mode
():
with
paddle
.
framework
.
no_grad
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
all_reduce
(
param
.
grad
,
paddle
.
distributed
.
all_reduce
(
group
=
comm
[
'group'
])
param
.
grad
,
group
=
comm
[
'group'
]
)
else
:
else
:
with
paddle
.
framework
.
no_grad
():
with
paddle
.
framework
.
no_grad
():
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
...
@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
...
@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
attrs
=
{
attrs
=
{
'ring_id'
:
comm
[
'group'
].
id
,
'ring_id'
:
comm
[
'group'
].
id
,
'use_calc_stream'
:
True
'use_calc_stream'
:
True
,
})
},
)
def
_segment_network_for_interleave
(
self
,
seg_method
):
def
_segment_network_for_interleave
(
self
,
seg_method
):
logger
.
info
(
"start segment network for interleave scheduler"
)
logger
.
info
(
"start segment network for interleave scheduler"
)
...
@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
...
@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
self
.
_layers_desc
,
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
,
method
=
seg_method
,
num_virtual_pipeline_stage
=
self
.
_num_virtual_pipeline_stages
)
num_virtual_pipeline_stage
=
self
.
_num_virtual_pipeline_stages
,
)
self
.
segment_parts
=
seg
.
do_segment
()
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
logger
.
info
(
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
)
)
for
i
in
range
(
self
.
_stage_id
,
self
.
_total_stages_with_virtual_stages
,
for
i
in
range
(
self
.
_num_stages
):
self
.
_stage_id
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
,
):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
...
@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
...
@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
def
_segment_network
(
self
,
seg_method
):
def
_segment_network
(
self
,
seg_method
):
logger
.
info
(
"start segment network.."
)
logger
.
info
(
"start segment network.."
)
seg
=
SegmentLayers
(
self
.
_layers_desc
,
seg
=
SegmentLayers
(
num_parts
=
self
.
_num_stages
,
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
method
=
seg_method
)
)
self
.
segment_parts
=
seg
.
do_segment
()
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
logger
.
info
(
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
)
)
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
...
@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
...
@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
def
_print_segmentation_for_debug
(
self
):
def
_print_segmentation_for_debug
(
self
):
# print information for debug
# print information for debug
for
stage
in
range
(
self
.
_num_stages
*
for
stage
in
range
(
self
.
_num_virtual_pipeline_stages
):
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
):
start
=
self
.
segment_parts
[
stage
]
start
=
self
.
segment_parts
[
stage
]
end
=
self
.
segment_parts
[
stage
+
1
]
end
=
self
.
segment_parts
[
stage
+
1
]
logger
.
info
(
"stage={}, global_rank={} ,layer_number={}"
.
format
(
logger
.
info
(
stage
,
self
.
global_rank
,
end
-
start
))
"stage={}, global_rank={} ,layer_number={}"
.
format
(
stage
,
self
.
global_rank
,
end
-
start
)
)
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
logger
.
info
(
"{}: {}"
.
format
(
index
+
start
,
str
(
layer
)))
logger
.
info
(
"{}: {}"
.
format
(
index
+
start
,
str
(
layer
)))
if
self
.
_num_virtual_pipeline_stages
>
1
:
if
self
.
_num_virtual_pipeline_stages
>
1
:
for
stage
in
range
(
self
.
_num_stages
):
for
stage
in
range
(
self
.
_num_stages
):
stage_to_virtual_stage_info
=
"stage {} contains virtual stages: "
.
format
(
stage_to_virtual_stage_info
=
(
stage
)
"stage {} contains virtual stages: "
.
format
(
stage
)
for
i
in
range
(
stage
,
self
.
_total_stages_with_virtual_stages
,
)
self
.
_num_stages
):
for
i
in
range
(
stage
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
,
):
stage_to_virtual_stage_info
+=
" {},"
.
format
(
i
)
stage_to_virtual_stage_info
+=
" {},"
.
format
(
i
)
logger
.
info
(
stage_to_virtual_stage_info
)
logger
.
info
(
stage_to_virtual_stage_info
)
...
@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
...
@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
if
layer
.
layer_name
not
in
self
.
shared_layers
:
if
layer
.
layer_name
not
in
self
.
shared_layers
:
self
.
shared_layers
[
layer
.
layer_name
]
=
layer
.
build_layer
()
self
.
shared_layers
[
layer
.
layer_name
]
=
layer
.
build_layer
()
self
.
shared_weight_attrs
[
self
.
shared_weight_attrs
[
layer
.
layer_name
]
=
layer
.
shared_weight_attr
layer
.
layer_name
]
=
layer
.
shared_weight_attr
for
param
in
self
.
shared_layers
[
for
param
in
self
.
shared_layers
[
layer
.
layer_name
].
parameters
():
layer
.
layer_name
].
parameters
():
setattr
(
param
,
"is_firstly_shared"
,
True
)
setattr
(
param
,
"is_firstly_shared"
,
True
)
if
layer
.
forward_func
is
None
:
if
layer
.
forward_func
is
None
:
...
@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
...
@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
else
:
else
:
run_function
.
append
(
run_function
.
append
(
partial
(
layer
.
forward_func
,
partial
(
self
.
shared_layers
[
layer
.
layer_name
]))
layer
.
forward_func
,
self
.
shared_layers
[
layer
.
layer_name
],
)
)
elif
isinstance
(
layer
,
LayerDesc
):
elif
isinstance
(
layer
,
LayerDesc
):
model
=
layer
.
build_layer
()
model
=
layer
.
build_layer
()
...
@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
...
@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
def
forward
(
self
,
input
,
chunk_id
=
None
):
def
forward
(
self
,
input
,
chunk_id
=
None
):
if
chunk_id
is
not
None
:
if
chunk_id
is
not
None
:
assert
isinstance
(
chunk_id
,
int
),
"chunk_id should be an int"
assert
isinstance
(
chunk_id
,
int
),
"chunk_id should be an int"
assert
self
.
_num_virtual_pipeline_stages
>
1
,
\
assert
(
"chunk_id is only valid when using virtual pipeline stage"
self
.
_num_virtual_pipeline_stages
>
1
assert
chunk_id
<
len
(
self
.
_model_chunks
),
\
),
"chunk_id is only valid when using virtual pipeline stage"
"The virtual pipeline only has {} chunks, "
\
assert
chunk_id
<
len
(
self
.
_model_chunks
),
(
"but received chunk_id {}."
.
format
(
len
(
self
.
_model_chunks
),
chunk_id
)
"The virtual pipeline only has {} chunks, "
"but received chunk_id {}."
.
format
(
len
(
self
.
_model_chunks
),
chunk_id
)
)
# Get the target model chunk.
# Get the target model chunk.
model_chunk
=
self
.
_model_chunks
[
chunk_id
]
model_chunk
=
self
.
_model_chunks
[
chunk_id
]
# Update the self.run_function to the target run functions.
# Update the self.run_function to the target run functions.
...
@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
...
@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
funcs
=
self
.
run_function
[
start_idx
:
end_idx
]
funcs
=
self
.
run_function
[
start_idx
:
end_idx
]
if
not
isinstance
(
input
,
tuple
):
if
not
isinstance
(
input
,
tuple
):
input
=
(
input
,
)
input
=
(
input
,)
if
self
.
_need_recompute
(
funcs
,
input
):
if
self
.
_need_recompute
(
funcs
,
input
):
input
=
recompute_hybrid
(
input
=
recompute_hybrid
(
self
.
recompute_ctx
,
self
.
recompute_ctx
,
self
.
forward_function
(
start_idx
,
end_idx
),
*
input
)
self
.
forward_function
(
start_idx
,
end_idx
),
*
input
)
else
:
else
:
input
=
self
.
forward_function
(
start_idx
,
end_idx
)(
*
input
)
input
=
self
.
forward_function
(
start_idx
,
end_idx
)(
*
input
)
return
input
return
input
def
_need_recompute
(
self
,
funcs
,
inputs
):
def
_need_recompute
(
self
,
funcs
,
inputs
):
if
not
any
(
input_
.
stop_gradient
==
False
if
not
any
(
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)):
input_
.
stop_gradient
==
False
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)
):
return
False
return
False
params
=
[
f
.
parameters
()
for
f
in
funcs
if
isinstance
(
f
,
Layer
)]
params
=
[
f
.
parameters
()
for
f
in
funcs
if
isinstance
(
f
,
Layer
)]
...
@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
...
@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
if
self
.
_num_virtual_pipeline_stages
>
1
:
if
self
.
_num_virtual_pipeline_stages
>
1
:
# add virtual pipeline info to the save path
# add virtual pipeline info to the save path
assert
local_chunk_id
is
not
None
assert
local_chunk_id
is
not
None
virtual_pipeline_stage_message
=
"-virtual_pp_stage_{:0>2d}"
.
format
(
virtual_pipeline_stage_message
=
(
local_chunk_id
)
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
layer_save_path
=
os
.
path
.
join
(
ckpt_dir
,
)
'layer_{:0>2d}'
.
format
(
idx
))
layer_save_path
=
os
.
path
.
join
(
layer_save_path
=
layer_save_path
+
virtual_pipeline_stage_message
+
rank_message
+
'-model_states.pdparams'
ckpt_dir
,
'layer_{:0>2d}'
.
format
(
idx
)
)
layer_save_path
=
(
layer_save_path
+
virtual_pipeline_stage_message
+
rank_message
+
'-model_states.pdparams'
)
return
layer_save_path
return
layer_save_path
def
_save_model
(
run_functions
,
local_chunk_id
=
None
):
def
_save_model
(
run_functions
,
local_chunk_id
=
None
):
...
@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
...
@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
def
set_state_dir
(
self
,
path
):
def
set_state_dir
(
self
,
path
):
assert
os
.
path
.
exists
(
assert
os
.
path
.
exists
(
path
),
"{} not found, please check the path"
.
format
(
path
)
path
),
"{} not found, please check the path"
.
format
(
path
)
def
_load_model
(
run_functions
,
local_chunk_id
=
None
):
def
_load_model
(
run_functions
,
local_chunk_id
=
None
):
for
idx
,
layer
in
enumerate
(
run_functions
):
for
idx
,
layer
in
enumerate
(
run_functions
):
...
@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
...
@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
pos_offset
=
self
.
_start_poss
[
local_chunk_id
]
pos_offset
=
self
.
_start_poss
[
local_chunk_id
]
layer_idx
=
idx
+
pos_offset
layer_idx
=
idx
+
pos_offset
layer_save_path
=
os
.
path
.
join
(
layer_save_path
=
os
.
path
.
join
(
path
,
'layer_{0:0>2d}'
.
format
(
layer_idx
))
path
,
'layer_{0:0>2d}'
.
format
(
layer_idx
)
)
if
self
.
_num_virtual_pipeline_stages
>
1
:
if
self
.
_num_virtual_pipeline_stages
>
1
:
# add virtual pipeline info to the path
# add virtual pipeline info to the path
assert
local_chunk_id
is
not
None
assert
local_chunk_id
is
not
None
layer_save_path
=
layer_save_path
+
"-virtual_pp_stage_{:0>2d}"
.
format
(
layer_save_path
=
(
local_chunk_id
)
layer_save_path
model_files
=
glob
.
glob
(
layer_save_path
+
+
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
"*model_states.pdparams"
)
)
model_files
=
glob
.
glob
(
layer_save_path
+
"*model_states.pdparams"
)
model_files
.
sort
()
model_files
.
sort
()
mp_rank
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
model
mp_rank
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
model
mp_world_size
=
self
.
_topo
.
get_dim
(
'model'
)
mp_world_size
=
self
.
_topo
.
get_dim
(
'model'
)
num_files
=
len
(
model_files
)
num_files
=
len
(
model_files
)
load_param_path
=
model_files
[
mp_rank
*
num_files
//
load_param_path
=
model_files
[
mp_world_size
]
mp_rank
*
num_files
//
mp_world_size
]
model_state_dict
=
paddle
.
load
(
load_param_path
)
model_state_dict
=
paddle
.
load
(
load_param_path
)
layer
.
set_state_dict
(
model_state_dict
)
layer
.
set_state_dict
(
model_state_dict
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录