Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0449b841
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0449b841
编写于
11月 03, 2022
作者:
R
Roc
提交者:
GitHub
11月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Save dygraph model for auto inference (#47463)
上级
21277904
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
911 addition
and
3 deletion
+911
-3
python/paddle/distributed/fleet/layers/mpu/mp_layers.py
python/paddle/distributed/fleet/layers/mpu/mp_layers.py
+9
-0
python/paddle/fluid/dygraph/layers.py
python/paddle/fluid/dygraph/layers.py
+12
-0
python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt
...dle/fluid/tests/unittests/collective/fleet/CMakeLists.txt
+8
-0
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_save_for_auto_infer.py
...unittests/collective/fleet/dygraph_save_for_auto_infer.py
+472
-0
python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_save_for_auto_infer.py
...ests/collective/fleet/test_dygraph_save_for_auto_infer.py
+49
-0
python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv
...ddle/fluid/tests/unittests/collective/fleet/testslist.csv
+1
-0
python/paddle/incubate/distributed/utils/io/__init__.py
python/paddle/incubate/distributed/utils/io/__init__.py
+1
-1
python/paddle/incubate/distributed/utils/io/dist_save.py
python/paddle/incubate/distributed/utils/io/dist_save.py
+2
-2
python/paddle/incubate/distributed/utils/io/save_for_auto.py
python/paddle/incubate/distributed/utils/io/save_for_auto.py
+357
-0
未找到文件。
python/paddle/distributed/fleet/layers/mpu/mp_layers.py
浏览文件 @
0449b841
...
@@ -144,6 +144,8 @@ class VocabParallelEmbedding(Layer):
...
@@ -144,6 +144,8 @@ class VocabParallelEmbedding(Layer):
)
)
self
.
weight
.
is_distributed
=
True
if
self
.
is_mp
else
False
self
.
weight
.
is_distributed
=
True
if
self
.
is_mp
else
False
if
self
.
weight
.
is_distributed
:
setattr
(
self
.
weight
,
"split_axis"
,
0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
is_mp
:
if
self
.
is_mp
:
...
@@ -276,6 +278,9 @@ class ColumnParallelLinear(Layer):
...
@@ -276,6 +278,9 @@ class ColumnParallelLinear(Layer):
self
.
weight
.
is_distributed
=
True
if
self
.
is_mp
else
False
self
.
weight
.
is_distributed
=
True
if
self
.
is_mp
else
False
if
self
.
weight
.
is_distributed
:
setattr
(
self
.
weight
,
"split_axis"
,
1
)
if
has_bias
:
if
has_bias
:
# initialize bias to zero like Megatron
# initialize bias to zero like Megatron
self
.
bias
=
self
.
create_parameter
(
self
.
bias
=
self
.
create_parameter
(
...
@@ -285,6 +290,8 @@ class ColumnParallelLinear(Layer):
...
@@ -285,6 +290,8 @@ class ColumnParallelLinear(Layer):
is_bias
=
True
,
is_bias
=
True
,
)
)
self
.
bias
.
is_distributed
=
True
if
self
.
is_mp
else
False
self
.
bias
.
is_distributed
=
True
if
self
.
is_mp
else
False
if
self
.
bias
.
is_distributed
:
setattr
(
self
.
bias
,
"split_axis"
,
0
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -437,6 +444,8 @@ class RowParallelLinear(Layer):
...
@@ -437,6 +444,8 @@ class RowParallelLinear(Layer):
)
)
self
.
weight
.
is_distributed
=
True
if
self
.
is_mp
else
False
self
.
weight
.
is_distributed
=
True
if
self
.
is_mp
else
False
if
self
.
weight
.
is_distributed
:
setattr
(
self
.
weight
,
"split_axis"
,
0
)
if
has_bias
:
if
has_bias
:
self
.
bias
=
self
.
create_parameter
(
self
.
bias
=
self
.
create_parameter
(
...
...
python/paddle/fluid/dygraph/layers.py
浏览文件 @
0449b841
...
@@ -62,6 +62,17 @@ _first_cap_re = re.compile('(.)([A-Z][a-z]+)')
...
@@ -62,6 +62,17 @@ _first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re
=
re
.
compile
(
'([a-z])([A-Z])'
)
_all_cap_re
=
re
.
compile
(
'([a-z])([A-Z])'
)
def
_scope_dist2single
(
dist_scope
):
mapping
=
{
"row_parallel_linear"
:
"linear"
,
"column_parallel_linear"
:
"linear"
,
"vocab_parallel_embedding"
:
"embedding"
,
# "parallel_cross_entropy": "cross_entropy", while mp_layer has parallel_cross_entropy,
# but there is no parameters so the mapping of parallel_cross_entropy is not neccessary.
}
return
mapping
.
get
(
dist_scope
,
dist_scope
)
def
_convert_camel_to_snake
(
name
):
def
_convert_camel_to_snake
(
name
):
s1
=
_first_cap_re
.
sub
(
r
'\1_\2'
,
name
)
s1
=
_first_cap_re
.
sub
(
r
'\1_\2'
,
name
)
return
_all_cap_re
.
sub
(
r
'\1_\2'
,
s1
).
lower
()
return
_all_cap_re
.
sub
(
r
'\1_\2'
,
s1
).
lower
()
...
@@ -137,6 +148,7 @@ class Layer(object):
...
@@ -137,6 +148,7 @@ class Layer(object):
self
.
training
=
True
self
.
training
=
True
if
name_scope
is
None
:
if
name_scope
is
None
:
name_scope
=
_convert_camel_to_snake
(
self
.
__class__
.
__name__
)
name_scope
=
_convert_camel_to_snake
(
self
.
__class__
.
__name__
)
name_scope
=
_scope_dist2single
(
name_scope
)
self
.
_full_name
=
unique_name
.
generate
(
name_scope
)
self
.
_full_name
=
unique_name
.
generate
(
name_scope
)
self
.
_helper
=
LayerObjectHelper
(
self
.
_full_name
)
self
.
_helper
=
LayerObjectHelper
(
self
.
_full_name
)
self
.
_built
=
False
self
.
_built
=
False
...
...
python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt
浏览文件 @
0449b841
...
@@ -952,3 +952,11 @@ if((WITH_GPU) AND (LINUX))
...
@@ -952,3 +952,11 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties
(
test_dygraph_dist_save_load
set_tests_properties
(
test_dygraph_dist_save_load
PROPERTIES TIMEOUT
"200"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"200"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU
)
AND
(
LINUX
))
py_test_modules
(
test_dygraph_save_for_auto_infer MODULES test_dygraph_save_for_auto_infer
ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_dygraph_save_for_auto_infer
PROPERTIES TIMEOUT
"300"
LABELS
"RUN_TYPE=DIST"
)
endif
()
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_save_for_auto_infer.py
0 → 100644
浏览文件 @
0449b841
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
import
numpy
as
np
import
tempfile
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Linear
,
Embedding
from
paddle.distributed
import
fleet
from
paddle.distributed.fleet.layers.mpu.mp_layers
import
(
RowParallelLinear
,
ColumnParallelLinear
,
VocabParallelEmbedding
,
)
from
paddle.distributed.auto_parallel
import
engine
from
paddle.distributed.sharding.group_sharded
import
group_sharded_parallel
from
paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers
import
(
PipelineLayer
,
LayerDesc
,
)
import
sys
import
subprocess
import
argparse
import
copy
from
paddle
import
distributed
as
dist
from
paddle.distributed.utils.log_utils
import
get_logger
from
paddle.fluid.dataloader.dataset
import
IterableDataset
from
paddle.incubate.distributed.utils.io
import
save_for_auto_inference
logger
=
get_logger
(
"INFO"
,
__file__
)
epoch
=
2
linear_size
=
1000
class
MLP_pipe
(
PipelineLayer
):
def
__init__
(
self
,
embedding_size
=
1000
,
linear_size
=
1000
,
param_attr
=
None
,
bias_attr
=
None
,
):
desc
=
[
LayerDesc
(
VocabParallelEmbedding
,
num_embeddings
=
embedding_size
,
embedding_dim
=
linear_size
,
),
LayerDesc
(
RowParallelLinear
,
in_features
=
linear_size
,
out_features
=
linear_size
,
has_bias
=
True
,
),
LayerDesc
(
ColumnParallelLinear
,
in_features
=
linear_size
,
out_features
=
linear_size
,
gather_output
=
True
,
has_bias
=
True
,
),
LayerDesc
(
Linear
,
input_dim
=
linear_size
,
output_dim
=
10
),
]
super
(
MLP_pipe
,
self
).
__init__
(
desc
,
num_stages
=
2
,
loss_fn
=
paddle
.
nn
.
CrossEntropyLoss
(),
topology
=
fleet
.
get_hybrid_communicate_group
().
_topo
,
)
class
MLP_Hybrid
(
fluid
.
Layer
):
def
__init__
(
self
,
embedding_size
=
1000
,
linear_size
=
1000
,
param_attr
=
None
,
bias_attr
=
None
,
):
super
(
MLP_Hybrid
,
self
).
__init__
()
self
.
embedding
=
VocabParallelEmbedding
(
embedding_size
,
linear_size
)
self
.
_linear1
=
RowParallelLinear
(
linear_size
,
linear_size
,
has_bias
=
True
,
input_is_parallel
=
True
)
self
.
_linear2
=
ColumnParallelLinear
(
linear_size
,
linear_size
,
gather_output
=
True
,
has_bias
=
True
)
self
.
_linear3
=
Linear
(
linear_size
,
10
)
def
forward
(
self
,
src
):
inputs
=
self
.
embedding
(
src
)
# slice for a bug in row parallel linear
mp_group
=
(
fleet
.
get_hybrid_communicate_group
().
get_model_parallel_group
()
)
step
=
inputs
.
shape
[
-
1
]
//
mp_group
.
nranks
mp_rank
=
dist
.
get_rank
(
mp_group
)
mp_rank
=
mp_rank
if
mp_rank
>=
0
else
0
inputs
=
inputs
[...,
step
*
mp_rank
:
step
*
mp_rank
+
step
]
y
=
self
.
_linear1
(
inputs
)
y
=
self
.
_linear2
(
y
)
y
=
self
.
_linear3
(
y
)
return
y
class
MLP
(
fluid
.
Layer
):
def
__init__
(
self
,
embedding_size
=
1000
,
linear_size
=
1000
,
param_attr
=
None
,
bias_attr
=
None
,
):
super
(
MLP
,
self
).
__init__
()
self
.
embedding
=
Embedding
((
embedding_size
,
linear_size
))
self
.
_linear1
=
Linear
(
linear_size
,
linear_size
)
self
.
_linear2
=
Linear
(
linear_size
,
linear_size
)
self
.
_linear3
=
Linear
(
linear_size
,
10
)
def
forward
(
self
,
src
):
inputs
=
self
.
embedding
(
src
)
y
=
self
.
_linear1
(
inputs
)
y
=
self
.
_linear2
(
y
)
y
=
self
.
_linear3
(
y
)
return
y
def
gen_uniq_random_numbers
(
low
,
high
,
size
,
seed
):
assert
np
.
prod
(
size
)
<=
high
-
low
pool
=
list
(
range
(
low
,
high
))
data
=
np
.
zeros
(
size
).
astype
(
"int32"
).
reshape
(
-
1
)
np
.
random
.
seed
(
10245
)
for
i
in
range
(
np
.
prod
(
size
)):
pos
=
int
(
np
.
random
.
randint
(
0
,
len
(
pool
)))
data
[
i
]
=
pool
[
pos
]
pool
.
remove
(
pool
[
pos
])
np
.
random
.
seed
(
seed
)
return
data
.
reshape
(
size
)
class
RangeIterableDataset
(
IterableDataset
):
def
__init__
(
self
,
data_path
,
ebd
=
1000
,
start
=
0
,
end
=
100
,
linear_size
=
1000
,
seed
=
1024
):
self
.
start
=
start
self
.
end
=
end
self
.
img
=
gen_uniq_random_numbers
(
0
,
1000
,
(
100
,
1
),
seed
)
def
__iter__
(
self
):
for
idx
in
range
(
self
.
start
,
self
.
end
):
label
=
np
.
ones
(
1
).
astype
(
'int32'
)
yield
self
.
img
[
idx
],
label
def
optimizer_setting
(
args
,
model
):
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.0
if
args
.
strategy
==
"static"
else
0.01
,
parameters
=
model
.
parameters
(),
weight_decay
=
0.01
,
)
return
optimizer
def
train_mlp
(
args
,
model
,
loss
,
opt_state
=
None
,
save_model
=
False
):
optimizer
=
optimizer_setting
(
args
,
model
=
model
)
if
args
.
strategy
in
[
"mp"
,
"dp"
,
"pp"
]:
model
=
fleet
.
distributed_model
(
model
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
elif
args
.
strategy
==
"sharding_stage2"
:
model
,
optimizer
,
_
=
wrap_sharding_2_3
(
model
,
optimizer
,
None
,
False
,
2
)
elif
args
.
strategy
==
"sharding_stage3"
:
model
,
optimizer
,
_
=
wrap_sharding_2_3
(
model
,
optimizer
,
None
,
False
,
3
)
elif
args
.
strategy
!=
"single"
:
raise
ValueError
(
f
"not supported strategy:
{
args
.
strategy
}
"
)
dataset
=
RangeIterableDataset
(
data_path
=
os
.
path
.
join
(
args
.
output_dir
,
"data.npy"
),
seed
=
args
.
seed
)
train_loader
=
paddle
.
io
.
DataLoader
(
dataset
,
batch_size
=
100
,
drop_last
=
True
)
if
dist
.
get_world_size
()
>
1
:
pp_degree
=
(
fleet
.
get_hybrid_communicate_group
().
get_pipe_parallel_world_size
()
)
else
:
pp_degree
=
0
model
.
train
()
for
epo
in
range
(
epoch
):
for
step
,
data
in
enumerate
(
train_loader
()):
img
,
label
=
data
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
if
pp_degree
<=
1
:
out
=
model
(
img
)
avg_loss
=
loss
(
out
,
label
)
paddle
.
device
.
cuda
.
synchronize
()
avg_loss
.
backward
()
optimizer
.
step
()
else
:
avg_loss
=
model
.
train_batch
(
data
,
optimizer
)
model
.
eval
()
print
(
"=============== predict in dygraph mode ================="
)
for
step
,
data
in
enumerate
(
train_loader
()):
img
,
label
=
data
if
pp_degree
<=
1
:
out
=
model
(
img
)
out
=
out
.
numpy
()
else
:
out
=
model
.
eval_batch
(
data
)
out
=
np
.
array
(
out
)
paddle
.
device
.
cuda
.
synchronize
()
if
save_model
:
return
model
,
optimizer
,
out
return
None
def
train_mlp_static
(
args
,
model
,
loss
,
opt_state
=
None
,
save_model
=
False
):
optimizer
=
optimizer_setting
(
args
,
model
=
model
)
model
=
engine
.
Engine
(
model
,
loss
=
loss
,
optimizer
=
optimizer
,
strategy
=
None
)
dataset
=
RangeIterableDataset
(
data_path
=
os
.
path
.
join
(
args
.
output_dir
,
"data.npy"
),
seed
=
args
.
seed
)
model
.
load
(
os
.
path
.
join
(
args
.
load_dir
,
"saved"
),
load_optimizer
=
False
)
model
.
fit
(
dataset
,
epochs
=
1
)
model
.
save
(
os
.
path
.
join
(
args
.
output_dir
,
"static_save"
))
paddle
.
device
.
cuda
.
synchronize
()
print
(
"=============== predict in static mode ================="
)
out
=
model
.
predict
(
dataset
,
verbose
=
1000
)
if
save_model
:
return
model
,
optimizer
return
out
def
step_check
(
output_dir
):
p1
=
os
.
path
.
join
(
output_dir
,
"static.npy"
)
p2
=
os
.
path
.
join
(
output_dir
,
"dygraph.npy"
)
m1
=
np
.
load
(
p1
).
reshape
(
-
1
)
m2
=
np
.
load
(
p2
).
reshape
(
-
1
)
try
:
assert
np
.
allclose
(
m1
,
m2
,
rtol
=
1e-5
,
atol
=
1e-6
)
except
:
diff
=
m1
-
m2
logger
.
error
(
f
"max diff
{
diff
.
max
()
}
, min diff:
{
diff
.
min
()
}
"
)
logger
.
error
(
f
"
{
m1
[:
10
]
}
"
)
logger
.
error
(
f
"
{
m2
[:
10
]
}
"
)
raise
ValueError
(
"diff is too large"
)
def
step_save
(
strategy
,
output_dir
,
seed
):
python_exe
=
sys
.
executable
# save data
os
.
makedirs
(
output_dir
+
"/logs"
,
exist_ok
=
True
)
filename
=
os
.
path
.
basename
(
__file__
)
if
strategy
!=
"single"
:
cmd
=
(
f
"
{
python_exe
}
-m paddle.distributed.launch --log_dir
{
output_dir
}
/logs"
f
" --gpus 0,1
{
filename
}
--cmd save --strategy
{
strategy
}
--output_dir
{
output_dir
}
--seed
{
seed
}
"
)
else
:
cmd
=
f
"
{
python_exe
}
{
filename
}
--cmd save --strategy
{
strategy
}
--output_dir
{
output_dir
}
--seed
{
seed
}
"
logger
.
info
(
f
"exe:
{
cmd
}
"
)
p
=
subprocess
.
Popen
(
cmd
.
split
())
p
.
communicate
()
assert
p
.
poll
()
==
0
def
step_load
(
curent_strateggy
,
saved_dir
,
seed
):
python_exe
=
sys
.
executable
os
.
makedirs
(
f
"
{
saved_dir
}
/load/logs"
,
exist_ok
=
True
)
filename
=
os
.
path
.
basename
(
__file__
)
# load dp
cmd
=
(
f
"
{
python_exe
}
-m paddle.distributed.launch --log_dir
{
saved_dir
}
/load/logs"
f
" --gpus 0
{
filename
}
--cmd load --strategy
{
curent_strateggy
}
--output_dir
{
saved_dir
}
--load_dir
{
saved_dir
}
--seed
{
seed
}
"
)
logger
.
info
(
f
"exe:
{
cmd
}
"
)
env
=
copy
.
copy
(
os
.
environ
)
env
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
p
=
subprocess
.
Popen
(
cmd
.
split
(),
env
=
env
)
p
.
communicate
()
assert
p
.
poll
()
==
0
def
wrap_sharding_2_3
(
model
,
optimizer
,
scaler
,
sharding_offload
,
stage
):
group
=
fleet
.
get_hybrid_communicate_group
().
get_sharding_parallel_group
()
level
=
"p_g_os"
if
stage
==
3
else
"os_g"
return
group_sharded_parallel
(
model
=
model
,
optimizer
=
optimizer
,
level
=
level
,
scaler
=
scaler
,
group
=
group
,
offload
=
sharding_offload
,
)
def
test_save_load
(
args
):
np
.
random
.
seed
(
args
.
seed
)
paddle
.
seed
(
args
.
seed
)
if
args
.
cmd
==
"main"
:
run_case
(
args
)
return
paddle
.
distributed
.
init_parallel_env
()
strategy
=
fleet
.
DistributedStrategy
()
if
args
.
strategy
==
"dp"
:
strategy
.
hybrid_configs
=
{
"dp_degree"
:
2
,
"mp_degree"
:
1
,
"pp_degree"
:
1
,
"sharding_degree"
:
1
,
}
elif
args
.
strategy
in
[
"sharding_stage2"
,
"sharding_stage3"
]:
strategy
.
hybrid_configs
=
{
"dp_degree"
:
1
,
"mp_degree"
:
1
,
"pp_degree"
:
1
,
"sharding_degree"
:
2
,
}
elif
args
.
strategy
==
"mp"
:
strategy
.
hybrid_configs
=
{
"dp_degree"
:
1
,
"mp_degree"
:
2
,
"pp_degree"
:
1
,
"sharding_degree"
:
1
,
}
elif
args
.
strategy
==
"pp"
:
strategy
.
hybrid_configs
=
{
"dp_degree"
:
1
,
"mp_degree"
:
1
,
"pp_degree"
:
2
,
"sharding_degree"
:
1
,
}
strategy
.
pipeline_configs
=
{
"accumulate_steps"
:
10
,
"micro_batch_size"
:
10
,
}
elif
args
.
strategy
==
"static"
:
paddle
.
enable_static
()
elif
args
.
strategy
!=
"single"
:
raise
ValueError
(
f
"Not supported strategy:
{
args
.
strategy
}
"
)
loss
=
paddle
.
nn
.
CrossEntropyLoss
()
fleet
.
set_log_level
(
"INFO"
)
if
dist
.
get_world_size
()
<=
1
:
mlp1
=
MLP
()
if
args
.
strategy
==
"static"
:
out_static
=
train_mlp_static
(
args
,
mlp1
,
loss
,
save_model
=
False
)
np
.
save
(
os
.
path
.
join
(
args
.
output_dir
,
"static.npy"
),
out_static
)
else
:
model
,
_
,
out_dygraph
=
train_mlp
(
args
,
mlp1
,
loss
,
save_model
=
True
)
np
.
save
(
os
.
path
.
join
(
args
.
output_dir
,
"dygraph.npy"
),
out_dygraph
)
else
:
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
pp_group
=
(
fleet
.
get_hybrid_communicate_group
().
get_pipe_parallel_group
()
)
if
pp_group
.
nranks
>
1
:
mlp1
=
MLP_pipe
()
else
:
mlp1
=
MLP_Hybrid
()
model
,
_
,
out_dygraph
=
train_mlp
(
args
,
mlp1
,
loss
,
save_model
=
True
)
if
(
dist
.
get_world_size
()
==
0
or
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
):
np
.
save
(
os
.
path
.
join
(
args
.
output_dir
,
"dygraph.npy"
),
out_dygraph
)
if
args
.
cmd
==
"save"
:
save_for_auto_inference
(
os
.
path
.
join
(
args
.
output_dir
,
"saved"
),
model
)
def
run_case
(
args
):
saving_strategy
=
args
.
test_case
.
split
(
":"
)[
0
]
loading_strategy
=
args
.
test_case
.
split
(
":"
)[
1
]
output_dir
=
tempfile
.
mkdtemp
()
if
os
.
path
.
isdir
(
output_dir
):
shutil
.
rmtree
(
output_dir
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
try
:
step_save
(
saving_strategy
,
output_dir
,
args
.
seed
)
step_load
(
loading_strategy
,
output_dir
,
args
.
seed
+
1
)
step_check
(
output_dir
)
except
Exception
as
e
:
shutil
.
rmtree
(
output_dir
)
raise
RuntimeError
(
f
"Test failed.
\n
{
e
.
__str__
()
}
"
)
shutil
.
rmtree
(
output_dir
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--cmd"
,
default
=
"main"
,
choices
=
[
"main"
,
"save"
,
"load"
]
)
parser
.
add_argument
(
"--strategy"
,
required
=
False
,
choices
=
[
"single"
,
"dp"
,
"mp"
,
"pp"
,
"sharding_stage2"
,
"sharding_stage3"
,
"static"
,
],
)
parser
.
add_argument
(
"--load_way"
,
choices
=
[
"paddle.load"
,
"load"
],
required
=
False
)
parser
.
add_argument
(
"--load_dir"
,
required
=
False
)
parser
.
add_argument
(
"--output_dir"
,
required
=
False
)
parser
.
add_argument
(
"--output_param_path"
,
required
=
False
)
parser
.
add_argument
(
"--test_case"
,
required
=
False
,
choices
=
[
"dp:static"
,
"mp:static"
,
"pp:static"
,
"sharding_stage2:static"
,
"sharding_stage3:static"
,
"single:static"
,
],
)
parser
.
add_argument
(
"--gather_to"
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
2022
)
args
=
parser
.
parse_args
()
test_save_load
(
args
)
python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_save_for_auto_infer.py
0 → 100644
浏览文件 @
0449b841
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
import
subprocess
import
sys
def
strategy_test
(
saving
,
seed
=
1024
,
loading
=
"static"
):
cmd
=
f
"
{
sys
.
executable
}
dygraph_save_for_auto_infer.py --test_case
{
saving
}
:
{
loading
}
--cmd main --seed
{
seed
}
"
p
=
subprocess
.
Popen
(
cmd
.
split
())
p
.
communicate
()
assert
p
.
poll
()
==
0
class
TestHybrid
(
unittest
.
TestCase
):
def
test_dygraph_save_load_dp_sharding_stage2
(
self
):
strategy_test
(
"dp"
)
strategy_test
(
"mp"
)
strategy_test
(
"pp"
)
class
TestSharding
(
unittest
.
TestCase
):
def
test_dygraph_save_load_dp_sharding_stage2
(
self
):
strategy_test
(
"sharding_stage2"
)
strategy_test
(
"sharding_stage3"
)
class
TestSingleCard
(
unittest
.
TestCase
):
def
test_dygraph_save_load_dp_sharding_stage2
(
self
):
strategy_test
(
"single"
)
if
__name__
==
"__main__"
:
os
.
environ
[
"FLAGS_enable_eager_mode"
]
=
"1"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0,1"
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv
浏览文件 @
0449b841
...
@@ -84,3 +84,4 @@ test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_
...
@@ -84,3 +84,4 @@ test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_
test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
python/paddle/incubate/distributed/utils/io/__init__.py
浏览文件 @
0449b841
...
@@ -12,5 +12,5 @@
...
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.dist_save
import
save
from
.dist_save
import
save
,
save_for_auto_inference
from
.dist_load
import
load
from
.dist_load
import
load
python/paddle/incubate/distributed/utils/io/dist_save.py
浏览文件 @
0449b841
...
@@ -20,10 +20,10 @@ from paddle.distributed.fleet.utils.log_util import logger
...
@@ -20,10 +20,10 @@ from paddle.distributed.fleet.utils.log_util import logger
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid.framework
import
dygraph_only
import
copy
import
copy
import
sys
import
sys
from
.save_for_auto
import
save_for_auto_inference
from
paddle.distributed.fleet.utils.log_util
import
logger
from
paddle.distributed.fleet.utils.log_util
import
logger
__all__
=
[
"save"
]
__all__
=
[
"save"
,
"save_for_auto_inference"
]
@
dygraph_only
@
dygraph_only
...
...
python/paddle/incubate/distributed/utils/io/save_for_auto.py
0 → 100644
浏览文件 @
0449b841
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.distributed
as
dist
import
paddle.distributed.fleet
as
fleet
import
re
import
paddle
from
paddle.distributed.fleet.utils.log_util
import
logger
import
os
import
pickle
from
paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3
import
(
GroupShardedStage3
,
)
from
paddle.fluid.framework
import
dygraph_only
import
copy
import
numpy
as
np
__all__
=
[
"save_for_auto_inference"
]
@
dygraph_only
def
save_for_auto_inference
(
path_prefix
,
dist_model
,
cvt2cpu
=
False
):
"""
Description:
Save model parameters for auto parallel inference.
Supporting dp + mp + pp + sharding(stage1), dp + sharding stage2-3.
MoE not sdupported till MoE is supported in auto parallel mode.
Args:
path_prefix: path prefix to save
If `path_preifx` ends with path sepreator,
the path is processed as a directory and parameters will be saved in it,
automatically named saved_parameters.
Otherwisw, the parameters will be saved with name
path_preifx_dist{global_rank}.pdparams and path_preifx_dist{global_rank}.pdattrs
dist_model:
model in distributed modeß
cvt2cpu: wheather to move parameters to CPU when using sharding stage 3.
The var is invalid if not using sharding stage 3.
Returns:
None
Examples:
dist_model = build_distributed_model()
path_prefix = "path/to/save_infer"
save_for_auto_inference(path_prefix, dist_model=dist_model, original_model=single_model, cvt2cpu=False)
Outputs:
path/to/save_infer_dist0.pdparams path/to/save_infer_dist1.pdparams path/to/save_infer_dist2.pdparams ...
path/to/save_infer_dist0.pdattr path/to/save_infer_dist1.pdattr path/to/save_infer_dist2.pdattr ...
"""
save_dir
,
basename_prefix
=
_get_abs_saved_prefix
(
path_prefix
)
if
isinstance
(
dist_model
,
GroupShardedStage3
):
dist_model
.
get_all_parameters
(
cvt2cpu
)
wrapped_dict
=
_get_wrapped_dist_state_dict
(
dist_model
.
state_dict
())
global_rank
=
paddle
.
distributed
.
get_rank
()
# save parameters
paddle
.
save
(
wrapped_dict
,
os
.
path
.
join
(
save_dir
,
f
"
{
basename_prefix
}
_dist
{
global_rank
}
.pdparams"
),
)
# save attributes
_save_param_attr
(
wrapped_dict
,
os
.
path
.
join
(
save_dir
,
f
"
{
basename_prefix
}
_dist
{
global_rank
}
.pdattr"
),
)
# unset dims mapping after saving attrs
for
_
,
dist_param
in
wrapped_dict
.
items
():
_unset_dims_mapping
(
dist_param
)
def
_is_first_used
(
param
):
return
not
hasattr
(
param
,
"is_firstly_shared"
)
or
param
.
is_firstly_shared
def
_get_all_ranks_of_pp
(
pp_rank
,
dp_degree
,
mp_degree
,
pp_degree
):
"""
Description:
get all global ranks involving given pp_rank
"""
process_group
=
[]
world_size
=
dp_degree
*
mp_degree
*
pp_degree
for
i
in
range
(
dp_degree
):
for
k
in
range
(
mp_degree
):
process_group
.
append
(
i
*
world_size
//
dp_degree
+
pp_rank
*
world_size
//
dp_degree
//
pp_degree
+
k
)
return
process_group
def
_save_param_attr
(
state_dict_
,
path
,
dims_mapping_dict
=
None
):
"""
Description:
save params' attr dict
Args:
state_dict_:
state for which to save attrs, when the state is optimzier state, the master and LRScheduler will be reomoved.
path:
path to save
dims_mapping_dict:
Dims mapping dict, mapping from parameter name in state_dict_ to dims_mapping.
If parameter in state_dict_ has attribute 'dims_mapping', the dims_mapping is ignored.
If parameter has no attribute 'dims_mapping', the dims mapping must contains the parameter's name.
"""
state_dict
=
copy
.
copy
(
state_dict_
)
# remove master_weights and LRScheduler, which needs no parameter attributes to save
state_dict
.
pop
(
"master_weights"
,
None
)
state_dict
.
pop
(
"LR_Scheduler"
,
None
)
if
dims_mapping_dict
is
not
None
:
assert
isinstance
(
dims_mapping_dict
,
dict
),
"dims_mapping_dict must be an instance of dict"
for
k
in
state_dict
.
keys
():
assert
(
k
in
dims_mapping_dict
),
f
"param
{
k
}
cannot find dims mapping in dims_mapping_dict"
if
dist
.
get_world_size
()
>
1
:
hcg
=
fleet
.
get_hybrid_communicate_group
()
dp_degree
=
hcg
.
get_data_parallel_world_size
()
mp_degree
=
hcg
.
get_model_parallel_world_size
()
pp_degree
=
hcg
.
get_pipe_parallel_world_size
()
sharding_degree
=
hcg
.
get_sharding_parallel_world_size
()
dp_degree
=
dp_degree
*
sharding_degree
pp_group
=
hcg
.
get_pipe_parallel_group
()
else
:
pp_degree
=
1
dp_degree
=
1
mp_degree
=
1
pp_group
=
None
hcg
=
None
logger
.
debug
(
f
"dp degree * sharding degree :
{
dp_degree
}
"
)
logger
.
debug
(
f
"mp degree:
{
mp_degree
}
"
)
logger
.
debug
(
f
"pp degree:
{
pp_degree
}
"
)
pp_rank
=
dist
.
get_rank
(
pp_group
)
# Why condition 'pp_rank < 0' exists?
# Because if pp_degree = 1, pp_rank is set -1
pp_rank
=
0
if
pp_rank
<=
0
else
pp_rank
if
dist
.
get_world_size
()
>
1
:
process_group
=
_get_all_ranks_of_pp
(
pp_rank
,
dp_degree
,
mp_degree
,
pp_degree
)
else
:
process_group
=
[
0
]
attr_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
dims
=
len
(
v
.
shape
)
logger
.
debug
(
f
"shape: ,
{
k
}
,
{
dims
}
"
)
attr_d
=
{
"process_shape"
:
[
dp_degree
,
mp_degree
]
if
hcg
else
[
1
],
"process_group"
:
process_group
,
"dims_mapping"
:
v
.
dims_mapping
if
hasattr
(
v
,
"dims_mapping"
)
else
[
-
1
for
_
in
v
.
shape
],
}
attr_dict
[
k
]
=
attr_d
with
open
(
path
,
"wb"
)
as
f
:
pickle
.
dump
(
attr_dict
,
f
)
def
_unset_dims_mapping
(
param
):
if
hasattr
(
param
,
"dims_mapping"
):
delattr
(
param
,
"dims_mapping"
)
def
_get_dims_mapping
(
dist_parameter
,
mp_group
):
"""
Description:
return the sliting mapping:
{tensor_name: spiting_strategy}
Args:
dist_parameters(list): distributed model parameters
mp_group(ProcessGroup): Model Parallel communication group
Return:
The sliting mapping
Examples:
spliting_strategy's format (-1, -1, -1, 0), meaing the dims
of the tennsor is 4 and it is splited along the first strategy axis in mesh
Mesh Examples: (2, 4) means dp=2, mp=4
"""
import
numpy
as
np
dist_shape
=
np
.
array
(
dist_parameter
.
shape
)
if
hasattr
(
dist_parameter
,
"split_axis"
):
aixs
=
getattr
(
dist_parameter
,
"split_axis"
)
mapping
=
[
-
1
for
_
in
dist_shape
]
mapping
[
aixs
]
=
1
logger
.
debug
(
f
"
{
dist_parameter
.
name
}
has attr split_axis: mapping:
{
mapping
}
"
)
else
:
mapping
=
[
-
1
for
_
in
dist_shape
]
logger
.
debug
(
f
"normal parameter:
{
dist_parameter
.
name
}
"
)
return
mapping
def
_get_abs_saved_prefix
(
path_prefix
):
"""
Description:
Get absolute dir path and basename prefix of path_prefix, with making path_prefix's directories.
If path_prefix is a directory name, basename is set 'saved_parameters'.
If path_prefix is a file name, basename is extracted from path_prefix.
Args:
path_prefix: str
Return:
(dirpath: str, basename: str)
"""
abs_prefix
=
os
.
path
.
abspath
(
path_prefix
)
if
abs_prefix
[
-
1
]
==
os
.
path
.
sep
:
save_dir
=
abs_prefix
basename_prefix
=
"saved_parameters"
else
:
save_dir
=
os
.
path
.
dirname
(
abs_prefix
)
basename_prefix
=
os
.
path
.
basename
(
abs_prefix
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
return
save_dir
,
basename_prefix
def
_name_mapping_dist2single
(
state_dict
,
pp_group
):
key_list
=
[]
param_keys
=
[
v
.
name
for
_
,
v
in
state_dict
.
items
()
if
isinstance
(
v
,
paddle
.
Tensor
)
and
_is_first_used
(
v
)
]
if
pp_group
.
nranks
==
1
:
return
{
k
:
k
for
k
in
param_keys
}
dist
.
all_gather_object
(
key_list
,
param_keys
,
pp_group
)
# find how many a op in a each pp:
# {"linear:"[0, 2,0,1,1,...]}
param_types
=
{}
matcher
=
re
.
compile
(
r
"^\w+_\d+(?=\.)"
)
for
pp
,
keys
in
enumerate
(
key_list
):
param_type_idx
=
{}
for
k
in
keys
:
matched
=
matcher
.
search
(
k
)
logger
.
debug
(
f
"matched:
{
k
}
:
{
matched
}
"
)
assert
(
matched
is
not
None
),
f
"the name of param, '
{
k
}
', is not satisfyied the format 'name_idx.xxx'"
name_idx
=
k
[
matched
.
start
()
:
matched
.
end
()]
logger
.
debug
(
f
"get param_type_idx:
{
name_idx
}
"
)
if
name_idx
in
param_type_idx
:
continue
name
=
"_"
.
join
(
name_idx
.
split
(
"_"
)[:
-
1
])
idx
=
int
(
name_idx
.
split
(
"_"
)[
-
1
])
param_type_idx
.
update
({
name_idx
:
(
name
,
idx
)})
if
name
not
in
param_types
:
param_types
[
name
]
=
[
0
]
*
pp_group
.
nranks
param_types
[
name
][
pp
]
+=
1
# check if continous
types_idx
=
{}
for
_
,
v
in
param_type_idx
.
items
():
if
v
[
0
]
not
in
types_idx
:
types_idx
.
update
({
v
[
0
]:
[
v
[
1
]]})
else
:
types_idx
[
v
[
0
]].
append
(
v
[
1
])
for
k
,
v
in
types_idx
.
items
():
assert
v
==
list
(
range
(
v
[
0
],
v
[
-
1
]
+
1
)
),
f
"
{
k
}
is not continous:
{
v
}
"
logger
.
debug
(
f
"param type:
{
param_types
}
"
)
# analyse starting index
for
k
in
param_types
.
keys
():
param_types
[
k
]
=
np
.
cumsum
([
0
]
+
param_types
[
k
][:
-
1
])
logger
.
debug
(
f
"params type:
{
param_types
}
"
)
name_mapping
=
{}
pp_rank
=
dist
.
get_rank
(
pp_group
)
for
k
in
key_list
[
pp_rank
]:
matched
=
matcher
.
search
(
k
)
name_idx
=
k
[
matched
.
start
()
:
matched
.
end
()]
name
=
"_"
.
join
(
name_idx
.
split
(
"_"
)[:
-
1
])
idx
=
int
(
name_idx
.
split
(
"_"
)[
-
1
])
logger
.
debug
(
f
"idx:
{
idx
}
"
)
new_idx
=
param_types
[
name
][
pp_rank
]
+
idx
logger
.
debug
(
f
"new idx:
{
new_idx
}
"
)
new_name_idx
=
name
+
"_"
+
str
(
new_idx
)
name_mapping
[
k
]
=
new_name_idx
+
k
[
matched
.
end
()
:]
return
name_mapping
def
_get_wrapped_dist_state_dict
(
dist_state_dict
):
wrapped_state_dict
=
dict
()
if
dist
.
get_world_size
()
<=
1
:
for
_
,
v
in
dist_state_dict
.
items
():
wrapped_state_dict
[
v
.
name
]
=
v
return
wrapped_state_dict
hcg
=
fleet
.
get_hybrid_communicate_group
()
pp_group
=
hcg
.
get_pipe_parallel_group
()
mp_group
=
hcg
.
get_model_parallel_group
()
logger
.
debug
(
"execute _name_mapping_dist2single"
)
name_mapping
=
_name_mapping_dist2single
(
dist_state_dict
,
pp_group
)
for
_
,
v
in
dist_state_dict
.
items
():
if
not
_is_first_used
(
v
):
logger
.
debug
(
f
"not first used :
{
v
.
name
}
"
)
continue
wrapped_state_dict
[
name_mapping
[
v
.
name
]]
=
v
setattr
(
v
,
"dims_mapping"
,
_get_dims_mapping
(
v
,
mp_group
))
logger
.
debug
(
f
"saving param:
{
v
.
name
}
->
{
name_mapping
[
v
.
name
]
}
shape:
{
v
.
shape
}
"
)
return
wrapped_state_dict
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录