Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
503f422e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
503f422e
编写于
4月 25, 2023
作者:
W
wuhuachaocoding
提交者:
GitHub
4月 25, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mp_sync config. (#53254)
上级
00f747f2
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
150 addition
and
34 deletion
+150
-34
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+6
-0
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
...optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
+65
-30
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py
...ts/unittests/collective/fleet/hybrid_parallel_mp_model.py
+48
-4
python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py
...tests/collective/fleet/test_fleet_distributed_strategy.py
+30
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
503f422e
...
@@ -55,6 +55,7 @@ message MpConfig {
...
@@ -55,6 +55,7 @@ message MpConfig {
optional
bool
sync_param
=
1
[
default
=
false
];
optional
bool
sync_param
=
1
[
default
=
false
];
optional
bool
sync_grad
=
2
[
default
=
false
];
optional
bool
sync_grad
=
2
[
default
=
false
];
optional
bool
sync_moment
=
3
[
default
=
false
];
optional
bool
sync_moment
=
3
[
default
=
false
];
optional
string
sync_mode
=
4
[
default
=
'broadcast'
];
}
}
message
PpConfig
{
message
PpConfig
{
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
503f422e
...
@@ -146,6 +146,8 @@ class DistributedStrategy:
...
@@ -146,6 +146,8 @@ class DistributedStrategy:
self
.
strategy
.
sync_nccl_allreduce
=
bool
(
_global_flags
()[
key
])
self
.
strategy
.
sync_nccl_allreduce
=
bool
(
_global_flags
()[
key
])
self
.
hybrid_parallel_order
=
[
'dp'
,
'pp'
,
'sharding'
,
'mp'
]
self
.
hybrid_parallel_order
=
[
'dp'
,
'pp'
,
'sharding'
,
'mp'
]
self
.
sync_param_name
=
[
"embedding"
,
"layer_norm"
,
".b_"
]
self
.
__lock_attr
=
True
self
.
__lock_attr
=
True
logger
.
info
(
"distributed strategy initialized"
)
logger
.
info
(
"distributed strategy initialized"
)
...
@@ -1698,6 +1700,10 @@ class DistributedStrategy:
...
@@ -1698,6 +1700,10 @@ class DistributedStrategy:
)
)
if
"mp_configs"
in
configs
:
if
"mp_configs"
in
configs
:
if
"sync_param_name"
in
configs
[
"mp_configs"
]:
self
.
sync_param_name
=
configs
[
"mp_configs"
][
"sync_param_name"
]
configs
[
"mp_configs"
].
pop
(
"sync_param_name"
)
assign_configs_value
(
assign_configs_value
(
self
.
strategy
.
hybrid_configs
.
mp_configs
,
configs
[
"mp_configs"
]
self
.
strategy
.
hybrid_configs
.
mp_configs
,
configs
[
"mp_configs"
]
)
)
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
浏览文件 @
503f422e
...
@@ -303,9 +303,20 @@ class HybridParallelOptimizer:
...
@@ -303,9 +303,20 @@ class HybridParallelOptimizer:
inner_opt
.
_grad_clip
,
hcg
inner_opt
.
_grad_clip
,
hcg
)
)
def
_filter_fn
(
self
,
param
):
def
_insert_sync
(
self
,
sync_var
,
src
,
mp_group
,
sync_mode
):
if
sync_mode
==
"broadcast"
:
paddle
.
distributed
.
broadcast
(
sync_var
,
src
=
src
,
group
=
mp_group
,
sync_op
=
True
)
else
:
paddle
.
distributed
.
all_reduce
(
sync_var
,
group
=
mp_group
,
sync_op
=
True
)
sync_var
.
scale_
(
1.0
/
mp_group
.
nranks
)
def
_filter_fn
(
self
,
param
,
strategy
):
p_name
=
param
.
name
p_name
=
param
.
name
tar_param
=
[
"embedding"
,
"layer_norm"
,
".b_"
]
tar_param
=
strategy
.
sync_param_name
if
param
.
is_distributed
is
False
:
if
param
.
is_distributed
is
False
:
for
tar
in
tar_param
:
for
tar
in
tar_param
:
if
tar
in
p_name
:
if
tar
in
p_name
:
...
@@ -329,26 +340,48 @@ class HybridParallelOptimizer:
...
@@ -329,26 +340,48 @@ class HybridParallelOptimizer:
or
mp_configs
.
sync_moment
or
mp_configs
.
sync_moment
):
):
params
=
sorted
(
params
=
sorted
(
[
p
for
p
in
parameters_list
if
self
.
_filter_fn
(
p
)],
[
p
for
p
in
parameters_list
if
self
.
_filter_fn
(
p
,
fleet
.
fleet
.
_user_defined_strategy
)
],
key
=
lambda
p
:
p
.
name
,
key
=
lambda
p
:
p
.
name
,
)
)
# Grad sync before opt
if
mp_group
.
nranks
>
1
and
mp_configs
and
mp_configs
.
sync_grad
:
if
mp_group
.
nranks
>
1
and
mp_configs
and
mp_configs
.
sync_grad
:
for
p
in
params
:
for
p
in
params
:
if
p
.
grad
is
None
:
if
hasattr
(
p
,
"main_grad"
)
and
p
.
main_grad
is
not
None
:
continue
assert
p
.
grad
is
None
paddle
.
distributed
.
broadcast
(
self
.
_insert_sync
(
p
.
grad
,
src
=
src_rank
,
group
=
mp_group
,
sync_op
=
True
p
.
main_grad
,
src_rank
,
mp_group
,
mp_configs
.
sync_mode
)
)
elif
p
.
grad
is
not
None
:
self
.
_insert_sync
(
p
.
grad
,
src_rank
,
mp_group
,
mp_configs
.
sync_mode
)
self
.
_inner_opt
.
step
()
self
.
_inner_opt
.
step
()
if
mp_group
.
nranks
>
1
and
mp_configs
and
mp_configs
.
sync_param
:
if
mp_group
.
nranks
>
1
and
mp_configs
and
mp_configs
.
sync_param
:
for
p
in
params
:
for
p
in
params
:
paddle
.
distributed
.
broadcast
(
# Param sync after opt
p
,
src
=
src_rank
,
group
=
mp_group
,
sync_op
=
True
self
.
_insert_sync
(
p
,
src_rank
,
mp_group
,
mp_configs
.
sync_mode
)
)
# Master param sync after opt
if
(
hasattr
(
self
.
_inner_opt
,
"_multi_precision"
)
and
self
.
_inner_opt
.
_multi_precision
and
p
.
name
in
self
.
_inner_opt
.
_master_weights
):
self
.
_insert_sync
(
self
.
_inner_opt
.
_master_weights
[
p
.
name
],
src_rank
,
mp_group
,
mp_configs
.
sync_mode
,
)
# Moment sync after opt
if
mp_group
.
nranks
>
1
and
mp_configs
and
mp_configs
.
sync_moment
:
if
mp_group
.
nranks
>
1
and
mp_configs
and
mp_configs
.
sync_moment
:
for
p
in
params
:
for
p
in
params
:
# support opt state of adam and adamw to broadcast now.
# support opt state of adam and adamw to broadcast now.
...
@@ -357,28 +390,30 @@ class HybridParallelOptimizer:
...
@@ -357,28 +390,30 @@ class HybridParallelOptimizer:
(
paddle
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
),
(
paddle
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
),
):
):
if
(
if
(
self
.
_inner_opt
.
_multi_precision
p
.
name
and
p
.
name
in
self
.
_master_weights
in
self
.
_inner_opt
.
_accumulators
[
self
.
_inner_opt
.
_moment1_acc_str
]
):
):
paddle
.
distributed
.
broadcast
(
moment1
=
self
.
_inner_opt
.
_get_accumulator
(
self
.
_inner_opt
.
_m
aster_weights
[
p
.
name
],
self
.
_inner_opt
.
_m
oment1_acc_str
,
p
src
=
src_rank
,
)
group
=
mp_group
,
self
.
_insert_sync
(
sync_op
=
True
,
moment1
,
src_rank
,
mp_group
,
mp_configs
.
sync_mode
)
)
moment1
=
self
.
_inner_opt
.
_get_accumulator
(
if
(
self
.
_inner_opt
.
_moment1_acc_str
,
p
p
.
name
)
in
self
.
_inner_opt
.
_accumulators
[
moment2
=
self
.
_inner_opt
.
_get_accumulator
(
self
.
_inner_opt
.
_moment2_acc_str
self
.
_inner_opt
.
_moment2_acc_str
,
p
]
)
)
:
paddle
.
distributed
.
broadcast
(
moment2
=
self
.
_inner_opt
.
_get_accumulator
(
moment1
,
src
=
src_rank
,
group
=
mp_group
,
sync_op
=
True
self
.
_inner_opt
.
_moment2_acc_str
,
p
)
)
paddle
.
distributed
.
broadcast
(
self
.
_insert_sync
(
moment2
,
src
=
src_rank
,
group
=
mp_group
,
sync_op
=
Tru
e
moment2
,
src_rank
,
mp_group
,
mp_configs
.
sync_mod
e
)
)
@
no_grad
()
@
no_grad
()
@
framework
.
dygraph_only
@
framework
.
dygraph_only
...
...
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py
浏览文件 @
503f422e
...
@@ -202,6 +202,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
...
@@ -202,6 +202,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
self
,
self
,
batchs
,
batchs
,
fp16
=
False
,
fp16
=
False
,
amp_level
=
"O1"
,
mp_sync_param
=
False
,
mp_sync_param
=
False
,
mp_sync_grad
=
False
,
mp_sync_grad
=
False
,
mp_sync_moment
=
False
,
mp_sync_moment
=
False
,
...
@@ -232,6 +233,11 @@ class TestDistMPSyncTraning(unittest.TestCase):
...
@@ -232,6 +233,11 @@ class TestDistMPSyncTraning(unittest.TestCase):
learning_rate
=
0.1
,
parameters
=
model
.
parameters
()
learning_rate
=
0.1
,
parameters
=
model
.
parameters
()
)
)
if
fp16
and
amp_level
==
"O2"
:
model
,
optimizer
=
paddle
.
amp
.
decorate
(
models
=
model
,
optimizers
=
optimizer
,
level
=
'O2'
)
strategy
=
fleet
.
fleet
.
_user_defined_strategy
strategy
=
fleet
.
fleet
.
_user_defined_strategy
strategy
.
hybrid_configs
=
{
strategy
.
hybrid_configs
=
{
"dp_degree"
:
self
.
data_parallel_size
,
"dp_degree"
:
self
.
data_parallel_size
,
...
@@ -246,15 +252,15 @@ class TestDistMPSyncTraning(unittest.TestCase):
...
@@ -246,15 +252,15 @@ class TestDistMPSyncTraning(unittest.TestCase):
model
=
fleet
.
distributed_model
(
model
)
model
=
fleet
.
distributed_model
(
model
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
return
self
.
train_batch
(
batchs
,
model
,
optimizer
,
fp16
)
return
self
.
train_batch
(
batchs
,
model
,
optimizer
,
fp16
,
amp_level
)
def
train_batch
(
self
,
batchs
,
model
,
optimizer
,
fp16
=
False
):
def
train_batch
(
self
,
batchs
,
model
,
optimizer
,
fp16
=
False
,
amp_level
=
"O1"
):
losses
=
[]
losses
=
[]
if
fp16
:
if
fp16
:
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
scaler
=
fleet
.
distributed_scaler
(
scaler
)
scaler
=
fleet
.
distributed_scaler
(
scaler
)
for
batch
in
batchs
:
for
batch
in
batchs
:
with
paddle
.
amp
.
auto_cast
(
enable
=
fp16
,
level
=
'O1'
):
with
paddle
.
amp
.
auto_cast
(
enable
=
fp16
,
level
=
amp_level
):
output
=
model
(
batch
)
output
=
model
(
batch
)
loss
=
output
.
mean
()
loss
=
output
.
mean
()
losses
.
append
(
loss
.
numpy
())
losses
.
append
(
loss
.
numpy
())
...
@@ -295,7 +301,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
...
@@ -295,7 +301,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
for
i
in
range
(
len
(
losses
)):
for
i
in
range
(
len
(
losses
)):
np
.
testing
.
assert_allclose
(
losses
[
i
],
losses_sync
[
i
],
rtol
=
1e-6
)
np
.
testing
.
assert_allclose
(
losses
[
i
],
losses_sync
[
i
],
rtol
=
1e-6
)
# test fp16
# test fp16
O1
losses_fp16
=
self
.
build_model_optimizer_train
(
batchs
,
fp16
=
True
)
losses_fp16
=
self
.
build_model_optimizer_train
(
batchs
,
fp16
=
True
)
losses_sync_fp16
=
self
.
build_model_optimizer_train
(
losses_sync_fp16
=
self
.
build_model_optimizer_train
(
batchs
,
batchs
,
...
@@ -310,6 +316,24 @@ class TestDistMPSyncTraning(unittest.TestCase):
...
@@ -310,6 +316,24 @@ class TestDistMPSyncTraning(unittest.TestCase):
losses_fp16
[
i
],
losses_sync_fp16
[
i
],
rtol
=
1e-6
losses_fp16
[
i
],
losses_sync_fp16
[
i
],
rtol
=
1e-6
)
)
# test fp16 O2
losses_fp16_O2
=
self
.
build_model_optimizer_train
(
batchs
,
fp16
=
True
,
amp_level
=
"O2"
)
losses_sync_fp16_O2
=
self
.
build_model_optimizer_train
(
batchs
,
fp16
=
True
,
amp_level
=
"O2"
,
mp_sync_param
=
mp_sync_param
,
mp_sync_grad
=
mp_sync_grad
,
mp_sync_moment
=
mp_sync_moment
,
)
for
i
in
range
(
len
(
losses_fp16_O2
)):
np
.
testing
.
assert_allclose
(
losses_fp16_O2
[
i
],
losses_sync_fp16_O2
[
i
],
rtol
=
1e-6
)
def
test_mp_sync_param
(
self
):
def
test_mp_sync_param
(
self
):
self
.
mp_sync_base
(
mp_sync_param
=
True
)
self
.
mp_sync_base
(
mp_sync_param
=
True
)
...
@@ -325,6 +349,26 @@ class TestDistMPSyncTraning(unittest.TestCase):
...
@@ -325,6 +349,26 @@ class TestDistMPSyncTraning(unittest.TestCase):
)
)
class
TestDistMPSyncModelTraning
(
TestDistMPSyncTraning
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
self
.
model_parallel_size
=
2
self
.
data_parallel_size
=
1
strategy
.
hybrid_configs
=
{
"dp_degree"
:
self
.
data_parallel_size
,
"mp_degree"
:
self
.
model_parallel_size
,
"pp_degree"
:
1
,
"mp_configs"
:
{
"sync_param"
:
False
,
"sync_grad"
:
False
,
"sync_moment"
:
False
,
"sync_mode"
:
"average"
,
"sync_param_name"
:
[
"embedding"
,
"layer_norm"
,
".b_"
],
},
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
class
TestDistMPTraning
(
unittest
.
TestCase
):
class
TestDistMPTraning
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
strategy
=
fleet
.
DistributedStrategy
()
...
...
python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py
浏览文件 @
503f422e
...
@@ -84,6 +84,36 @@ class TestStrategyConfig(unittest.TestCase):
...
@@ -84,6 +84,36 @@ class TestStrategyConfig(unittest.TestCase):
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_degree"
],
2
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_degree"
],
2
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"pp_degree"
],
4
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"pp_degree"
],
4
)
def
test_hybrid_parallel_mp_configs
(
self
):
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
hybrid_configs
=
{
"dp_degree"
:
1
,
"mp_degree"
:
2
,
"pp_degree"
:
4
,
"mp_configs"
:
{
"sync_param"
:
True
,
"sync_grad"
:
False
,
"sync_moment"
:
False
,
"sync_mode"
:
"broadcast"
,
"sync_param_name"
:
[
"embedding"
,
"layer_norm"
,
".w"
,
".b_"
],
},
}
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"dp_degree"
],
1
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_degree"
],
2
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"pp_degree"
],
4
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_configs"
].
sync_param
,
True
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_configs"
].
sync_grad
,
False
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_configs"
].
sync_moment
,
False
)
self
.
assertEqual
(
strategy
.
hybrid_configs
[
"mp_configs"
].
sync_mode
,
"broadcast"
)
self
.
assertEqual
(
strategy
.
sync_param_name
,
[
"embedding"
,
"layer_norm"
,
".w"
,
".b_"
]
)
def
test_hybrid_parallel_configs_order
(
self
):
def
test_hybrid_parallel_configs_order
(
self
):
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
hybrid_configs
=
{
strategy
.
hybrid_configs
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录