Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3a014783
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a014783
编写于
11月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the split logic for uniform (#47670) (#47705)
* code format change * update the split logic for uniform (#47670)
上级
d5809836
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
221 addition
and
130 deletion
+221
-130
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+221
-130
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
3a014783
...
...
@@ -57,7 +57,6 @@ __all__ = []
class
LayerDesc
(
object
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
self
.
layer_func
=
layer_func
self
.
inputs
=
inputs
...
...
@@ -65,25 +64,28 @@ class LayerDesc(object):
if
not
issubclass
(
layer_func
,
Layer
):
raise
TypeError
(
"The input(layer_func) should be a derived class of Layer."
)
"The input(layer_func) should be a derived class of Layer."
)
def
build_layer
(
self
):
return
self
.
layer_func
(
*
self
.
inputs
,
**
self
.
kwargs
)
def
__repr__
(
self
):
return
layer_to_str
(
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
**
self
.
kwargs
)
return
layer_to_str
(
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
**
self
.
kwargs
)
class
SharedLayerDesc
(
LayerDesc
):
def
__init__
(
self
,
def
__init__
(
self
,
key
,
layer_func
,
forward_func
=
None
,
shared_weight_attr
=
'weight'
,
*
inputs
,
**
kwargs
):
**
kwargs
):
super
(
SharedLayerDesc
,
self
).
__init__
(
layer_func
,
*
inputs
,
**
kwargs
)
self
.
layer_name
=
key
self
.
forward_func
=
forward_func
...
...
@@ -91,12 +93,13 @@ class SharedLayerDesc(LayerDesc):
class
SegmentLayers
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
layers_desc
,
num_parts
,
method
=
"uniform"
,
num_virtual_pipeline_stage
=
None
):
num_virtual_pipeline_stage
=
None
,
):
self
.
_layers_desc
=
layers_desc
self
.
method
=
method
self
.
num_parts
=
num_parts
...
...
@@ -104,7 +107,9 @@ class SegmentLayers(object):
self
.
num_virtual_pipeline_stage
=
num_virtual_pipeline_stage
if
self
.
num_virtual_pipeline_stage
is
not
None
:
self
.
total_parts
=
num_parts
*
self
.
num_virtual_pipeline_stage
assert
self
.
num_items
>=
self
.
num_parts
,
"layer number should be greater than number of segments"
assert
(
self
.
num_items
>=
self
.
num_parts
),
"layer number should be greater than number of segments"
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
...
...
@@ -118,12 +123,17 @@ class SegmentLayers(object):
for
idx
in
weight_idxs
:
weights
[
idx
]
=
1
actual_num_parts
=
self
.
num_parts
if
self
.
num_virtual_pipeline_stage
is
None
else
self
.
total_parts
actual_num_parts
=
(
self
.
num_parts
if
self
.
num_virtual_pipeline_stage
is
None
else
self
.
total_parts
)
assert
sum
(
weights
)
%
actual_num_parts
==
0
,
"number of layers ({}) should be divided by part number({})"
.
format
(
sum
(
weights
),
actual_num_parts
)
assert
(
sum
(
weights
)
%
actual_num_parts
==
0
),
"number of layers ({}) should be divided by part number({})"
.
format
(
sum
(
weights
),
actual_num_parts
)
part_size
=
sum
(
weights
)
//
actual_num_parts
result
=
[
0
for
_
in
range
(
actual_num_parts
+
1
)]
...
...
@@ -156,21 +166,23 @@ class SegmentLayers(object):
if
regex
.
search
(
name
):
weight_idxs
.
append
(
idx
)
assert
len
(
weight_idxs
)
>
0
,
"weight_idxs' length should be greater than 0"
assert
(
len
(
weight_idxs
)
>
0
),
"weight_idxs' length should be greater than 0"
return
weight_idxs
def
uniform
(
self
,
num_items
,
num_parts
):
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
part_size
=
math
.
floor
(
num_items
/
num_parts
)
for
i
in
range
(
num_parts
):
result
[
i
]
=
int
(
min
(
part_size
*
i
,
num_items
))
extra_layers
=
num_items
%
num_parts
for
i
in
range
(
1
,
num_parts
):
offset
=
1
if
i
>
(
num_parts
-
extra_layers
)
else
0
result
[
i
]
=
int
(
min
(
result
[
i
-
1
]
+
part_size
+
offset
,
num_items
))
result
[
num_parts
]
=
num_items
return
result
class
PipelineLayerChunk
(
Layer
):
def
__init__
(
self
):
super
(
PipelineLayerChunk
,
self
).
__init__
()
self
.
run_function
=
[]
...
...
@@ -192,7 +204,8 @@ class PipelineLayerChunk(Layer):
# behavior under recompute circumstance.
raise
PermissionError
(
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer."
)
"Please call forward function of PipelineLayer."
)
class
PipelineLayer
(
Layer
):
...
...
@@ -274,7 +287,8 @@ class PipelineLayer(Layer):
"""
def
__init__
(
self
,
def
__init__
(
self
,
layers
,
num_stages
=
None
,
topology
=
None
,
...
...
@@ -282,24 +296,32 @@ class PipelineLayer(Layer):
seg_method
=
"uniform"
,
recompute_interval
=
0
,
recompute_ctx
=
None
,
num_virtual_pipeline_stages
=
None
):
num_virtual_pipeline_stages
=
None
,
):
super
(
PipelineLayer
,
self
).
__init__
()
if
num_stages
is
None
and
topology
is
None
:
raise
ValueError
(
"should provide num_stages or topology"
)
if
num_virtual_pipeline_stages
:
assert
isinstance
(
num_virtual_pipeline_stages
,
int
),
\
"virtual_pipeline_stage should be None or an int"
assert
isinstance
(
num_virtual_pipeline_stages
,
int
),
"virtual_pipeline_stage should be None or an int"
if
num_virtual_pipeline_stages
>
1
:
logger
.
info
(
"set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
)
assert
isinstance
(
seg_method
,
str
),
\
"seg_method should be a str for interleave scheduler"
assert
seg_method
.
startswith
(
'layer:'
),
\
"seg_method shoud be start with layer: for interleave scheduler"
self
.
_num_virtual_pipeline_stages
=
1
if
num_virtual_pipeline_stages
is
None
else
num_virtual_pipeline_stages
assert
isinstance
(
seg_method
,
str
),
"seg_method should be a str for interleave scheduler"
assert
seg_method
.
startswith
(
'layer:'
),
"seg_method shoud be start with layer: for interleave scheduler"
self
.
_num_virtual_pipeline_stages
=
(
1
if
num_virtual_pipeline_stages
is
None
else
num_virtual_pipeline_stages
)
# lazy import
import
paddle.distributed
as
dist
...
...
@@ -313,13 +335,17 @@ class PipelineLayer(Layer):
self
.
recompute_ctx
=
recompute_ctx
if
recompute_interval
>
0
:
assert
recompute_ctx
is
not
None
,
"recompute_ctx must be not None for recompute."
assert
(
recompute_ctx
is
not
None
),
"recompute_ctx must be not None for recompute."
offload
=
recompute_ctx
.
get
(
'offload'
,
False
)
partition
=
recompute_ctx
.
get
(
'partition'
,
False
)
logger
.
info
(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.
format
(
offload
,
partition
))
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.
format
(
offload
,
partition
)
)
world_size
=
dist
.
get_world_size
()
self
.
global_rank
=
dist
.
get_rank
()
...
...
@@ -328,22 +354,28 @@ class PipelineLayer(Layer):
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
if
num_stages
:
assert
self
.
_num_stages
==
num_stages
,
"num_stages should be equal to be %d"
%
(
self
.
_num_stages
)
assert
(
self
.
_num_stages
==
num_stages
),
"num_stages should be equal to be %d"
%
(
self
.
_num_stages
)
else
:
# construct default topology
if
world_size
%
num_stages
!=
0
:
raise
ValueError
(
"should provide correct num_stages({}) "
"which can be divided by world_size({})"
.
format
(
num_stages
,
world_size
))
num_stages
,
world_size
)
)
dp_num
=
world_size
//
num_stages
self
.
_topo
=
fleet
.
CommunicateTopology
([
"data"
,
"pipe"
,
"model"
],
[
dp_num
,
num_stages
,
1
])
self
.
_topo
=
fleet
.
CommunicateTopology
(
[
"data"
,
"pipe"
,
"model"
],
[
dp_num
,
num_stages
,
1
]
)
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
self
.
_total_stages_with_virtual_stages
=
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
self
.
_total_stages_with_virtual_stages
=
(
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
)
# initialize segment
self
.
_layers_desc
=
list
(
self
.
layers
)
...
...
@@ -381,16 +413,22 @@ class PipelineLayer(Layer):
start_idx
=
virtual_pp_rank
*
self
.
_num_stages
for
stage
in
range
(
self
.
_num_stages
):
# stage mark the real pp stage
if
self
.
segment_parts
[
start_idx
+
stage
]
<=
layer_idx
<
self
.
segment_parts
[
start_idx
+
stage
+
1
]:
if
(
self
.
segment_parts
[
start_idx
+
stage
]
<=
layer_idx
<
self
.
segment_parts
[
start_idx
+
stage
+
1
]
):
return
stage
def
get_num_virtual_stages
(
self
):
return
self
.
_num_virtual_pipeline_stages
def
get_model_chunks
(
self
):
return
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
return
(
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
)
def
_construct_shared_comm
(
self
):
shared_comm
=
{}
...
...
@@ -398,17 +436,21 @@ class PipelineLayer(Layer):
return
layers_desc
=
self
.
_layers_desc
shared_layer_names
=
set
(
s
.
layer_name
for
s
in
layers_desc
if
isinstance
(
s
,
SharedLayerDesc
))
shared_layer_names
=
set
(
s
.
layer_name
for
s
in
layers_desc
if
isinstance
(
s
,
SharedLayerDesc
)
)
for
key
in
shared_layer_names
:
shared_layers
=
[]
for
idx
,
layer
in
enumerate
(
layers_desc
):
if
isinstance
(
layer
,
SharedLayerDesc
)
and
layer
.
layer_name
==
key
:
if
(
isinstance
(
layer
,
SharedLayerDesc
)
and
layer
.
layer_name
==
key
):
shared_layers
.
append
(
idx
)
shared_stages
=
set
(
self
.
get_stage_from_index
(
idx
)
for
idx
in
shared_layers
)
self
.
get_stage_from_index
(
idx
)
for
idx
in
shared_layers
)
self
.
_dp_degree
=
self
.
_topo
.
get_dim
(
'data'
)
self
.
_mp_degree
=
self
.
_topo
.
get_dim
(
'model'
)
self
.
_sharding_degree
=
self
.
_topo
.
get_dim
(
'sharding'
)
...
...
@@ -425,7 +467,9 @@ class PipelineLayer(Layer):
pipe
=
s
,
data
=
dp
,
sharding
=
sharding
,
model
=
mp
))
model
=
mp
,
)
)
group
=
paddle
.
distributed
.
new_group
(
ranks
=
shared_ranks
)
if
self
.
global_rank
in
shared_ranks
:
...
...
@@ -434,8 +478,9 @@ class PipelineLayer(Layer):
shared_comm
[
key
]
=
{
'ranks'
:
shared_ranks
,
'group'
:
group
,
'weight_attr'
:
self
.
shared_weight_attrs
[
key
],
'weight_attr'
:
self
.
shared_weight_attrs
[
key
],
'layer'
:
self
.
shared_layers
[
key
],
}
return
shared_comm
...
...
@@ -443,10 +488,11 @@ class PipelineLayer(Layer):
def
_synchronize_shared_weights
(
self
):
for
key
,
comm
in
self
.
shared_comm
.
items
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
broadcast
(
getattr
(
comm
[
'layer'
],
comm
[
'weight_attr'
]),
paddle
.
distributed
.
broadcast
(
getattr
(
comm
[
'layer'
],
comm
[
'weight_attr'
]),
src
=
min
(
comm
[
'ranks'
]),
group
=
comm
[
'group'
])
group
=
comm
[
'group'
],
)
for
param
in
comm
[
'layer'
].
parameters
():
if
self
.
global_rank
!=
min
(
comm
[
'ranks'
]):
...
...
@@ -458,8 +504,9 @@ class PipelineLayer(Layer):
# need use trace_op to allreduce weight
if
in_dygraph_mode
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
all_reduce
(
param
.
grad
,
group
=
comm
[
'group'
])
paddle
.
distributed
.
all_reduce
(
param
.
grad
,
group
=
comm
[
'group'
]
)
else
:
with
paddle
.
framework
.
no_grad
():
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
...
...
@@ -468,8 +515,9 @@ class PipelineLayer(Layer):
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
attrs
=
{
'ring_id'
:
comm
[
'group'
].
id
,
'use_calc_stream'
:
True
})
'use_calc_stream'
:
True
,
},
)
def
_segment_network_for_interleave
(
self
,
seg_method
):
logger
.
info
(
"start segment network for interleave scheduler"
)
...
...
@@ -477,14 +525,20 @@ class PipelineLayer(Layer):
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
,
num_virtual_pipeline_stage
=
self
.
_num_virtual_pipeline_stages
)
num_virtual_pipeline_stage
=
self
.
_num_virtual_pipeline_stages
,
)
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
logger
.
info
(
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
)
)
for
i
in
range
(
self
.
_stage_id
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
):
for
i
in
range
(
self
.
_stage_id
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
,
):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
...
...
@@ -500,13 +554,15 @@ class PipelineLayer(Layer):
def
_segment_network
(
self
,
seg_method
):
logger
.
info
(
"start segment network.."
)
seg
=
SegmentLayers
(
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
)
seg
=
SegmentLayers
(
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
)
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
logger
.
info
(
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
)
)
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
...
...
@@ -514,22 +570,30 @@ class PipelineLayer(Layer):
def
_print_segmentation_for_debug
(
self
):
# print information for debug
for
stage
in
range
(
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
):
for
stage
in
range
(
self
.
_num_stages
*
self
.
_num_virtual_pipeline_stages
):
start
=
self
.
segment_parts
[
stage
]
end
=
self
.
segment_parts
[
stage
+
1
]
logger
.
info
(
"stage={}, global_rank={} ,layer_number={}"
.
format
(
stage
,
self
.
global_rank
,
end
-
start
))
logger
.
info
(
"stage={}, global_rank={} ,layer_number={}"
.
format
(
stage
,
self
.
global_rank
,
end
-
start
)
)
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
logger
.
info
(
"{}: {}"
.
format
(
index
+
start
,
str
(
layer
)))
if
self
.
_num_virtual_pipeline_stages
>
1
:
for
stage
in
range
(
self
.
_num_stages
):
stage_to_virtual_stage_info
=
"stage {} contains virtual stages: "
.
format
(
stage
)
for
i
in
range
(
stage
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
):
stage_to_virtual_stage_info
=
(
"stage {} contains virtual stages: "
.
format
(
stage
)
)
for
i
in
range
(
stage
,
self
.
_total_stages_with_virtual_stages
,
self
.
_num_stages
,
):
stage_to_virtual_stage_info
+=
" {},"
.
format
(
i
)
logger
.
info
(
stage_to_virtual_stage_info
)
...
...
@@ -575,9 +639,11 @@ class PipelineLayer(Layer):
if
layer
.
layer_name
not
in
self
.
shared_layers
:
self
.
shared_layers
[
layer
.
layer_name
]
=
layer
.
build_layer
()
self
.
shared_weight_attrs
[
layer
.
layer_name
]
=
layer
.
shared_weight_attr
layer
.
layer_name
]
=
layer
.
shared_weight_attr
for
param
in
self
.
shared_layers
[
layer
.
layer_name
].
parameters
():
layer
.
layer_name
].
parameters
():
setattr
(
param
,
"is_firstly_shared"
,
True
)
if
layer
.
forward_func
is
None
:
...
...
@@ -585,8 +651,11 @@ class PipelineLayer(Layer):
else
:
run_function
.
append
(
partial
(
layer
.
forward_func
,
self
.
shared_layers
[
layer
.
layer_name
]))
partial
(
layer
.
forward_func
,
self
.
shared_layers
[
layer
.
layer_name
],
)
)
elif
isinstance
(
layer
,
LayerDesc
):
model
=
layer
.
build_layer
()
...
...
@@ -615,11 +684,15 @@ class PipelineLayer(Layer):
def
forward
(
self
,
input
,
chunk_id
=
None
):
if
chunk_id
is
not
None
:
assert
isinstance
(
chunk_id
,
int
),
"chunk_id should be an int"
assert
self
.
_num_virtual_pipeline_stages
>
1
,
\
"chunk_id is only valid when using virtual pipeline stage"
assert
chunk_id
<
len
(
self
.
_model_chunks
),
\
"The virtual pipeline only has {} chunks, "
\
"but received chunk_id {}."
.
format
(
len
(
self
.
_model_chunks
),
chunk_id
)
assert
(
self
.
_num_virtual_pipeline_stages
>
1
),
"chunk_id is only valid when using virtual pipeline stage"
assert
chunk_id
<
len
(
self
.
_model_chunks
),
(
"The virtual pipeline only has {} chunks, "
"but received chunk_id {}."
.
format
(
len
(
self
.
_model_chunks
),
chunk_id
)
)
# Get the target model chunk.
model_chunk
=
self
.
_model_chunks
[
chunk_id
]
# Update the self.run_function to the target run functions.
...
...
@@ -637,20 +710,25 @@ class PipelineLayer(Layer):
funcs
=
self
.
run_function
[
start_idx
:
end_idx
]
if
not
isinstance
(
input
,
tuple
):
input
=
(
input
,
)
input
=
(
input
,)
if
self
.
_need_recompute
(
funcs
,
input
):
input
=
recompute_hybrid
(
self
.
recompute_ctx
,
self
.
forward_function
(
start_idx
,
end_idx
),
*
input
)
self
.
forward_function
(
start_idx
,
end_idx
),
*
input
)
else
:
input
=
self
.
forward_function
(
start_idx
,
end_idx
)(
*
input
)
return
input
def
_need_recompute
(
self
,
funcs
,
inputs
):
if
not
any
(
input_
.
stop_gradient
==
False
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)):
if
not
any
(
input_
.
stop_gradient
==
False
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)
):
return
False
params
=
[
f
.
parameters
()
for
f
in
funcs
if
isinstance
(
f
,
Layer
)]
...
...
@@ -674,11 +752,18 @@ class PipelineLayer(Layer):
if
self
.
_num_virtual_pipeline_stages
>
1
:
# add virtual pipeline info to the save path
assert
local_chunk_id
is
not
None
virtual_pipeline_stage_message
=
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
layer_save_path
=
os
.
path
.
join
(
ckpt_dir
,
'layer_{:0>2d}'
.
format
(
idx
))
layer_save_path
=
layer_save_path
+
virtual_pipeline_stage_message
+
rank_message
+
'-model_states.pdparams'
virtual_pipeline_stage_message
=
(
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
)
layer_save_path
=
os
.
path
.
join
(
ckpt_dir
,
'layer_{:0>2d}'
.
format
(
idx
)
)
layer_save_path
=
(
layer_save_path
+
virtual_pipeline_stage_message
+
rank_message
+
'-model_states.pdparams'
)
return
layer_save_path
def
_save_model
(
run_functions
,
local_chunk_id
=
None
):
...
...
@@ -701,7 +786,8 @@ class PipelineLayer(Layer):
def
set_state_dir
(
self
,
path
):
assert
os
.
path
.
exists
(
path
),
"{} not found, please check the path"
.
format
(
path
)
path
),
"{} not found, please check the path"
.
format
(
path
)
def
_load_model
(
run_functions
,
local_chunk_id
=
None
):
for
idx
,
layer
in
enumerate
(
run_functions
):
...
...
@@ -715,21 +801,26 @@ class PipelineLayer(Layer):
pos_offset
=
self
.
_start_poss
[
local_chunk_id
]
layer_idx
=
idx
+
pos_offset
layer_save_path
=
os
.
path
.
join
(
path
,
'layer_{0:0>2d}'
.
format
(
layer_idx
))
path
,
'layer_{0:0>2d}'
.
format
(
layer_idx
)
)
if
self
.
_num_virtual_pipeline_stages
>
1
:
# add virtual pipeline info to the path
assert
local_chunk_id
is
not
None
layer_save_path
=
layer_save_path
+
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
model_files
=
glob
.
glob
(
layer_save_path
+
"*model_states.pdparams"
)
layer_save_path
=
(
layer_save_path
+
"-virtual_pp_stage_{:0>2d}"
.
format
(
local_chunk_id
)
)
model_files
=
glob
.
glob
(
layer_save_path
+
"*model_states.pdparams"
)
model_files
.
sort
()
mp_rank
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
model
mp_world_size
=
self
.
_topo
.
get_dim
(
'model'
)
num_files
=
len
(
model_files
)
load_param_path
=
model_files
[
mp_rank
*
num_files
//
mp_world_size
]
load_param_path
=
model_files
[
mp_rank
*
num_files
//
mp_world_size
]
model_state_dict
=
paddle
.
load
(
load_param_path
)
layer
.
set_state_dict
(
model_state_dict
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录