Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
327e5050
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
327e5050
编写于
12月 19, 2021
作者:
B
Baibaifan
提交者:
GitHub
12月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Integration sharding stage2 function (#38151)
上级
9e42fe9a
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
119 addition
and
103 deletion
+119
-103
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
...optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
+38
-38
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py
...stributed/fleet/meta_parallel/sharding/sharding_stage2.py
+38
-17
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py
...istributed/fleet/meta_parallel/sharding/sharding_utils.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py
...luid/tests/unittests/dygraph_sharding_optimizer_stage2.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
...n/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
+33
-42
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py
.../fluid/tests/unittests/dygraph_sharding_stage2_offload.py
+8
-4
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
浏览文件 @
327e5050
...
...
@@ -16,21 +16,19 @@
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
import
copy
import
time
import
logging
import
numpy
as
np
from
math
import
inf
from
itertools
import
chain
from
functools
import
reduce
from
collections
import
OrderedDict
import
paddle
import
paddle.fluid
as
fluid
from
paddle
import
framework
from
paddle.fluid
import
core
import
paddle.distributed
as
dist
from
paddle.optimizer
import
Optimizer
from
paddle.fluid.clip
import
ClipGradByGlobalNorm
from
paddle.distributed.collective
import
_get_global_group
from
...utils.internal_storage
import
ParamStorage
from
...meta_parallel.sharding.sharding_utils
import
Type
,
device_guard
,
ShardingClipGrad
...
...
@@ -59,14 +57,14 @@ class ShardingOptimizerStage2(Optimizer):
# Feature Notes:
# 1. Unified memory for parameters and parameters.grad to InternalStorage.
# 2. Support the segmentation of optimizer parameters and partial updating of parameters.
# 3. Dynamically adjust training parameters and models
。
# 3. Dynamically adjust training parameters and models
.
# 4. Support offload function.
# 5. Support the establishment of independent communication groups.
# 6. Broadcast_fp16 is not supported now.
def
__init__
(
self
,
params
,
optim
,
group
,
group
=
None
,
broadcast_fp16
=
False
,
offload
=
False
,
device
=
"gpu"
,
...
...
@@ -78,13 +76,16 @@ class ShardingOptimizerStage2(Optimizer):
self
.
_dtype_rank_params
=
OrderedDict
(
)
# {dtype:[param1,param2]} device, rank, params
self
.
_param2rank
=
{}
self
.
_segment_params
=
[]
self
.
_
_
segment_params
=
[]
self
.
_rank_buffer_size
=
{}
# {dtype: {rank: numel+alignment}}
self
.
_param2align
=
{}
# {param.name: align}
# Default information
self
.
_optim_defaults
=
kw
self
.
_optim
=
optim
self
.
_ori_parameter_list
=
self
.
_optim
.
_parameter_list
self
.
_ori_param_groups
=
self
.
_optim
.
_param_groups
assert
hasattr
(
self
.
_optim
,
"_master_weights"
),
"Must use optimizer with _master_weights attribute"
self
.
_local_params
=
params
...
...
@@ -94,8 +95,8 @@ class ShardingOptimizerStage2(Optimizer):
filter
(
lambda
x
:
x
.
trainable
and
x
.
dtype
==
Type
.
fp16
.
value
,
self
.
_local_params
)))
>
0
assert
group
is
not
None
,
"Distributed communication group is must be gived"
self
.
group
=
group
group
=
_get_global_group
()
if
group
is
None
else
group
self
.
world_size
=
group
.
nranks
self
.
rank
=
group
.
rank
...
...
@@ -119,7 +120,7 @@ class ShardingOptimizerStage2(Optimizer):
self
.
_master_params
=
{}
# Update optimizer parameters and adjust parameter storage and use according to rank.
self
.
update_opt_status
()
self
.
_
update_opt_status
()
def
_generate_master_params
(
self
,
trainable_params
):
if
self
.
offload
:
...
...
@@ -137,7 +138,7 @@ class ShardingOptimizerStage2(Optimizer):
self
.
_optim
.
_master_weights
[
param
.
name
]
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
def
update_opt_status
(
self
):
def
_
update_opt_status
(
self
):
"""Update optimizer status and parameter storage information, and special functions to be developed.
"""
# func 1
...
...
@@ -147,12 +148,12 @@ class ShardingOptimizerStage2(Optimizer):
# Segement helpers
def
segment_params
(
self
):
def
_
segment_params
(
self
):
"""
Divide all optimizer parameters equally into rank.
"""
if
len
(
self
.
_segment_params
)
==
0
:
self
.
_segment_params
,
param_lists
=
[
if
len
(
self
.
_
_
segment_params
)
==
0
:
self
.
_
_
segment_params
,
param_lists
=
[
[]
for
_
in
range
(
self
.
world_size
)
],
[[]
for
_
in
range
(
self
.
world_size
)]
sizes
=
[
0
]
*
self
.
world_size
...
...
@@ -165,9 +166,8 @@ class ShardingOptimizerStage2(Optimizer):
sizes
[
rank
]
+=
np
.
prod
(
param
.
shape
)
if
param
.
trainable
else
0
for
rank
,
params
in
enumerate
(
param_lists
):
# param_group_rank = copy.copy(params)
self
.
_segment_params
[
rank
].
extend
(
params
)
return
self
.
_segment_params
self
.
__segment_params
[
rank
].
extend
(
params
)
return
self
.
__segment_params
@
property
def
local_params
(
self
):
...
...
@@ -177,7 +177,7 @@ class ShardingOptimizerStage2(Optimizer):
def
param2rank
(
self
):
"""Map the params to the rank which owns them"""
if
len
(
self
.
_param2rank
)
==
0
:
for
rank
,
params
in
enumerate
(
self
.
segment_params
()):
for
rank
,
params
in
enumerate
(
self
.
_
segment_params
()):
for
param
in
params
:
self
.
_param2rank
[
param
.
name
]
=
rank
return
self
.
_param2rank
...
...
@@ -271,32 +271,31 @@ class ShardingOptimizerStage2(Optimizer):
"""
if
self
.
offload
:
self
.
_optim
.
_parameter_list
=
[
param
for
name
,
param
in
self
.
_master_params
.
items
()
]
params_list
=
list
(
self
.
_master_params
.
values
())
else
:
# Synchronize optimizer parameters for the current rank
if
len
(
self
.
dtype_rank_params
.
keys
(
))
==
1
and
Type
.
fp32
.
value
in
self
.
dtype_rank_params
.
keys
():
self
.
_optim
.
_parameter_list
=
self
.
dtype_rank_params
[
Type
.
fp32
.
value
][
self
.
rank
]
elif
len
(
self
.
dtype_rank_params
.
keys
(
))
==
1
and
Type
.
fp16
.
value
in
self
.
dtype_rank_params
.
keys
():
self
.
_optim
.
_parameter_list
=
self
.
dtype_rank_params
[
Type
.
fp16
.
value
][
self
.
rank
]
else
:
self
.
_optim
.
_parameter_list
=
self
.
dtype_rank_params
[
Type
.
fp16
.
value
][
self
.
rank
]
+
self
.
dtype_rank_params
[
Type
.
fp32
.
value
][
self
.
rank
]
params_list
=
[]
for
dtype
in
self
.
dtype_rank_params
.
keys
():
params_list
.
extend
(
self
.
dtype_rank_params
[
dtype
][
self
.
rank
])
params_name_list
=
list
(
map
(
lambda
p
:
p
.
name
,
params_list
))
if
not
isinstance
(
self
.
_optim
.
_param_groups
[
0
],
dict
):
self
.
_optim
.
_parameter_list
=
params_list
self
.
_optim
.
_param_groups
=
params_list
else
:
for
param_group
in
self
.
_optim
.
_param_groups
:
p_group
=
[]
for
param
in
param_group
[
'params'
]:
if
param
.
name
in
params_name_list
:
p_group
.
append
(
params_list
[
params_name_list
.
index
(
param
.
name
)])
param_group
[
'params'
]
=
p_group
# Run the optimizer of the current rank step
if
self
.
offload
:
with
device_guard
(
self
.
rank
,
self
.
offload_device
):
with
device_guard
(
device
=
self
.
offload_device
):
self
.
_optim
.
step
()
for
param
in
self
.
_optim
.
_parameter_list
:
self
.
_master_params
[
param
.
name
].
set_value
(
param
)
dev_id
=
0
if
paddle
.
get_device
()
==
"cpu"
else
int
(
paddle
.
get_device
().
split
(
":"
)[
1
])
...
...
@@ -312,10 +311,11 @@ class ShardingOptimizerStage2(Optimizer):
self
.
_broadcast_params
()
# Return full parameters to optimizer parameters
self
.
_optim
.
_parameter_list
=
self
.
_local_params
self
.
_optim
.
_parameter_list
=
self
.
_ori_parameter_list
self
.
_optim
.
_param_groups
=
self
.
_ori_param_groups
def
clear_cache
(
self
):
self
.
_segment_params
.
clear
()
def
_
clear_cache
(
self
):
self
.
_
_
segment_params
.
clear
()
self
.
_dtype_rank_params
.
clear
()
self
.
_param2rank
.
clear
()
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py
浏览文件 @
327e5050
...
...
@@ -24,10 +24,12 @@ import numpy as np
from
itertools
import
chain
from
functools
import
reduce
from
collections
import
deque
from
types
import
MethodType
import
paddle
from
paddle
import
nn
import
paddle.distributed
as
dist
from
paddle.distributed.collective
import
_get_global_group
from
...utils.internal_storage
import
GradStorage
from
...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2
import
ShardingOptimizerStage2
...
...
@@ -57,7 +59,7 @@ class ShardingStage2(nn.Layer):
self
,
layer
,
sharding_optimizer
,
group
,
group
=
None
,
sync_buffers
=
False
,
pertrain_sync_models
=
True
,
buffer_max_size
=
2
**
23
,
#8MB
...
...
@@ -83,13 +85,12 @@ class ShardingStage2(nn.Layer):
self
.
_accumulate_grads
=
accumulate_grads
# Communication related attributes
assert
group
is
not
None
,
"Distributed communication group is must be gived"
self
.
_group
=
group
self
.
_world_size_scaling
=
1.0
/
self
.
_group
.
nranks
assert
self
.
_group
.
nranks
>
1
,
"Training must be distributed, ranks must be greater than 1"
self
.
_rank
=
self
.
_group
.
rank
group
=
_get_global_group
()
if
group
is
None
else
group
self
.
_world_size_scaling
=
1.0
/
group
.
nranks
assert
group
.
nranks
>
1
,
"Training must be distributed, ranks must be greater than 1"
self
.
_rank
=
group
.
rank
self
.
_global_root_rank
=
0
# picking rank 0 as the reference
self
.
_global_ranks
=
self
.
_group
.
ranks
self
.
_default_device
=
device
# Global statistical parameters
...
...
@@ -112,8 +113,8 @@ class ShardingStage2(nn.Layer):
self
.
_has_grad_storage
=
[]
self
.
_grad_storage_list
=
[]
#
o
ffload
# TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy
#
O
ffload
# TODO(haohongxiang): Now it's not
be
supported for multi-optimizers using Offload strategy
self
.
_offload_optims
=
list
(
filter
(
lambda
optim
:
optim
.
offload
,
self
.
_sharding_optimizers
))
if
len
(
self
.
_offload_optims
)
>
0
:
...
...
@@ -134,6 +135,11 @@ class ShardingStage2(nn.Layer):
# Set tasks flow
self
.
_tasks_flow
=
deque
()
# Define optimizer step and clear_grad
if
self
.
_accumulate_grads
:
self
.
_redefine_opt_step
()
self
.
_redefine_opt_clear
()
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
A wrapper for Sharding Stage2 layer.
...
...
@@ -161,7 +167,7 @@ class ShardingStage2(nn.Layer):
return
fw
def
clear_gradients
(
self
):
def
_
clear_gradients
(
self
):
"""
Set zero to the gradient of the optimizer's current rank trainable parameters.
"""
...
...
@@ -176,7 +182,7 @@ class ShardingStage2(nn.Layer):
if
param
.
name
in
self
.
_param_grads
and
param
.
grad
is
not
None
:
param
.
clear_gradient
()
def
grad_scale
(
self
):
def
_
grad_scale
(
self
):
"""
Before the gradient accumulation, scale the gradient.
"""
...
...
@@ -287,9 +293,6 @@ class ShardingStage2(nn.Layer):
for
grad_storage
in
self
.
_grad_storage_list
:
grad_storage
.
reset_checked_in
()
if
not
self
.
_accumulate_grads
:
self
.
_grads_flipped
=
False
def
_get_reduce_fn
(
self
,
index
,
param
,
dst_rank
):
"""
There are two ways to reduce gradient.
...
...
@@ -412,7 +415,6 @@ class ShardingStage2(nn.Layer):
self
.
_bw_hooks
.
pop
().
remove
()
# Go through the parameters, attach the hook
self
.
_grad_accs
=
[]
if
not
self
.
training
:
return
...
...
@@ -500,9 +502,6 @@ class ShardingStage2(nn.Layer):
# Whether parameters trainability changed
trainability_changed
=
trainable_mask
!=
self
.
_trainable_mask
# The whole model is not trainable but we still have grad hooks
trainability_changed
|=
not
self
.
training
and
len
(
self
.
_bw_hooks
)
>
0
if
trainability_changed
:
logging
.
warning
(
"Trainable params changed, because of eval/train mode or parameter freezing/unfreeze."
...
...
@@ -548,3 +547,25 @@ class ShardingStage2(nn.Layer):
format
(
rank_buffer_size
[
Type
.
fp32
.
value
]
/
2
**
18
,
model_size
/
2
**
18
))
return
rank_buffer_size
def
_redefine_opt_step
(
self
):
if
not
self
.
_accumulate_grads
:
return
grad_func
=
self
.
_grad_scale
for
opt
in
self
.
_sharding_optimizers
:
opt_step
=
opt
.
step
def
_opt_step
(
self
):
grad_func
()
opt_step
()
opt
.
step
=
MethodType
(
_opt_step
,
opt
)
def
_redefine_opt_clear
(
self
):
clear_func
=
self
.
_clear_gradients
def
_opt_clear
(
self
):
clear_func
()
for
opt
in
self
.
_sharding_optimizers
:
opt
.
clear_grad
=
MethodType
(
_opt_clear
,
opt
)
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py
浏览文件 @
327e5050
...
...
@@ -131,7 +131,7 @@ class ShardingClipGrad:
@
contextlib
.
contextmanager
def
device_guard
(
dev_id
,
device
=
"cpu"
):
def
device_guard
(
dev_id
=
0
,
device
=
"cpu"
):
origin_device
=
paddle
.
device
.
get_device
()
if
device
==
"cpu"
:
paddle
.
set_device
(
device
)
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py
浏览文件 @
327e5050
...
...
@@ -125,7 +125,7 @@ def train_mlp():
oss_optimizer
.
step
()
# oss_optimizer clear cache
oss_optimizer
.
clear_cache
()
oss_optimizer
.
_
clear_cache
()
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
浏览文件 @
327e5050
...
...
@@ -30,7 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar
seed
=
2021
epoch
=
2
batch_size
=
32
linear_size
=
1000
0
linear_size
=
1000
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
hybrid_configs
=
{
...
...
@@ -46,7 +46,7 @@ paddle.seed(seed)
class
MLP
(
fluid
.
Layer
):
def
__init__
(
self
,
linear_size
=
1000
0
,
param_attr
=
None
,
bias_attr
=
None
):
def
__init__
(
self
,
linear_size
=
1000
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
MLP
,
self
).
__init__
()
self
.
_linear1
=
Linear
(
linear_size
,
linear_size
)
...
...
@@ -60,7 +60,7 @@ class MLP(fluid.Layer):
return
y
def
reader_decorator
(
linear_size
=
1000
0
):
def
reader_decorator
(
linear_size
=
1000
):
def
__reader__
():
for
_
in
range
(
100
):
img
=
np
.
random
.
rand
(
linear_size
).
astype
(
'float32'
)
...
...
@@ -70,10 +70,12 @@ def reader_decorator(linear_size=10000):
return
__reader__
def
optimizer_setting
(
model
,
use_pure_fp16
):
def
optimizer_setting
(
model
,
use_pure_fp16
,
opt_group
=
False
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
model
.
parameters
(),
parameters
=
[{
"params"
:
model
.
parameters
()
}]
if
opt_group
else
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
grad_clip
=
clip
,
...
...
@@ -85,27 +87,32 @@ def optimizer_setting(model, use_pure_fp16):
def
train_mlp
(
model
,
sharding_stage
,
use_pure_fp16
=
False
,
a
ll_test
=
False
,
accumulate_grad
=
False
):
a
ccumulate_grad
=
False
,
opt_group
=
False
):
if
sharding_stage
==
"dp"
:
hcg
=
fleet
.
get_hybrid_communicate_group
()
group
=
hcg
.
get_check_parallel_group
()
else
:
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
)
if
opt_group
:
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
,
opt_group
=
opt_group
)
else
:
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
if
sharding_stage
==
2
:
optimizer
=
ShardingOptimizerStage2
(
params
=
model
.
parameters
(),
optim
=
optimizer
,
group
=
group
)
if
a
ll_test
:
if
a
ccumulate_grad
:
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
,
accumulate_grads
=
accumulate_grad
)
model
,
optimizer
,
group
=
group
,
buffer_max_size
=
2
**
21
,
accumulate_grads
=
accumulate_grad
)
else
:
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
)
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
,
buffer_max_size
=
2
**
21
)
else
:
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
model
=
fleet
.
distributed_model
(
model
)
...
...
@@ -132,29 +139,16 @@ def train_mlp(model,
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
enable
=
use_pure_fp16
,
level
=
'O2'
):
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
avg_loss
.
backward
()
if
accumulate_grad
and
batch_id
==
2
:
model
.
grad_scale
()
optimizer
.
step
()
model
.
clear_gradients
()
return
model
.
parameters
()
if
not
accumulate_grad
:
optimizer
.
step
()
if
sharding_stage
==
2
:
model
.
clear_gradients
()
else
:
optimizer
.
clear_grad
()
optimizer
.
step
()
optimizer
.
clear_grad
()
if
a
ll_test
and
batch_id
==
2
:
if
a
ccumulate_grad
and
batch_id
==
2
:
return
model
.
parameters
()
return
model
.
parameters
()
...
...
@@ -171,22 +165,19 @@ def test_dp_stage2():
mlp2
.
set_state_dict
(
state_dict
)
mlp3
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
dp_params
=
train_mlp
(
mlp1
,
sharding_stage
=
"dp"
,
use_pure_fp16
=
False
)
stage2_params
=
train_mlp
(
mlp2
,
sharding_stage
=
2
,
use_pure_fp16
=
False
)
dp_params
=
train_mlp
(
mlp1
,
sharding_stage
=
"dp"
,
use_pure_fp16
=
False
,
opt_group
=
True
)
stage2_params
=
train_mlp
(
mlp2
,
sharding_stage
=
2
,
use_pure_fp16
=
False
,
opt_group
=
True
)
for
i
in
range
(
len
(
dp_params
)):
for
j
in
range
(
len
(
stage2_params
)):
if
dp_params
[
i
].
name
==
stage2_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
dp_params
[
i
].
numpy
(),
stage2_params
[
j
].
numpy
(),
rtol
=
1e-6
)
stage2_params
=
train_mlp
(
mlp3
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
all_test
=
True
)
stage2_params
=
train_mlp
(
mlp3
,
sharding_stage
=
2
)
stage2_accumulate_grad
=
train_mlp
(
mlp4
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
all_test
=
True
,
accumulate_grad
=
True
)
mlp4
,
sharding_stage
=
2
,
accumulate_grad
=
True
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage2_accumulate_grad
)):
if
stage2_params
[
i
].
name
==
stage2_accumulate_grad
[
j
].
name
:
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py
浏览文件 @
327e5050
...
...
@@ -33,7 +33,7 @@ from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting
seed
=
2021
epoch
=
2
batch_size
=
32
linear_size
=
8
000
linear_size
=
1
000
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
...
...
@@ -52,7 +52,12 @@ def train_mlp(model, offload=False):
optim
=
optimizer
,
group
=
group
,
offload
=
offload
)
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
,
accumulate_grads
=
True
)
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
,
buffer_max_size
=
2
**
21
,
accumulate_grads
=
True
)
train_reader
=
paddle
.
batch
(
reader_decorator
(
linear_size
),
batch_size
=
batch_size
,
drop_last
=
True
)
...
...
@@ -81,10 +86,9 @@ def train_mlp(model, offload=False):
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
scaler
.
scale
(
avg_loss
).
backward
()
model
.
grad_scale
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
model
.
clear_gradients
()
optimizer
.
clear_grad
()
for
dtype
in
optimizer
.
param_storages
:
for
dst_rank
,
param_storage
in
optimizer
.
param_storages
[
dtype
].
items
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录