Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
20e19776
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
20e19776
编写于
12月 02, 2021
作者:
B
Baibaifan
提交者:
GitHub
12月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dygraph sharding stage2 (#37707)
上级
29ebf621
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
743 addition
and
6 deletion
+743
-6
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
...optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
+0
-6
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py
...stributed/fleet/meta_parallel/sharding/sharding_stage2.py
+505
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
...n/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
+204
-0
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py
...dle/fluid/tests/unittests/test_dygraph_sharding_stage2.py
+31
-0
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py
浏览文件 @
20e19776
...
...
@@ -68,7 +68,6 @@ class ShardingOptimizerStage2(Optimizer):
broadcast_fp16
=
False
,
offload
=
False
,
device
=
"gpu"
,
accumulation_steps
=
None
,
**
kw
):
super
().
__init__
(
optim
.
_learning_rate
,
params
,
kw
)
...
...
@@ -86,7 +85,6 @@ class ShardingOptimizerStage2(Optimizer):
self
.
_optim
=
optim
self
.
_local_params
=
params
self
.
_default_device
=
device
self
.
_accumulation_steps
=
accumulation_steps
assert
group
is
not
None
,
"Distributed communication group is must be gived"
self
.
group
=
group
...
...
@@ -136,10 +134,6 @@ class ShardingOptimizerStage2(Optimizer):
def
local_params
(
self
):
return
self
.
_local_params
@
property
def
accumulation_steps
(
self
):
return
self
.
_accumulation_steps
@
property
def
param2rank
(
self
):
"""Map the params to the rank which owns them"""
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py
0 → 100644
浏览文件 @
20e19776
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Taken and modified for fairscale from:
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/sharded_ddp.py
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
import
os
import
contextlib
import
logging
import
time
import
functools
import
numpy
as
np
from
itertools
import
chain
from
functools
import
reduce
from
collections
import
deque
import
paddle
from
paddle
import
nn
import
paddle.distributed
as
dist
from
...utils.internal_storage
import
GradStorage
from
.sharding_utils
import
Taskflow
,
Type
def
_trainable
(
param
):
return
param
.
trainable
class
ShardingStage2
(
nn
.
Layer
):
"""
A wrapper for Sharding Stage2 Layer in Dygraph.
.. warning: ShardingStage2 encapsulates the layer strategy and integrates it into the nn.Layer.
.. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
"""
# TODO (Baibaifan)
# Feature Notes::
# 1. Unified memory for param and param.grad to InternalStorage.
# 2. Divide param.grad according to rank to centrally apply for and release GPU memory.
# 3. Dynamically adjust training parameters and models。
# 4. Support offload function.
# 5. Support the establishment of independent communication groups.
def
__init__
(
self
,
layer
,
sharding_optimizer
,
group
,
sync_buffers
=
False
,
pertrain_sync_models
=
True
,
buffer_max_size
=
2
**
23
,
#8MB
auto_refresh_trainable
=
True
,
device
=
"gpu"
,
use_grad_storage
=
True
,
accumulate_grads
=
False
):
super
().
__init__
()
# training options
self
.
_layer
=
layer
self
.
_sharding_optimizers
=
[
sharding_optimizer
]
if
not
isinstance
(
sharding_optimizer
,
list
)
else
sharding_optimizer
self
.
_sync_buffers
=
sync_buffers
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
# Gradient accumulation, Gradient flip
self
.
_accumulate_grads
=
accumulate_grads
# Communication related attributes
assert
group
is
not
None
,
"Distributed communication group is must be gived"
self
.
_group
=
group
self
.
_world_size_scaling
=
1.0
/
self
.
_group
.
nranks
assert
self
.
_group
.
nranks
>
1
,
"Training must be distributed, ranks must be greater than 1"
self
.
_rank
=
self
.
_group
.
rank
self
.
_global_root_rank
=
0
# picking rank 0 as the reference
self
.
_global_ranks
=
self
.
_group
.
ranks
self
.
_default_device
=
device
# Global statistical parameters
self
.
_all_params
=
list
(
chain
(
*
[
optim
.
local_params
for
optim
in
self
.
_sharding_optimizers
]))
self
.
_trainable_params
=
[]
self
.
_grad_reduced
=
[]
self
.
_trainable_param2rank
=
{}
self
.
_trainable_param2align
=
{}
self
.
_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
self
.
_param_grads
=
[]
# Set grad storage size & Display param sizes and model sizes
model_size
=
sum
(
[
np
.
prod
(
p
.
shape
)
for
p
in
self
.
_layer
.
parameters
()]).
item
()
self
.
_buffer_max_size
=
self
.
_rank_buffer_size
(
buffer_max_size
,
model_size
)
self
.
_use_grad_storage
=
use_grad_storage
self
.
_grad_storages
=
{}
# {dtype: {rank: GradStorage}}
self
.
_has_grad_storage
=
[]
self
.
_grad_storage_list
=
[]
# Set backward pass hooks
self
.
_bw_hooks
=
[]
# Synchronous all ranks models
if
pertrain_sync_models
:
self
.
_sync_params_and_buffers
()
# Set tasks flow
self
.
_tasks_flow
=
deque
()
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
A wrapper for Sharding Stage2 layer.
- Fresh trainable params or rebuild grad storage
- Sync layer's buffer params
- Clear all flags states
- Forward for origin layers
"""
# Whether to need to reset trainable parameters
needs_fresh
=
len
(
self
.
_bw_hooks
)
==
0
and
self
.
training
if
self
.
_auto_refresh_trainable
:
needs_fresh
|=
self
.
_detect_train_change
()
# Front hook
self
.
_init_internal_storage
(
needs_fresh
)
# Sync layer's buffers state
if
self
.
_sync_buffers
:
self
.
__sync_buffers
()
# Normal FW on the base model
fw
=
self
.
_layer
(
*
inputs
,
**
kwargs
)
return
fw
def
clear_gradients
(
self
):
"""
Set zero to the gradient of the optimizer's current rank trainable parameters.
"""
# Release grad storages
for
dtype
in
self
.
_grad_storages
.
keys
():
if
self
.
_rank
in
self
.
_grad_storages
[
dtype
].
keys
():
self
.
_grad_storages
[
dtype
][
self
.
_rank
].
buffer
.
zero_
()
# Release params
for
param
in
self
.
_trainable_params
:
if
param
.
name
in
self
.
_param_grads
and
param
.
grad
is
not
None
:
param
.
clear_gradient
()
def
grad_scale
(
self
):
"""
Before the gradient accumulation, scale the gradient.
"""
# Scale grad storages
for
dtype
in
self
.
_grad_storages
.
keys
():
if
self
.
_rank
in
self
.
_grad_storages
[
dtype
].
keys
():
self
.
_grad_storages
[
dtype
][
self
.
_rank
].
buffer
.
scale_
(
scale
=
self
.
_world_size_scaling
)
# Scale params
for
param
in
self
.
_trainable_params
:
if
param
.
name
in
self
.
_param_grads
and
param
.
grad
is
not
None
:
param
.
grad
.
scale_
(
scale
=
self
.
_world_size_scaling
)
param
.
_reset_grad_inplace_version
()
def
_init_internal_storage
(
self
,
needs_fresh
):
"""
Judge Fresh trainable params or rebuild grad storage.
"""
if
needs_fresh
:
self
.
_fresh_trainable
()
else
:
self
.
_build_grad_storages
()
# Clear all flags state
self
.
_clear_counters
()
def
to
(
self
,
device
=
None
,
dtype
=
None
,
blocking
=
True
):
"""
Synchronously or asynchronously convert the data type of the layer, the device is not supported now.
"""
assert
device
==
self
.
_default_device
,
"New devices are not supported, because of the optimizer state is not sync"
def
_fresh_trainable
(
self
):
""" Whether to update training parameters. """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
if
reduce
(
lambda
x
,
y
:
x
or
y
,
self
.
_grad_reduced
,
False
):
logging
.
warning
(
"Grads waiting to be reduced."
)
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
self
.
_all_params
))
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
np
.
prod
(
x
.
shape
))
self
.
_trainable_param2rank
=
{}
for
optim
in
self
.
_sharding_optimizers
:
# Need to be wrappered for Sharding Stage2 Optimizer
if
len
(
optim
.
param_storages
.
keys
())
==
0
:
optim
.
update_opt_status
()
# Get the parameters split by the optimizer according to rank
for
per_rank_params
in
optim
.
dtype_rank_params
.
values
(
):
# all the params from all ranks
for
params
in
per_rank_params
:
for
param
in
filter
(
lambda
x
:
x
.
trainable
,
params
):
self
.
_trainable_param2rank
[
param
.
name
]
=
optim
.
param2rank
[
param
.
name
]
self
.
_trainable_param2align
[
param
.
name
]
=
optim
.
_param2align
[
param
.
name
]
self
.
_setup_use_grad_storage
()
# wait next func hook support
self
.
_setup_backward_hooks
()
@
paddle
.
no_grad
()
def
__sync_buffers
(
self
):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
"""
for
buffer
in
self
.
_layer
.
buffers
(
include_sublayers
=
True
):
dist
.
broadcast
(
buffer
,
self
.
_global_root_rank
,
self
.
_group
,
use_calc_stream
=
True
)
# Multi stream operation will be supported later
dist
.
wait
(
tensor
=
buffer
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
def
__getattr__
(
self
,
name
):
"""Forward missing attributes to wrapped layer."""
try
:
return
super
().
__getattr__
(
name
)
except
AttributeError
:
return
getattr
(
self
.
_layer
,
name
)
@
paddle
.
no_grad
()
def
_clear_counters
(
self
):
"""Reset all the grad reduce and call counters."""
if
self
.
training
:
self
.
_grad_reduced
=
[
True
for
_
in
self
.
_trainable_params
]
if
self
.
_use_grad_storage
:
for
grad_storage
in
self
.
_grad_storage_list
:
grad_storage
.
reset_checked_in
()
if
not
self
.
_accumulate_grads
:
self
.
_grads_flipped
=
False
def
_get_reduce_fn
(
self
,
index
,
param
,
dst_rank
):
"""
There are two ways to reduce gradient.
- 1. Do not use use_grad_storage or exceeded buffer_max_size will be reduced separately.
- 2. Use grad_storage Reduce the storage to get the full gradient from different ranks.
"""
if
not
self
.
_use_grad_storage
or
not
self
.
_has_grad_storage
[
index
]:
# Direct reduction
@
paddle
.
no_grad
()
def
reduce
(
*
_
):
# Skip gradient reduction, do not change status information
if
self
.
_grad_reduced
[
index
]:
assert
param
.
grad
is
not
None
,
"Parameter gradient cannot be None"
# Change reduce information
self
.
_grad_reduced
[
index
]
=
False
if
not
self
.
_accumulate_grads
:
param
.
grad
.
scale_
(
scale
=
self
.
_world_size_scaling
)
param
.
_reset_grad_inplace_version
()
# Clear the gradient that does not belong to the current rank through the callback function
def
cleanup
():
if
dst_rank
!=
self
.
_rank
:
param
.
clear_gradient
(
False
)
# Synchronize the reduce parameter gradient
self
.
_tasks_flow
.
append
(
Taskflow
(
task
=
dist
.
reduce
(
tensor
=
param
.
grad
,
dst
=
dst_rank
,
group
=
self
.
_group
,
use_calc_stream
=
True
),
callback
=
cleanup
))
# Multi stream operation will be supported later
dist
.
wait
(
tensor
=
param
.
grad
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
# Clear the task flow and trigger callback to clear the redundant gradient
self
.
_clear_task_flow
()
else
:
# Buffer reduction
@
paddle
.
no_grad
()
def
reduce
(
*
_
):
# Skip gradient reduction, do not change status information
if
self
.
_grad_reduced
[
index
]:
assert
param
.
grad
is
not
None
,
"Parameter gradient cannot be None"
# Change reduce information
self
.
_grad_reduced
[
index
]
=
False
grad_storage
=
self
.
_grad_storages
[
param
.
dtype
][
dst_rank
]
grad_storage
.
params_checked_in
+=
1
if
grad_storage
.
all_checked_in
:
assert
grad_storage
.
buffer
is
not
None
# Normalize all ranks grad_storage
if
not
self
.
_accumulate_grads
:
grad_storage
.
buffer
.
scale_
(
scale
=
self
.
_world_size_scaling
)
# Clearing up the grad_storage buffer
def
cleanup
():
if
dst_rank
!=
self
.
_rank
:
for
p
in
grad_storage
.
_params
:
p
.
clear_gradient
(
False
)
p
.
_gradient_set_empty
(
False
)
grad_storage
.
buffer
.
value
().
get_tensor
().
_clear
(
)
# Reduce the bucket
grad_storage
.
sent
=
True
self
.
_tasks_flow
.
append
(
Taskflow
(
task
=
dist
.
reduce
(
tensor
=
grad_storage
.
buffer
,
dst
=
grad_storage
.
destination
,
group
=
self
.
_group
,
use_calc_stream
=
True
),
callback
=
cleanup
))
# Multi stream operation will be supported later
dist
.
wait
(
tensor
=
grad_storage
.
buffer
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
# Clear the task flow and trigger callback to clear the redundant gradient
self
.
_clear_task_flow
()
return
reduce
def
_setup_backward_hooks
(
self
):
"""
Set the backward hook to synchronize the gradients of all rank by reduce group ranks.
"""
# Remove previous backward hooks
while
len
(
self
.
_bw_hooks
)
>
0
:
self
.
_bw_hooks
.
pop
().
remove
()
# Go through the parameters, attach the hook
self
.
_grad_accs
=
[]
if
not
self
.
training
:
return
for
index
,
param
in
enumerate
(
self
.
_trainable_params
):
dst_rank
=
self
.
_trainable_param2rank
[
param
.
name
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_bw_hooks
.
append
(
param
.
_register_backward_hook
(
reduce_function
))
@
paddle
.
no_grad
()
def
_sync_params_and_buffers
(
self
):
"""
Sync all model states for all ranks
"""
for
t
in
self
.
_layer
.
parameters
():
dist
.
broadcast
(
t
,
src
=
self
.
_global_root_rank
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
# Multi stream operation will be supported later
dist
.
wait
(
tensor
=
t
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
def
_setup_use_grad_storage
(
self
):
"""
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
"""
if
not
self
.
_use_grad_storage
:
return
# According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank
self
.
_grad_storages
=
{}
self
.
_has_grad_storage
=
[
False
for
_
in
self
.
_trainable_params
]
for
index
,
param
in
enumerate
(
self
.
_trainable_params
):
dst_rank
=
self
.
_trainable_param2rank
[
param
.
name
]
if
param
.
dtype
not
in
self
.
_grad_storages
.
keys
():
self
.
_grad_storages
[
param
.
dtype
]
=
{}
if
dst_rank
not
in
self
.
_grad_storages
[
param
.
dtype
].
keys
():
self
.
_grad_storages
[
param
.
dtype
][
dst_rank
]
=
GradStorage
(
self
.
_buffer_max_size
[
param
.
dtype
],
dtype
=
param
.
dtype
,
device
=
self
.
_default_device
,
destination
=
dst_rank
,
parm2align
=
self
.
_trainable_param2align
)
# Criteria to decide whether this parameter is to be put in GradStorage
if
self
.
_grad_storages
[
param
.
dtype
][
dst_rank
].
can_add_grad_view
(
param
,
self
.
_trainable_param2align
[
param
.
name
]):
self
.
_grad_storages
[
param
.
dtype
][
dst_rank
].
add_grad
(
param
,
self
.
_trainable_param2align
[
param
.
name
])
self
.
_has_grad_storage
[
index
]
=
True
else
:
self
.
_param_grads
.
append
(
param
.
name
)
print
(
"Can not add param: {}, param's shape: {}, param align: {}, grad_storages fill: {}, "
.
format
(
param
.
name
,
param
.
shape
,
self
.
_trainable_param2align
[
param
.
name
],
self
.
_grad_storages
[
param
.
dtype
][
dst_rank
]
.
_fill
))
self
.
_grad_storage_list
=
list
(
chain
(
*
[
self
.
_grad_storages
[
dtype
].
values
()
for
dtype
in
self
.
_grad_storages
.
keys
()
]))
def
_clear_task_flow
(
self
):
"""Try to consume the previous tasks."""
while
len
(
self
.
_tasks_flow
)
>
0
:
task
=
self
.
_tasks_flow
.
popleft
()
if
task
.
callback
is
not
None
:
task
.
callback
()
def
_detect_train_change
(
self
):
# Current trainable parameters
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# Whether parameters trainability changed
trainability_changed
=
trainable_mask
!=
self
.
_trainable_mask
# The whole model is not trainable but we still have grad hooks
trainability_changed
|=
not
self
.
training
and
len
(
self
.
_bw_hooks
)
>
0
if
trainability_changed
:
logging
.
warning
(
"Trainable params changed, because of eval/train mode or parameter freezing/unfreeze."
)
self
.
_trainable_mask
=
trainable_mask
return
trainability_changed
def
_build_grad_storages
(
self
):
"""
Rebuild grad storages.
"""
# Rebuild fp16/fp32 grad storages
for
dtype
in
self
.
_grad_storages
.
keys
():
for
dst_rank
,
grad_storage
in
self
.
_grad_storages
[
dtype
].
items
():
if
dst_rank
!=
self
.
_rank
:
grad_storage
.
manumal_relase
()
grad_storage
.
rebuild
()
def
_rank_buffer_size
(
self
,
buffer_max_size
,
model_size
):
"""
Generate the minimum buffer size for each rank & Display param sizes and model sizes.
"""
# Initialize buffer size
rank_buffer_size
=
{}
for
shard_opt
in
self
.
_sharding_optimizers
:
if
shard_opt
.
rank_buffer_size
:
for
dtype
in
shard_opt
.
rank_buffer_size
.
keys
():
sizes
=
max
(
shard_opt
.
rank_buffer_size
[
dtype
].
values
())
rank_buffer_size
[
dtype
]
=
min
(
sizes
,
buffer_max_size
)
if
Type
.
fp16
.
value
in
rank_buffer_size
.
keys
():
# FP16 GradStorage and model size
print
(
"====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======"
.
format
(
rank_buffer_size
[
Type
.
fp16
.
value
]
/
2
**
19
,
model_size
/
2
**
19
))
if
Type
.
fp32
.
value
in
rank_buffer_size
.
keys
():
# FP32 GradStorage and model size
print
(
"====== FP32 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======"
.
format
(
rank_buffer_size
[
Type
.
fp32
.
value
]
/
2
**
18
,
model_size
/
2
**
18
))
return
rank_buffer_size
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
20e19776
...
...
@@ -33,6 +33,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2
)
list
(
APPEND DIST_TEST_OPS test_dygraph_sharding_stage2
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_parallelizer
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers
)
list
(
APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper
)
...
...
@@ -244,6 +245,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2
)
list
(
REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2
)
list
(
REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
...
...
@@ -1039,6 +1041,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
0 → 100644
浏览文件 @
20e19776
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
argparse
import
ast
import
time
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.distributed
import
fleet
from
paddle.fluid.dygraph
import
nn
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer
import
DygraphShardingOptimizer
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2
import
ShardingOptimizerStage2
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2
import
ShardingStage2
seed
=
2021
epoch
=
2
batch_size
=
32
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
hybrid_configs
=
{
"dp_degree"
:
2
,
"mp_degree"
:
1
,
"pp_degree"
:
1
,
"sharding_degree"
:
1
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
class
MLP
(
fluid
.
Layer
):
def
__init__
(
self
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
MLP
,
self
).
__init__
()
self
.
_linear1
=
Linear
(
10000
,
10000
)
self
.
_linear2
=
Linear
(
10000
,
10000
)
self
.
_linear3
=
Linear
(
10000
,
10
)
def
forward
(
self
,
inputs
):
y
=
self
.
_linear1
(
inputs
)
y
=
self
.
_linear2
(
y
)
y
=
self
.
_linear3
(
y
)
return
y
def
reader_decorator
():
def
__reader__
():
for
_
in
range
(
100
):
img
=
np
.
random
.
rand
(
10000
).
astype
(
'float32'
)
label
=
np
.
ones
(
1
).
astype
(
'int64'
)
yield
img
,
label
return
__reader__
def
optimizer_setting
(
model
,
use_pure_fp16
,
stage
=
1
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
grad_clip
=
clip
,
multi_precision
=
use_pure_fp16
)
return
optimizer
def
train_mlp
(
model
,
sharding_stage
,
use_pure_fp16
=
False
,
all_test
=
False
,
accumulate_grad
=
False
):
if
sharding_stage
==
1
:
hcg
=
fleet
.
get_hybrid_communicate_group
()
group
=
hcg
.
get_check_parallel_group
()
else
:
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
,
stage
=
sharding_stage
)
if
use_pure_fp16
:
model
,
optimizer
=
paddle
.
amp
.
decorate
(
models
=
model
,
optimizers
=
optimizer
,
level
=
'O2'
,
save_dtype
=
'float32'
)
if
sharding_stage
==
2
:
optimizer
=
ShardingOptimizerStage2
(
params
=
model
.
parameters
(),
optim
=
optimizer
,
group
=
group
)
if
all_test
:
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
,
accumulate_grads
=
accumulate_grad
)
else
:
model
=
ShardingStage2
(
model
,
optimizer
,
group
=
group
)
else
:
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
model
=
fleet
.
distributed_model
(
model
)
train_reader
=
paddle
.
batch
(
reader_decorator
(),
batch_size
=
batch_size
,
drop_last
=
True
)
train_loader
=
paddle
.
io
.
DataLoader
.
from_generator
(
capacity
=
32
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
train_loader
.
set_sample_list_generator
(
train_reader
)
for
eop
in
range
(
epoch
):
model
.
train
()
for
batch_id
,
data
in
enumerate
(
train_loader
()):
img
,
label
=
data
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
enable
=
use_pure_fp16
,
level
=
'O2'
):
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
avg_loss
.
backward
()
if
accumulate_grad
and
batch_id
==
2
:
model
.
grad_scale
()
optimizer
.
step
()
model
.
clear_gradients
()
return
model
.
parameters
()
if
not
accumulate_grad
:
optimizer
.
step
()
if
sharding_stage
==
2
:
model
.
clear_gradients
()
else
:
optimizer
.
clear_grad
()
if
all_test
and
batch_id
==
2
:
return
model
.
parameters
()
if
sharding_stage
==
2
:
model
.
to
(
device
=
"gpu"
)
return
model
.
parameters
()
def
test_stage1_stage2
():
mlp
=
MLP
()
state_dict
=
mlp
.
state_dict
()
mlp1
=
MLP
()
mlp2
=
MLP
()
mlp3
=
MLP
()
mlp4
=
MLP
()
mlp1
.
set_state_dict
(
state_dict
)
mlp2
.
set_state_dict
(
state_dict
)
mlp3
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
stage1_params
=
train_mlp
(
mlp
,
sharding_stage
=
1
,
use_pure_fp16
=
False
)
stage2_params
=
train_mlp
(
mlp
,
sharding_stage
=
2
,
use_pure_fp16
=
False
)
for
i
in
range
(
len
(
stage1_params
)):
np
.
testing
.
assert_allclose
(
stage1_params
[
i
].
numpy
(),
stage2_params
[
i
].
numpy
(),
rtol
=
1e-6
)
stage2_params
=
train_mlp
(
mlp3
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
all_test
=
True
)
stage2_accumulate_grad
=
train_mlp
(
mlp4
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
all_test
=
True
,
accumulate_grad
=
True
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage2_accumulate_grad
)):
if
stage2_params
[
i
].
name
==
stage2_accumulate_grad
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage2_accumulate_grad
[
j
].
numpy
(),
rtol
=
1e-6
)
return
if
__name__
==
'__main__'
:
test_stage1_stage2
()
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py
0 → 100644
浏览文件 @
20e19776
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestDygraphShardingStage2
(
TestMultipleGpus
):
# check sharding logic as well as the accuracy with single mode
def
test_dygraph_sharding_optimizer_stage2
(
self
):
self
.
run_mnist_2gpu
(
'dygraph_sharding_stage2.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录