Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4c77a908
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
4c77a908
编写于
1月 14, 2022
作者:
B
Baibaifan
提交者:
GitHub
1月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dygraph sharding stage3 (#38052)
上级
556d5097
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
960 addition
and
17 deletion
+960
-17
paddle/pten/core/dense_tensor.cc
paddle/pten/core/dense_tensor.cc
+4
-0
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py
...stributed/fleet/meta_parallel/sharding/sharding_stage3.py
+675
-0
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py
...istributed/fleet/meta_parallel/sharding/sharding_utils.py
+14
-17
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
...n/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
+233
-0
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py
...dle/fluid/tests/unittests/test_dygraph_sharding_stage3.py
+31
-0
未找到文件。
paddle/pten/core/dense_tensor.cc
浏览文件 @
4c77a908
...
@@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
...
@@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
}
}
void
DenseTensor
::
ShareBufferWith
(
const
DenseTensor
&
tensor
)
{
void
DenseTensor
::
ShareBufferWith
(
const
DenseTensor
&
tensor
)
{
if
(
storage_
==
nullptr
)
{
storage_
=
make_intrusive
<
paddle
::
experimental
::
SharedStorage
>
(
paddle
::
platform
::
CPUPlace
());
}
if
(
storage_
!=
nullptr
&&
tensor
.
storage_
!=
nullptr
)
{
if
(
storage_
!=
nullptr
&&
tensor
.
storage_
!=
nullptr
)
{
storage_
->
set_data_shared
(
tensor
.
storage_
->
data_shared
());
storage_
->
set_data_shared
(
tensor
.
storage_
->
data_shared
());
}
}
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py
0 → 100644
浏览文件 @
4c77a908
此差异已折叠。
点击以展开。
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py
浏览文件 @
4c77a908
...
@@ -152,6 +152,9 @@ def ShardingScaler(scaler):
...
@@ -152,6 +152,9 @@ def ShardingScaler(scaler):
param_grads
=
[]
param_grads
=
[]
param_grads_fp16
=
[]
param_grads_fp16
=
[]
param_grads_fp32
=
[]
param_grads_fp32
=
[]
if
hasattr
(
optimizer
,
"update_slice"
):
optimizer
.
update_slice
()
optimizer
.
update_scaler
=
True
if
getattr
(
optimizer
.
_optim
,
'_param_groups'
,
None
)
and
isinstance
(
if
getattr
(
optimizer
.
_optim
,
'_param_groups'
,
None
)
and
isinstance
(
optimizer
.
_optim
.
_param_groups
[
0
],
dict
):
optimizer
.
_optim
.
_param_groups
[
0
],
dict
):
...
@@ -161,27 +164,21 @@ def ShardingScaler(scaler):
...
@@ -161,27 +164,21 @@ def ShardingScaler(scaler):
if
param
.
_grad_ivar
()
is
not
None
:
if
param
.
_grad_ivar
()
is
not
None
:
param_grads
.
append
(
param
.
_grad_ivar
())
param_grads
.
append
(
param
.
_grad_ivar
())
if
param
.
_grad_ivar
(
if
param
.
_grad_ivar
(
).
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
).
dtype
in
[
core
.
VarDesc
.
VarType
.
FP16
,
paddle
.
float16
]
:
param_grads_fp16
.
append
(
param
.
_grad_ivar
())
param_grads_fp16
.
append
(
param
.
_grad_ivar
())
else
:
else
:
param_grads_fp32
.
append
(
param
.
_grad_ivar
())
param_grads_fp32
.
append
(
param
.
_grad_ivar
())
else
:
else
:
param_grads
=
[
for
param
in
optimizer
.
_optim
.
_parameter_list
:
param
.
_grad_ivar
()
for
param
in
optimizer
.
_optim
.
_parameter_list
if
param
.
grad
is
not
None
:
if
param
.
_grad_ivar
()
is
not
None
param_grads
.
append
(
param
.
grad
)
]
if
param
.
grad
.
dtype
in
[
param_grads_fp16
=
[
core
.
VarDesc
.
VarType
.
FP16
,
paddle
.
float16
param
.
_grad_ivar
()
for
param
in
optimizer
.
_optim
.
_parameter_list
]:
if
(
param
.
_grad_ivar
()
is
not
None
param_grads_fp16
.
append
(
param
.
grad
)
)
and
(
param
.
_grad_ivar
().
dtype
==
core
.
VarDesc
.
VarType
.
FP16
else
:
)
param_grads_fp32
.
append
(
param
.
grad
)
]
param_grads_fp32
=
[
param
.
_grad_ivar
()
for
param
in
optimizer
.
_optim
.
_parameter_list
if
(
param
.
_grad_ivar
()
is
not
None
)
and
(
param
.
_grad_ivar
().
dtype
==
core
.
VarDesc
.
VarType
.
FP32
)
]
temp_found_inf_fp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool
))
temp_found_inf_fp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool
))
temp_found_inf_fp32
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool
))
temp_found_inf_fp32
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool
))
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
4c77a908
...
@@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
...
@@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_stage2
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_stage2
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_stage3
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_parallelizer
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_parallelizer
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers
)
list
(
APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper
)
list
(
APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper
)
...
@@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
...
@@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3
)
list
(
REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer
)
list
(
REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
...
@@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
...
@@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
0 → 100644
浏览文件 @
4c77a908
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
argparse
import
ast
import
time
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.distributed
import
fleet
from
paddle.fluid.dygraph
import
nn
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2
import
ShardingOptimizerStage2
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2
import
ShardingStage2
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3
import
ShardingStage3
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_utils
import
ShardingScaler
epoch
=
10
batch_size
=
32
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
base_lr
=
0.1
momentum_rate
=
0.9
l2_decay
=
1e-4
fleet
.
init
(
is_collective
=
True
)
class
MLP
(
fluid
.
Layer
):
def
__init__
(
self
,
linear_size
=
1000
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
MLP
,
self
).
__init__
()
self
.
_linear1
=
Linear
(
linear_size
,
linear_size
)
self
.
_linear2
=
Linear
(
linear_size
,
linear_size
)
self
.
_linear3
=
Linear
(
linear_size
,
10
)
def
forward
(
self
,
inputs
):
y
=
self
.
_linear1
(
inputs
)
y
=
self
.
_linear2
(
y
)
y
=
self
.
_linear3
(
y
)
return
y
def
reader_decorator
(
linear_size
=
1000
):
def
__reader__
():
for
_
in
range
(
100
):
img
=
np
.
random
.
rand
(
linear_size
).
astype
(
'float32'
)
label
=
np
.
ones
(
1
).
astype
(
'int64'
)
yield
img
,
label
return
__reader__
def
optimizer_setting
(
model
,
use_pure_fp16
,
opt_group
=
False
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
[{
"params"
:
model
.
parameters
()
}]
if
opt_group
else
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
grad_clip
=
clip
,
multi_precision
=
use_pure_fp16
)
return
optimizer
def
train_mlp
(
model
,
sharding_stage
,
use_pure_fp16
=
False
,
accumulate_grad
=
False
,
opt_group
=
False
,
recompute
=
False
):
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
if
opt_group
:
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
,
opt_group
=
opt_group
)
else
:
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
ShardingScaler
(
scaler
)
if
sharding_stage
==
2
:
optimizer
=
ShardingOptimizerStage2
(
params
=
model
.
parameters
(),
optim
=
optimizer
,
group
=
group
)
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
,
buffer_max_size
=
2
**
21
,
accumulate_grads
=
accumulate_grad
)
elif
sharding_stage
==
3
:
model
=
ShardingStage3
(
model
,
optimizer
=
optimizer
,
group
=
group
,
sync_comm
=
recompute
)
train_reader
=
paddle
.
batch
(
reader_decorator
(),
batch_size
=
batch_size
,
drop_last
=
True
)
train_loader
=
paddle
.
io
.
DataLoader
.
from_generator
(
capacity
=
32
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
train_loader
.
set_sample_list_generator
(
train_reader
)
for
eop
in
range
(
epoch
):
model
.
train
()
for
batch_id
,
data
in
enumerate
(
train_loader
()):
img
,
label
=
data
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
True
,
level
=
'O2'
):
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
if
not
accumulate_grad
:
if
not
use_pure_fp16
:
avg_loss
.
backward
()
optimizer
.
step
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
if
accumulate_grad
:
if
not
use_pure_fp16
:
avg_loss
.
backward
()
optimizer
.
step
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
if
sharding_stage
==
3
:
model
.
get_all_parameters
()
return
model
.
parameters
()
def
test_stage2_stage3
():
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
=
MLP
(),
MLP
(),
MLP
(
),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
()
state_dict
=
mlp
.
state_dict
()
mlp1
.
set_state_dict
(
state_dict
)
mlp2
.
set_state_dict
(
state_dict
)
mlp3
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
mlp5
.
set_state_dict
(
state_dict
)
mlp6
.
set_state_dict
(
state_dict
)
mlp7
.
set_state_dict
(
state_dict
)
mlp8
.
set_state_dict
(
state_dict
)
# fp32
stage2_params
=
train_mlp
(
mlp1
,
sharding_stage
=
2
,
use_pure_fp16
=
False
,
opt_group
=
True
)
stage3_params
=
train_mlp
(
mlp2
,
sharding_stage
=
3
,
use_pure_fp16
=
False
,
opt_group
=
True
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage3_params
)):
if
stage2_params
[
i
].
name
==
stage3_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
j
].
numpy
(),
rtol
=
1e-6
)
# fp32 accumulate grad
stage2_params
=
train_mlp
(
mlp3
,
sharding_stage
=
2
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
,
opt_group
=
True
)
stage3_params
=
train_mlp
(
mlp4
,
sharding_stage
=
3
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
,
opt_group
=
True
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage3_params
)):
if
stage2_params
[
i
].
name
==
stage3_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
j
].
numpy
(),
rtol
=
1e-6
)
# fp16
stage2_params
=
train_mlp
(
mlp5
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
opt_group
=
False
)
stage3_params
=
train_mlp
(
mlp6
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage3_params
)):
if
stage2_params
[
i
].
name
==
stage3_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
j
].
numpy
(),
rtol
=
1e-6
)
# fp16 recompute
stage3_params
=
train_mlp
(
mlp7
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
)
stage3_params_re
=
train_mlp
(
mlp8
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
,
recompute
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
for
j
in
range
(
len
(
stage3_params_re
)):
if
stage3_params
[
i
].
name
==
stage3_params_re
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
j
].
numpy
(),
rtol
=
1e-6
)
return
if
__name__
==
'__main__'
:
test_stage2_stage3
()
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py
0 → 100644
浏览文件 @
4c77a908
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestDygraphShardingStage3
(
TestMultipleGpus
):
# check sharding logic as well as the accuracy with single mode
def
test_dygraph_sharding_optimizer_stage3
(
self
):
self
.
run_mnist_2gpu
(
'dygraph_sharding_stage3.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录