Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3a014783
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a014783
编写于
11月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the split logic for uniform (#47670) (#47705)
* code format change * update the split logic for uniform (#47670)
上级
d5809836
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
221 addition
and
130 deletion
+221
-130
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+221
-130
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
3a014783
...
@@ -57,7 +57,6 @@ __all__ = []
...
@@ -57,7 +57,6 @@ __all__ = []
class
LayerDesc
(
object
):
class
LayerDesc
(
object
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
self
.
layer_func
=
layer_func
self
.
layer_func
=
layer_func
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
@@ -65,25 +64,28 @@ class LayerDesc(object):
...
@@ -65,25 +64,28 @@ class LayerDesc(object):
if
not
issubclass
(
layer_func
,
Layer
):
if
not
issubclass
(
layer_func
,
Layer
):
raise
TypeError
(
raise
TypeError
(
"The input(layer_func) should be a derived class of Layer."
)
"The input(layer_func) should be a derived class of Layer."
)
def
build_layer
(
self
):
def
build_layer
(
self
):
return
self
.
layer_func
(
*
self
.
inputs
,
**
self
.
kwargs
)
return
self
.
layer_func
(
*
self
.
inputs
,
**
self
.
kwargs
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
layer_to_str
(
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
return
layer_to_str
(
**
self
.
kwargs
)
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
**
self
.
kwargs
)
class
SharedLayerDesc
(
LayerDesc
):
class
SharedLayerDesc
(
LayerDesc
):
def
__init__
(
def
__init__
(
self
,
self
,
key
,
key
,
layer_func
,
layer_func
,
forward_func
=
None
,
forward_func
=
None
,
shared_weight_attr
=
'weight'
,
shared_weight_attr
=
'weight'
,
*
inputs
,
*
inputs
,
**
kwargs
):
**
kwargs
):
super
(
SharedLayerDesc
,
self
).
__init__
(
layer_func
,
*
inputs
,
**
kwargs
)
super
(
SharedLayerDesc
,
self
).
__init__
(
layer_func
,
*
inputs
,
**
kwargs
)
self
.
layer_name
=
key
self
.
layer_name
=
key
self
.
forward_func
=
forward_func
self
.
forward_func
=
forward_func
...
@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
...
@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
class
SegmentLayers
(
object
):
class
SegmentLayers
(
object
):
def
__init__
(
def
__init__
(
self
,
self
,
layers_desc
,
layers_desc
,
num_parts
,
num_parts
,
method
=
"uniform"
,
method
=
"uniform"
,
num_virtual_pipeline_stage
=
None
):
num_virtual_pipeline_stage
=
None
,
):
self
.
_layers_desc
=
layers_desc
self
.
_layers_desc
=
layers_desc
self
.
method
=
method
self
.
method
=
method
self
.
num_parts
=
num_parts
self
.
num_parts
=
num_parts
...
@@ -104,7 +107,9 @@ class SegmentLayers(object):
...
@@ -104,7 +107,9 @@ class SegmentLayers(object):
self
.
num_virtual_pipeline_stage
=
num_virtual_pipeline_stage
self
.
num_virtual_pipeline_stage
=
num_virtual_pipeline_stage
if
self
.
num_virtual_pipeline_stage
is
not
None
:
if
self
.
num_virtual_pipeline_stage
is
not
None
:
self
.
total_parts
=
num_parts
*
self
.
num_virtual_pipeline_stage
self
.
total_parts
=
num_parts
*
self
.
num_virtual_pipeline_stage
assert
self
.
num_items
>=
self
.
num_parts
,
"layer number should be greater than number of segments"
assert
(
self
.
num_items
>=
self
.
num_parts
),
"layer number should be greater than number of segments"
def
do_segment
(
self
):
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
if
self
.
method
==
"uniform"
:
...
@@ -118,12 +123,17 @@ class SegmentLayers(object):
...
@@ -118,12 +123,17 @@ class SegmentLayers(object):
for
idx
in
weight_idxs
:
for
idx
in
weight_idxs
:
weights
[
idx
]
=
1
weights
[
idx
]
=
1
actual_num_parts
=
self
.
num_parts
if
self
.
num_virtual_pipeline_stage
is
None
else
self
.
total_parts
actual_num_parts
=
(
self
.
num_parts
if
self
.
num_virtual_pipeline_stage
is
None
else
self
.
total_parts
)
assert
sum
(
assert
(
weights
sum
(
weights
)
%
actual_num_parts
==
0
)
%
actual_num_parts
==
0
,
"number of layers ({}) should be divided by part number({})"
.
format
(
),
"number of layers ({}) should be divided by part number({})"
.
format
(
sum
(
weights
),
actual_num_parts
)
sum
(
weights
),
actual_num_parts
)
part_size
=
sum
(
weights
)
//
actual_num_parts
part_size
=
sum
(
weights
)
//
actual_num_parts
result
=
[
0
for
_
in
range
(
actual_num_parts
+
1
)]
result
=
[
0
for
_
in
range
(
actual_num_parts
+
1
)]
...
@@ -156,21 +166,23 @@ class SegmentLayers(object):
...
@@ -156,21 +166,23 @@ class SegmentLayers(object):
if
regex
.
search
(
name
):
if
regex
.
search
(
name
):
weight_idxs
.
append
(
idx
)
weight_idxs
.
append
(
idx
)
assert
len
(
assert
(
weight_idxs
)
>
0
,
"weight_idxs' length should be greater than 0"
len
(
weight_idxs
)
>
0
),
"weight_idxs' length should be greater than 0"
return
weight_idxs
return
weight_idxs
def
uniform
(
self
,
num_items
,
num_parts
):
def
uniform
(
self
,
num_items
,
num_parts
):
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
part_size
=
math
.
floor
(
num_items
/
num_parts
)
part_size
=
math
.
floor
(
num_items
/
num_parts
)
for
i
in
range
(
num_parts
):
extra_layers
=
num_items
%
num_parts
result
[
i
]
=
int
(
min
(
part_size
*
i
,
num_items
))
for
i
in
range
(
1
,
num_parts
):
offset
=
1
if
i
>
(
num_parts
-
extra_layers
)
else
0
result
[
i
]
=
int
(
min
(
result
[
i
-
1
]
+
part_size
+
offset
,
num_items
))
result
[
num_parts
]
=
num_items
result
[
num_parts
]
=
num_items
return
result
return
result
class
PipelineLayerChunk
(
Layer
):
class
PipelineLayerChunk
(
Layer
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
PipelineLayerChunk
,
self
).
__init__
()
super
(
PipelineLayerChunk
,
self
).
__init__
()
self
.
run_function
=
[]
self
.
run_function
=
[]
...
@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer):
...
@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer):
# behavior under recompute circumstance.
# behavior under recompute circumstance.
raise
PermissionError
(
raise
PermissionError
(
"The forward function of PipelineLayerChunk cannot be called directly. "
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer."
)
"Please call forward function of PipelineLayer."
)
class
PipelineLayer
(
Layer
):
class
PipelineLayer
(
Layer
):
...
@@ -274,7 +287,8 @@ class PipelineLayer(Layer):
...
@@ -274,7 +287,8 @@ class PipelineLayer(Layer):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
layers
,
layers
,
num_stages
=
None
,
num_stages
=
None
,
topology
=
None
,
topology
=
None
,
...
@@ -282,24 +296,32 @@ class PipelineLayer(Layer):
...
@@ -282,24 +296,32 @@ class PipelineLayer(Layer):
seg_method
=
"uniform"
,
seg_method
=
"uniform"
,
recompute_interval
=
0
,
recompute_interval
=
0
,
recompute_ctx
=
None
,
recompute_ctx
=
None
,
num_virtual_pipeline_stages
=
None
):
num_virtual_pipeline_stages
=
None
,
):
super
(
PipelineLayer
,
self
).
__init__
()
super
(
PipelineLayer
,
self
).
__init__
()
if
num_stages
is
None
and
topology
is
None
:
if
num_stages
is
None
and
topology
is
None
:
raise
ValueError
(
"should provide num_stages or topology"
)
raise
ValueError
(
"should provide num_stages or topology"
)
if
num_virtual_pipeline_stages
:
if
num_virtual_pipeline_stages
:
assert
isinstance
(
num_virtual_pipeline_stages
,
int
),
\
assert
isinstance
(
"virtual_pipeline_stage should be None or an int"
num_virtual_pipeline_stages
,
int
),
"virtual_pipeline_stage should be None or an int"
if
num_virtual_pipeline_stages
>
1
:
if
num_virtual_pipeline_stages
>
1
:
logger
.
info
(
logger
.
info
(
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
)
)
assert
isinstance
(
seg_method
,
str
),
\
assert
isinstance
(
"seg_method should be a str for interleave scheduler"
seg_method
,
str
assert
seg_method
.
startswith
(
'layer:'
),
\
),
"seg_method should be a str for interleave scheduler"
"seg_method shoud be start with layer: for interleave scheduler"
assert
seg_method
.
startswith
(
'layer:'
self
.
_num_virtual_pipeline_stages
=
1
if
num_virtual_pipeline_stages
is
None
else
num_virtual_pipeline_stages
),
"seg_method shoud be start with layer: for interleave scheduler"
self
.
_num_virtual_pipeline_stages
=
(
1
if
num_virtual_pipeline_stages
is
None
else
num_virtual_pipeline_stages
)
# lazy import
# lazy import
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
...
@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
...
@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
self
.
recompute_ctx
=
recompute_ctx
self
.
recompute_ctx
=
recompute_ctx
if
recompute_interval
>
0
:
if
recompute_interval
>
0
:
assert
recompute_ctx
is
not
None
,
"recompute_ctx must be not None for recompute."
assert
(
recompute_ctx
is
not
None
),
"recompute_ctx must be not None for recompute."
offload
=
recompute_ctx
.
get
(
'offload'
,
False
)
offload
=
recompute_ctx
.
get
(
'offload'
,
False
)
partition
=
recompute_ctx
.
get
(
'partition'
,
False
)
partition
=
recompute_ctx
.
get
(
'partition'
,
False
)
logger
.
info
(
logger
.
info
(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.
format
(
.
format
(
offload
,
partition
))
offload
,
partition
)
)
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
self
.
global_rank
=
dist
.
get_rank
()
self
.
global_rank
=
dist
.
get_rank
()
...
@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
...
@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
if
num_stages
:
if
num_stages
:
assert
self
.
_num_stages
==
num_stages
,
"num_stages should be equal to be %d"
%
(
assert
(
self
.
_num_stages
)
self
.
_num_stages
==
num_stages
),
"num_stages should be equal to be %d"
%
(
self
.
_num_stages
)
else
:
else
:
# construct default topology
# construct default topology
if
world_size
%
num_stages
!=
0
:
if
world_size
%
num_stages
!=
0
:
raise
ValueError
(
raise
ValueError
(
"should provide correct num_stages({}) "
"should provide correct num_stages({}) "
"which can be divided by world_size({})"
.
format
(
"which can be divided by world_size({})"
.
format
(
num_stages
,
world_size
))
num_stages
,
world_size
)
)
dp_num
=
world_size
//
num_stages
dp_num
=
world_size
//
num_stages
self
.
_topo
=
fleet
.
CommunicateTopology
([
"data"
,
"pipe"
,
"model"
],
self
.
_topo
=
fleet
.
CommunicateTopology
(
[
dp_num
,
num_stages
,
1
])
[
"data"
,
"pipe"
,
"model"
],
[
dp_num
,
num_stages
,
1
]
)
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_total_stages_with_virtual_stages
=
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
self
.
_total_stages_with_virtual_stages
=
(
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
)
# initialize segment
# initialize segment
self
.
_layers_desc
=
list
(
self
.
layers
)
self
.
_layers_desc
=
list
(
self
.
layers
)
...
@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
...
@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
start_idx
=
virtual_pp_rank
*
self
.
_num_stages
start_idx
=
virtual_pp_rank
*
self
.
_num_stages
for
stage
in
range
(
self
.
_num_stages
):
for
stage
in
range
(
self
.
_num_stages
):
# stage mark the real pp stage
# stage mark the real pp stage
if
self
.
segment_parts
[
start_idx
+
if
(
stage
]
<=
layer_idx
<
self
.
segment_parts
[
self
.
segment_parts
[
start_idx
+
stage
]
start_idx
+
stage
+
1
]:
<=
layer_idx
<
self
.
segment_parts
[
start_idx
+
stage
+
1
]
):
return
stage
return
stage
def
get_num_virtual_stages
(
self
):
def
get_num_virtual_stages
(
self
):
return
self
.
_num_virtual_pipeline_stages
return
self
.
_num_virtual_pipeline_stages
def
get_model_chunks
(
self
):
def
get_model_chunks
(
self
):
return
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
return
(
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
)
def
_construct_shared_comm
(
self
):
def
_construct_shared_comm
(
self
):
shared_comm
=
{}
shared_comm
=
{}
...
@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
...
@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
return
return
layers_desc
=
self
.
_layers_desc
layers_desc
=
self
.
_layers_desc
shared_layer_names
=
set
(
s
.
layer_name
for
s
in
layers_desc
shared_layer_names
=
set
(
if
isinstance
(
s
,
SharedLayerDesc
))
s
.
layer_name
for
s
in
layers_desc
if
isinstance
(
s
,
SharedLayerDesc
)
)
for
key
in
shared_layer_names
:
for
key
in
shared_layer_names
:
shared_layers
=
[]
shared_layers
=
[]
for
idx
,
layer
in
enumerate
(
layers_desc
):
for
idx
,
layer
in
enumerate
(
layers_desc
):
if
isinstance
(
layer
,
if
(
SharedLayerDesc
)
and
layer
.
layer_name
==
key
:
isinstance
(
layer
,
SharedLayerDesc
)
and
layer
.
layer_name
==
key
):
shared_layers
.
append
(
idx
)
shared_layers
.
append
(
idx
)
shared_stages
=
set
(
shared_stages
=
set
(
self
.
get_stage_from_index
(
idx
)
for
idx
in
shared_layers
)
self
.
get_stage_from_index
(
idx
)
for
idx
in
shared_layers
)
self
.
_dp_degree
=
self
.
_topo
.
get_dim
(
'data'
)
self
.
_dp_degree
=
self
.
_topo
.
get_dim
(
'data'
)
self
.
_mp_degree
=
self
.
_topo
.
get_dim
(
'model'
)
self
.
_mp_degree
=
self
.
_topo
.
get_dim
(
'model'
)
self
.
_sharding_degree
=
self
.
_topo
.
get_dim
(
'sharding'
)
self
.
_sharding_degree
=
self
.
_topo
.
get_dim
(
'sharding'
)
...
@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
...
@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
pipe
=
s
,
pipe
=
s
,
data
=
dp
,
data
=
dp
,
sharding
=
sharding
,
sharding
=
sharding
,
model
=
mp
))
model
=
mp
,
)
)
group
=
paddle
.
distributed
.
new_group
(
ranks
=
shared_ranks
)
group
=
paddle
.
distributed
.
new_group
(
ranks
=
shared_ranks
)
if
self
.
global_rank
in
shared_ranks
:
if
self
.
global_rank
in
shared_ranks
:
...
@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
...
@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
shared_comm
[
key
]
=
{
shared_comm
[
key
]
=
{
'ranks'
:
shared_ranks
,
'ranks'
:
shared_ranks
,
'group'
:
group
,
'group'
:
group
,
'weight_attr'
:
'weight_attr'
:
self
.
shared_weight_attrs
[
self
.
shared_weight_attrs
[
key
],
key
],
'layer'
:
self
.
shared_layers
[
key
],
'layer'
:
self
.
shared_layers
[
key
],
}
}
return
shared_comm
return
shared_comm
...
@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
...
@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
def
_synchronize_shared_weights
(
self
):
def
_synchronize_shared_weights
(
self
):
for
key
,
comm
in
self
.
shared_comm
.
items
():
for
key
,
comm
in
self
.
shared_comm
.
items
():
with
paddle
.
framework
.
no_grad
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
broadcast
(
getattr
(
comm
[
'layer'
],
paddle
.
distributed
.
broadcast
(
comm
[
'weight_attr'
]),
getattr
(
comm
[
'layer'
],
comm
[
'weight_attr'
]),
src
=
min
(
comm
[
'ranks'
]),
src
=
min
(
comm
[
'ranks'
]),
group
=
comm
[
'group'
])
group
=
comm
[
'group'
],
)
for
param
in
comm
[
'layer'
].
parameters
():
for
param
in
comm
[
'layer'
].
parameters
():
if
self
.
global_rank
!=
min
(
comm
[
'ranks'
]):
if
self
.
global_rank
!=
min
(
comm
[
'ranks'
]):
...
@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
...
@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
# need use trace_op to allreduce weight
# need use trace_op to allreduce weight
if
in_dygraph_mode
():
if
in_dygraph_mode
():
with
paddle
.
framework
.
no_grad
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
all_reduce
(
param
.
grad
,
paddle
.
distributed
.
all_reduce
(
group
=
comm
[
'group'
])
param
.
grad
,
group
=
comm
[
'group'
]
)
else
:
else
:
with
paddle
.
framework
.
no_grad
():
with
paddle
.
framework
.
no_grad
():
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
...
@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
...
@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
attrs
=
{
attrs
=
{
'ring_id'
:
comm
[
'group'
].
id
,
'ring_id'
:
comm
[
'group'
].
id
,
'use_calc_stream'
:
True
'use_calc_stream'
:
True
,
})
},
)
def
_segment_network_for_interleave
(
self
,
seg_method
):
def
_segment_network_for_interleave
(
self
,
seg_method
):
logger
.
info
(
"start segment network for interleave scheduler"
)
logger
.
info
(
"start segment network for interleave scheduler"
)
...
@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
...
@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
self
.
_layers_desc
,
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
,
method
=
seg_method
,
num_virtual_pipeline_stage
=
self
.
_num_virtual_pipeline_stages
)
num_virtual_pipeline_stage
=
self
.
_num_virtual_pipeline_stages
,
)
self
.
segment_parts
=
seg
.
do_segment
()
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
logger
.
info
(
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
)
)
for
i
in
range
(
self
.
_stage_id
,
self
.
_total_stages_with_virtual_stages
,
for
i
in
range
(
self
.
_num_stages
):
self
.
_stage_id
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
,
):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
...
@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
...
@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
def
_segment_network
(
self
,
seg_method
):
def
_segment_network
(
self
,
seg_method
):
logger
.
info
(
"start segment network.."
)
logger
.
info
(
"start segment network.."
)
seg
=
SegmentLayers
(
self
.
_layers_desc
,
seg
=
SegmentLayers
(
num_parts
=
self
.
_num_stages
,
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
method
=
seg_method
)
)
self
.
segment_parts
=
seg
.
do_segment
()
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
logger
.
info
(
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
)
)
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
...
@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
...
@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
def
_print_segmentation_for_debug
(
self
):
def
_print_segmentation_for_debug
(
self
):
# print information for debug
# print information for debug
for
stage
in
range
(
self
.
_num_stages
*
for
stage
in
range
(
self
.
_num_virtual_pipeline_stages
):
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
):
start
=
self
.
segment_parts
[
stage
]
start
=
self
.
segment_parts
[
stage
]
end
=
self
.
segment_parts
[
stage
+
1
]
end
=
self
.
segment_parts
[
stage
+
1
]
logger
.
info
(
"stage={}, global_rank={} ,layer_number={}"
.
format
(
logger
.
info
(
stage
,
self
.
global_rank
,
end
-
start
))
"stage={}, global_rank={} ,layer_number={}"
.
format
(
stage
,
self
.
global_rank
,
end
-
start
)
)
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
logger
.
info
(
"{}: {}"
.
format
(
index
+
start
,
str
(
layer
)))
logger
.
info
(
"{}: {}"
.
format
(
index
+
start
,
str
(
layer
)))
if
self
.
_num_virtual_pipeline_stages
>
1
:
if
self
.
_num_virtual_pipeline_stages
>
1
:
for
stage
in
range
(
self
.
_num_stages
):
for
stage
in
range
(
self
.
_num_stages
):
stage_to_virtual_stage_info
=
"stage {} contains virtual stages: "
.
format
(
stage_to_virtual_stage_info
=
(
stage
)
"stage {} contains virtual stages: "
.
format
(
stage
)
for
i
in
range
(
stage
,
self
.
_total_stages_with_virtual_stages
,
)
self
.
_num_stages
):
for
i
in
range
(
stage
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
,
):
stage_to_virtual_stage_info
+=
" {},"
.
format
(
i
)
stage_to_virtual_stage_info
+=
" {},"
.
format
(
i
)
logger
.
info
(
stage_to_virtual_stage_info
)
logger
.
info
(
stage_to_virtual_stage_info
)
...
@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
...
@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
if
layer
.
layer_name
not
in
self
.
shared_layers
:
if
layer
.
layer_name
not
in
self
.
shared_layers
:
self
.
shared_layers
[
layer
.
layer_name
]
=
layer
.
build_layer
()
self
.
shared_layers
[
layer
.
layer_name
]
=
layer
.
build_layer
()
self
.
shared_weight_attrs
[
self
.
shared_weight_attrs
[
layer
.
layer_name
]
=
layer
.
shared_weight_attr
layer
.
layer_name
]
=
layer
.
shared_weight_attr
for
param
in
self
.
shared_layers
[
for
param
in
self
.
shared_layers
[
layer
.
layer_name
].
parameters
():
layer
.
layer_name
].
parameters
():
setattr
(
param
,
"is_firstly_shared"
,
True
)
setattr
(
param
,
"is_firstly_shared"
,
True
)
if
layer
.
forward_func
is
None
:
if
layer
.
forward_func
is
None
:
...
@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
...
@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
else
:
else
:
run_function
.
append
(
run_function
.
append
(
partial
(
layer
.
forward_func
,
partial
(
self
.
shared_layers
[
layer
.
layer_name
]))
layer
.
forward_func
,
self
.
shared_layers
[
layer
.
layer_name
],
)
)
elif
isinstance
(
layer
,
LayerDesc
):
elif
isinstance
(
layer
,
LayerDesc
):
model
=
layer
.
build_layer
()
model
=
layer
.
build_layer
()
...
@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
...
@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
def
forward
(
self
,
input
,
chunk_id
=
None
):
def
forward
(
self
,
input
,
chunk_id
=
None
):
if
chunk_id
is
not
None
:
if
chunk_id
is
not
None
:
assert
isinstance
(
chunk_id
,
int
),
"chunk_id should be an int"
assert
isinstance
(
chunk_id
,
int
),
"chunk_id should be an int"
assert
self
.
_num_virtual_pipeline_stages
>
1
,
\
assert
(
"chunk_id is only valid when using virtual pipeline stage"
self
.
_num_virtual_pipeline_stages
>
1
assert
chunk_id
<
len
(
self
.
_model_chunks
),
\
),
"chunk_id is only valid when using virtual pipeline stage"
"The virtual pipeline only has {} chunks, "
\
assert
chunk_id
<
len
(
self
.
_model_chunks
),
(
"but received chunk_id {}."
.
format
(
len
(
self
.
_model_chunks
),
chunk_id
)
"The virtual pipeline only has {} chunks, "
"but received chunk_id {}."
.
format
(
len
(
self
.
_model_chunks
),
chunk_id
)
)
# Get the target model chunk.
# Get the target model chunk.
model_chunk
=
self
.
_model_chunks
[
chunk_id
]
model_chunk
=
self
.
_model_chunks
[
chunk_id
]
# Update the self.run_function to the target run functions.
# Update the self.run_function to the target run functions.
...
@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
...
@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
funcs
=
self
.
run_function
[
start_idx
:
end_idx
]
funcs
=
self
.
run_function
[
start_idx
:
end_idx
]
if
not
isinstance
(
input
,
tuple
):
if
not
isinstance
(
input
,
tuple
):
input
=
(
input
,
)
input
=
(
input
,)
if
self
.
_need_recompute
(
funcs
,
input
):
if
self
.
_need_recompute
(
funcs
,
input
):
input
=
recompute_hybrid
(
input
=
recompute_hybrid
(
self
.
recompute_ctx
,
self
.
recompute_ctx
,
self
.
forward_function
(
start_idx
,
end_idx
),
*
input
)
self
.
forward_function
(
start_idx
,
end_idx
),
*
input
)
else
:
else
:
input
=
self
.
forward_function
(
start_idx
,
end_idx
)(
*
input
)
input
=
self
.
forward_function
(
start_idx
,
end_idx
)(
*
input
)
return
input
return
input
def
_need_recompute
(
self
,
funcs
,
inputs
):
def
_need_recompute
(
self
,
funcs
,
inputs
):
if
not
any
(
input_
.
stop_gradient
==
False
if
not
any
(
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)):
input_
.
stop_gradient
==
False
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)
):
return
False
return
False
params
=
[
f
.
parameters
()
for
f
in
funcs
if
isinstance
(
f
,
Layer
)]
params
=
[
f
.
parameters
()
for
f
in
funcs
if
isinstance
(
f
,
Layer
)]
...
@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
...
@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
if
self
.
_num_virtual_pipeline_stages
>
1
:
if
self
.
_num_virtual_pipeline_stages
>
1
:
# add virtual pipeline info to the save path
# add virtual pipeline info to the save path
assert
local_chunk_id
is
not
None
assert
local_chunk_id
is
not
None
virtual_pipeline_stage_message
=
"-virtual_pp_stage_{:0>2d}"
.
format
(
virtual_pipeline_stage_message
=
(
local_chunk_id
)
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
layer_save_path
=
os
.
path
.
join
(
ckpt_dir
,
)
'layer_{:0>2d}'
.
format
(
idx
))
layer_save_path
=
os
.
path
.
join
(
layer_save_path
=
layer_save_path
+
virtual_pipeline_stage_message
+
rank_message
+
'-model_states.pdparams'
ckpt_dir
,
'layer_{:0>2d}'
.
format
(
idx
)
)
layer_save_path
=
(
layer_save_path
+
virtual_pipeline_stage_message
+
rank_message
+
'-model_states.pdparams'
)
return
layer_save_path
return
layer_save_path
def
_save_model
(
run_functions
,
local_chunk_id
=
None
):
def
_save_model
(
run_functions
,
local_chunk_id
=
None
):
...
@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
...
@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
def
set_state_dir
(
self
,
path
):
def
set_state_dir
(
self
,
path
):
assert
os
.
path
.
exists
(
assert
os
.
path
.
exists
(
path
),
"{} not found, please check the path"
.
format
(
path
)
path
),
"{} not found, please check the path"
.
format
(
path
)
def
_load_model
(
run_functions
,
local_chunk_id
=
None
):
def
_load_model
(
run_functions
,
local_chunk_id
=
None
):
for
idx
,
layer
in
enumerate
(
run_functions
):
for
idx
,
layer
in
enumerate
(
run_functions
):
...
@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
...
@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
pos_offset
=
self
.
_start_poss
[
local_chunk_id
]
pos_offset
=
self
.
_start_poss
[
local_chunk_id
]
layer_idx
=
idx
+
pos_offset
layer_idx
=
idx
+
pos_offset
layer_save_path
=
os
.
path
.
join
(
layer_save_path
=
os
.
path
.
join
(
path
,
'layer_{0:0>2d}'
.
format
(
layer_idx
))
path
,
'layer_{0:0>2d}'
.
format
(
layer_idx
)
)
if
self
.
_num_virtual_pipeline_stages
>
1
:
if
self
.
_num_virtual_pipeline_stages
>
1
:
# add virtual pipeline info to the path
# add virtual pipeline info to the path
assert
local_chunk_id
is
not
None
assert
local_chunk_id
is
not
None
layer_save_path
=
layer_save_path
+
"-virtual_pp_stage_{:0>2d}"
.
format
(
layer_save_path
=
(
local_chunk_id
)
layer_save_path
model_files
=
glob
.
glob
(
layer_save_path
+
+
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
"*model_states.pdparams"
)
)
model_files
=
glob
.
glob
(
layer_save_path
+
"*model_states.pdparams"
)
model_files
.
sort
()
model_files
.
sort
()
mp_rank
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
model
mp_rank
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
model
mp_world_size
=
self
.
_topo
.
get_dim
(
'model'
)
mp_world_size
=
self
.
_topo
.
get_dim
(
'model'
)
num_files
=
len
(
model_files
)
num_files
=
len
(
model_files
)
load_param_path
=
model_files
[
mp_rank
*
num_files
//
load_param_path
=
model_files
[
mp_world_size
]
mp_rank
*
num_files
//
mp_world_size
]
model_state_dict
=
paddle
.
load
(
load_param_path
)
model_state_dict
=
paddle
.
load
(
load_param_path
)
layer
.
set_state_dict
(
model_state_dict
)
layer
.
set_state_dict
(
model_state_dict
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录