Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
ujs_wantz
Paddle
提交
828f87ae
P
Paddle
项目概览
ujs_wantz
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
828f87ae
编写于
12月 06, 2021
作者:
B
Baibaifan
提交者:
GitHub
12月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sharding_stage2_pfp16 (#37836)
上级
3e33ef5a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
36 addition
and
20 deletion
+36
-20
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
...optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
+14
-0
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py
...stributed/fleet/meta_parallel/sharding/sharding_stage2.py
+8
-3
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
...n/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
+14
-17
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
浏览文件 @
828f87ae
...
...
@@ -83,8 +83,14 @@ class ShardingOptimizerStage2(Optimizer):
# Default information
self
.
_optim_defaults
=
kw
self
.
_optim
=
optim
assert
hasattr
(
self
.
_optim
,
"_master_weights"
),
"Must use optimizer with _master_weights attribute"
self
.
_local_params
=
params
self
.
_default_device
=
device
self
.
_pfp16
=
len
(
list
(
filter
(
lambda
x
:
x
.
trainable
and
x
.
dtype
==
Type
.
fp16
.
value
,
self
.
_local_params
)))
>
0
assert
group
is
not
None
,
"Distributed communication group is must be gived"
self
.
group
=
group
...
...
@@ -98,6 +104,12 @@ class ShardingOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank.
self
.
update_opt_status
()
def
_generate_master_params
(
self
,
trainable_params
):
for
param
in
trainable_params
:
if
param
.
dtype
==
Type
.
fp16
.
value
:
self
.
_optim
.
_master_weights
[
param
.
name
]
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
def
update_opt_status
(
self
):
"""Update optimizer status and parameter storage information, and special functions to be developed.
"""
...
...
@@ -207,6 +219,8 @@ class ShardingOptimizerStage2(Optimizer):
# Merge all the trainable params in a single InternalStorage
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
params
))
if
self
.
_pfp16
and
dst_rank
==
self
.
rank
:
self
.
_generate_master_params
(
trainable_params
)
if
trainable_params
:
param_storage
=
ParamStorage
(
size
=
self
.
rank_buffer_size
[
dtype
][
dst_rank
],
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py
浏览文件 @
828f87ae
...
...
@@ -30,6 +30,7 @@ from paddle import nn
import
paddle.distributed
as
dist
from
...utils.internal_storage
import
GradStorage
from
...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2
import
ShardingOptimizerStage2
from
.sharding_utils
import
Taskflow
,
Type
...
...
@@ -70,6 +71,11 @@ class ShardingStage2(nn.Layer):
self
.
_layer
=
layer
self
.
_sharding_optimizers
=
[
sharding_optimizer
]
if
not
isinstance
(
sharding_optimizer
,
list
)
else
sharding_optimizer
assert
all
(
list
(
map
(
lambda
opt
:
isinstance
(
opt
,
ShardingOptimizerStage2
),
self
.
_sharding_optimizers
))
),
"Please use ShardingOptimizerStage2 optimizer"
self
.
_sync_buffers
=
sync_buffers
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
...
...
@@ -88,8 +94,7 @@ class ShardingStage2(nn.Layer):
# Global statistical parameters
self
.
_all_params
=
list
(
chain
(
*
[
optim
.
local_params
for
optim
in
self
.
_sharding_optimizers
]))
chain
(
*
[
optim
.
local_params
for
optim
in
self
.
_sharding_optimizers
]))
self
.
_trainable_params
=
[]
self
.
_grad_reduced
=
[]
self
.
_trainable_param2rank
=
{}
...
...
@@ -436,7 +441,7 @@ class ShardingStage2(nn.Layer):
.
_fill
))
self
.
_grad_storage_list
=
list
(
chain
(
*
[
chain
(
*
[
self
.
_grad_storages
[
dtype
].
values
()
for
dtype
in
self
.
_grad_storages
.
keys
()
]))
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
浏览文件 @
828f87ae
...
...
@@ -24,7 +24,6 @@ from paddle.fluid.dygraph.nn import Linear
from
paddle.distributed
import
fleet
from
paddle.fluid.dygraph
import
nn
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer
import
DygraphShardingOptimizer
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2
import
ShardingOptimizerStage2
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2
import
ShardingStage2
...
...
@@ -70,7 +69,7 @@ def reader_decorator():
return
__reader__
def
optimizer_setting
(
model
,
use_pure_fp16
,
stage
=
1
):
def
optimizer_setting
(
model
,
use_pure_fp16
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
model
.
parameters
(),
...
...
@@ -87,20 +86,16 @@ def train_mlp(model,
use_pure_fp16
=
False
,
all_test
=
False
,
accumulate_grad
=
False
):
if
sharding_stage
==
1
:
if
sharding_stage
==
"dp"
:
hcg
=
fleet
.
get_hybrid_communicate_group
()
group
=
hcg
.
get_check_parallel_group
()
else
:
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
,
stage
=
sharding_stage
)
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
if
use_pure_fp16
:
model
,
optimizer
=
paddle
.
amp
.
decorate
(
models
=
model
,
optimizers
=
optimizer
,
level
=
'O2'
,
save_dtype
=
'float32'
)
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
)
if
sharding_stage
==
2
:
optimizer
=
ShardingOptimizerStage2
(
...
...
@@ -164,7 +159,7 @@ def train_mlp(model,
return
model
.
parameters
()
def
test_
stage1
_stage2
():
def
test_
dp
_stage2
():
mlp
=
MLP
()
state_dict
=
mlp
.
state_dict
()
mlp1
=
MLP
()
...
...
@@ -175,11 +170,13 @@ def test_stage1_stage2():
mlp2
.
set_state_dict
(
state_dict
)
mlp3
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
stage1_params
=
train_mlp
(
mlp
,
sharding_stage
=
1
,
use_pure_fp16
=
False
)
stage2_params
=
train_mlp
(
mlp
,
sharding_stage
=
2
,
use_pure_fp16
=
False
)
for
i
in
range
(
len
(
stage1_params
)):
np
.
testing
.
assert_allclose
(
stage1_params
[
i
].
numpy
(),
stage2_params
[
i
].
numpy
(),
rtol
=
1e-6
)
dp_params
=
train_mlp
(
mlp1
,
sharding_stage
=
"dp"
,
use_pure_fp16
=
False
)
stage2_params
=
train_mlp
(
mlp2
,
sharding_stage
=
2
,
use_pure_fp16
=
False
)
for
i
in
range
(
len
(
dp_params
)):
for
j
in
range
(
len
(
stage2_params
)):
if
dp_params
[
i
].
name
==
stage2_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
dp_params
[
i
].
numpy
(),
stage2_params
[
j
].
numpy
(),
rtol
=
1e-6
)
stage2_params
=
train_mlp
(
mlp3
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
all_test
=
True
)
...
...
@@ -201,4 +198,4 @@ def test_stage1_stage2():
if
__name__
==
'__main__'
:
test_
stage1
_stage2
()
test_
dp
_stage2
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录