Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f33f2444
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f33f2444
编写于
7月 01, 2021
作者:
J
JZ-LIANG
提交者:
GitHub
7月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Dygraph/sharding (#33633)
* dygraph sharding * update unitest hybrid_parallel_communicate_group
上级
3e82a794
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
777 addition
and
21 deletion
+777
-21
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+14
-4
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+43
-12
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
...ptimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
+198
-0
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
...optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
+16
-4
python/paddle/distributed/fleet/meta_parallel/__init__.py
python/paddle/distributed/fleet/meta_parallel/__init__.py
+1
-0
python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py
...ddle/distributed/fleet/meta_parallel/sharding_parallel.py
+33
-0
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
...on/paddle/distributed/fleet/utils/hybrid_parallel_util.py
+43
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py
...luid/tests/unittests/hybrid_parallel_communicate_group.py
+2
-1
python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py
...e/fluid/tests/unittests/hybrid_parallel_sharding_model.py
+297
-0
python/paddle/fluid/tests/unittests/test_hybrid_parallel_topology.py
...le/fluid/tests/unittests/test_hybrid_parallel_topology.py
+93
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py
...uid/tests/unittests/test_parallel_dygraph_dataparallel.py
+2
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py
...ests/unittests/test_parallel_dygraph_sharding_parallel.py
+31
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
f33f2444
...
...
@@ -47,6 +47,7 @@ message HybridConfig {
optional
int32
dp_degree
=
1
[
default
=
-
1
];
optional
int32
mp_degree
=
2
[
default
=
1
];
optional
int32
pp_degree
=
3
[
default
=
1
];
optional
int32
sharding_degree
=
4
[
default
=
1
];
}
message
AMPConfig
{
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
f33f2444
...
...
@@ -30,7 +30,7 @@ from paddle.fluid.dygraph import parallel_helper
from
.
import
topology
as
tp
from
.topology
import
ParallelMode
from
..meta_parallel
import
TensorParallel
,
model_parallel_random_seed
from
..meta_parallel
import
PipelineParallel
from
..meta_parallel
import
PipelineParallel
,
ShardingParallel
from
..meta_optimizers
import
HybridParallelOptimizer
from
..meta_optimizers
import
HybridParallelGradScaler
...
...
@@ -295,9 +295,11 @@ class Fleet(object):
self
.
dp_degree
=
self
.
hybrid_configs
[
"dp_degree"
]
self
.
mp_degree
=
self
.
hybrid_configs
[
"mp_degree"
]
self
.
pp_degree
=
self
.
hybrid_configs
[
"pp_degree"
]
self
.
sharding_degree
=
self
.
hybrid_configs
[
"sharding_degree"
]
assert
self
.
mp_degree
>=
0
,
"mp_degree should be greater or equal to 0"
assert
self
.
pp_degree
>=
0
,
"pp_degree should be greater or equal to 0"
assert
self
.
sharding_degree
>=
0
,
"sharding_degree should be greater or equal to 0"
self
.
mp_degree
=
max
(
self
.
mp_degree
,
1
)
self
.
pp_degree
=
max
(
self
.
pp_degree
,
1
)
...
...
@@ -309,8 +311,11 @@ class Fleet(object):
self
.
dp_degree
=
max
(
self
.
dp_degree
,
1
)
self
.
_topology
=
tp
.
CommunicateTopology
(
hybrid_group_names
=
[
"data"
,
"pipe"
,
"model"
],
dims
=
[
self
.
dp_degree
,
self
.
pp_degree
,
self
.
mp_degree
])
hybrid_group_names
=
[
"data"
,
"pipe"
,
"sharding"
,
"model"
],
dims
=
[
self
.
dp_degree
,
self
.
pp_degree
,
self
.
sharding_degree
,
self
.
mp_degree
])
self
.
_hcg
=
tp
.
HybridCommunicateGroup
(
self
.
_topology
)
...
...
@@ -886,7 +891,11 @@ class Fleet(object):
assert
model
is
not
None
,
"model should not be None"
if
self
.
worker_num
()
<=
1
:
return
model
if
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
DATA_PARALLEL
:
if
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
SHARDING_PARALLEL
:
distributed_model
=
ShardingParallel
(
model
,
self
.
_hcg
,
strategy
=
self
.
_user_defined_strategy
)
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
DATA_PARALLEL
:
distributed_model
=
paddle
.
DataParallel
(
model
,
comm_buffer_size
=
self
.
_user_defined_strategy
.
...
...
@@ -901,6 +910,7 @@ class Fleet(object):
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
PIPELINE_PARALLEL
:
distributed_model
=
PipelineParallel
(
model
,
self
.
_hcg
,
strategy
=
self
.
_user_defined_strategy
)
return
distributed_model
@
dygraph_only
...
...
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
f33f2444
...
...
@@ -30,12 +30,13 @@ class ParallelMode(object):
DATA_PARALLEL
=
0
TENSOR_PARALLEL
=
1
PIPELINE_PARALLEL
=
2
SHARDING_PARALLEL
=
3
class
CommunicateTopology
(
object
):
def
__init__
(
self
,
hybrid_group_names
=
[
"data"
,
"pipe"
,
"model"
],
dims
=
[
1
,
1
,
1
]):
hybrid_group_names
=
[
"data"
,
"pipe"
,
"
sharding"
,
"
model"
],
dims
=
[
1
,
1
,
1
,
1
]):
self
.
_parallel_names
=
hybrid_group_names
self
.
_dims
=
dims
self
.
coordinate
=
collections
.
namedtuple
(
'Coordinate'
,
...
...
@@ -122,15 +123,17 @@ class HybridCommunicateGroup(object):
self
.
_dp_degree
=
self
.
_topo
.
get_dim
(
'data'
)
self
.
_mp_degree
=
self
.
_topo
.
get_dim
(
'model'
)
self
.
_pp_degree
=
self
.
_topo
.
get_dim
(
'pipe'
)
self
.
_sharding_degree
=
self
.
_topo
.
get_dim
(
'sharding'
)
self
.
_data_parallel_id
=
self
.
_get_data_parallel_id
()
self
.
_model_parallel_id
=
self
.
_get_model_parallel_id
()
self
.
_sharding_parallel_id
=
self
.
_get_sharding_parallel_id
()
self
.
stage_id
=
self
.
_get_pipe_parallel_id
()
assert
self
.
_check_vaild_topo
(
),
"Here is an unreasonable topogy setting. world_size: {}, but"
\
"
dp_num: {}, mp_num: {}, pp_num: {}"
.
format
(
self
.
nranks
,
self
.
_dp_degree
,
self
.
_mp_degree
,
self
.
_
p
p_degree
)
"
mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}"
.
format
(
self
.
nranks
,
self
.
_mp_degree
,
self
.
_
sharding_degree
,
self
.
_pp_degree
,
self
.
_d
p_degree
)
# create comm group for data parallel
self
.
_dp_group
,
self
.
_dp_comm_group
=
self
.
_set_comm_group
(
"data"
)
...
...
@@ -141,6 +144,10 @@ class HybridCommunicateGroup(object):
# create comm group for pipe parallel
self
.
_pp_group
,
self
.
_pp_comm_group
=
self
.
_set_comm_group
(
"pipe"
)
# create comm group for sharding parallel
self
.
_sharding_group
,
self
.
_sharding_comm_group
=
self
.
_set_comm_group
(
"sharding"
)
# create global group for check inf_nan / clip global norm
self
.
_check_group
,
self
.
_check_comm_group
=
self
.
_set_check_group
(
"data"
)
...
...
@@ -149,19 +156,26 @@ class HybridCommunicateGroup(object):
self
.
is_first_stage
=
(
self
.
stage_id
==
0
)
self
.
is_last_stage
=
(
self
.
stage_id
==
(
self
.
_pp_degree
-
1
))
debug_str
=
"HybridParallelInfo: rank_id: %d, dp_degree: %d, "
\
"mp_degree: %d, pp_degree: %d"
%
(
self
.
global_rank
,
self
.
_dp_degree
,
self
.
_mp_degree
,
self
.
_pp_degree
)
debug_str
+=
", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s"
%
(
self
.
_dp_group
,
self
.
_mp_group
,
self
.
_pp_group
,
self
.
_check_group
)
debug_str
=
"HybridParallelInfo: rank_id: %d, mp_degree: %d, "
\
"sharding_degree: %d, pp_degree: %d, dp_degree: %d"
%
(
self
.
global_rank
,
self
.
_mp_degree
,
self
.
_sharding_degree
,
self
.
_pp_degree
,
self
.
_dp_degree
)
debug_str
+=
", mp_group: %s, sharding_group: %s, pp_group: %s, dp_group: %s, check/clip group: %s"
%
(
self
.
_mp_group
,
self
.
_sharding_group
,
self
.
_pp_group
,
self
.
_dp_group
,
self
.
_check_group
)
logger
.
info
(
debug_str
)
global
_HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP
=
self
def
get_parallel_mode
(
self
):
# there are three modes : DataParallel / TensorParallel / PipelineParallel
if
self
.
_mp_degree
==
1
and
self
.
_pp_degree
==
1
:
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
# adding its parallel logic within that parallelism
# when use sharding alone, it should have its own parallelism for its parallel logic
# TODO modify 3 others parallel to support sharding
if
self
.
_mp_degree
==
1
and
self
.
_pp_degree
==
1
and
self
.
_dp_degree
==
1
and
self
.
_sharding_degree
>
1
:
return
ParallelMode
.
SHARDING_PARALLEL
elif
self
.
_mp_degree
==
1
and
self
.
_pp_degree
==
1
:
return
ParallelMode
.
DATA_PARALLEL
elif
self
.
_mp_degree
>
1
and
self
.
_pp_degree
==
1
:
# initialize the seed
...
...
@@ -170,7 +184,7 @@ class HybridCommunicateGroup(object):
return
ParallelMode
.
PIPELINE_PARALLEL
def
_check_vaild_topo
(
self
):
return
self
.
_dp_degree
*
self
.
_mp_degree
*
self
.
_pp_degree
==
self
.
nranks
return
self
.
_dp_degree
*
self
.
_mp_degree
*
self
.
_pp_degree
*
self
.
_sharding_degree
==
self
.
nranks
def
_set_comm_group
(
self
,
parallel_method
=
"data"
):
parallel_group
=
[]
...
...
@@ -255,6 +269,23 @@ class HybridCommunicateGroup(object):
def
get_pipe_parallel_group
(
self
):
return
self
.
_pp_comm_group
# sharding parallel message:
def
_get_sharding_parallel_id
(
self
):
return
self
.
_topo
.
get_coord
(
self
.
global_rank
).
sharding
def
get_sharding_parallel_rank
(
self
):
return
self
.
_sharding_parallel_id
def
get_sharding_parallel_world_size
(
self
):
return
self
.
_sharding_degree
def
get_sharding_parallel_group
(
self
):
return
self
.
_sharding_comm_group
def
get_sharding_parallel_group_src_rank
(
self
):
# TODO should the src rank related to the shard rank for each parameter ?
return
self
.
_sharding_comm_group
.
ranks
[
0
]
# check parallel group
def
get_check_parallel_group
(
self
):
return
self
.
_check_comm_group
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
0 → 100755
浏览文件 @
f33f2444
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######
from
functools
import
reduce
import
paddle
from
paddle
import
framework
from
...utils.log_util
import
logger
def
_is_trainable
(
param
:
paddle
.
Tensor
)
->
bool
:
return
not
param
.
stop_gradient
class
DygraphShardingOptimizer
(
object
):
"""
A wrapper for Sharding Optimizer in Dygraph.
.. warning: DygraphShardingOptimizer is experimental and subject to change.
.. ZeRO: https://arxiv.org/abs/1910.02054
"""
# TODO (JZ-LIANG)
# TO support following featrues in future:
# 1. fused update parameter sync
# 2. parameters_groups
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
def
__init__
(
self
,
hcg
,
user_defined_strategy
,
params
,
inner_optimizer_class
,
**
inner_optimizer_kargs
,
):
if
not
isinstance
(
params
,
list
):
raise
TypeError
(
"`parameters` argument given to the DygraphShardingOptimizer should be "
"an iterable of paddle Tensors, but got argument type is `{}`."
.
format
(
type
(
params
)))
self
.
_parameter_list
=
params
self
.
_reference_is_trainable_params
=
list
(
map
(
_is_trainable
,
self
.
_parameter_list
))
self
.
_inner_optimizer_class
=
inner_optimizer_class
self
.
_inner_optimizer_kargs
=
inner_optimizer_kargs
# sharding parallel information
# TODO better way to get the hcg & user_defined_strategy
self
.
_hcg
=
hcg
self
.
_user_defined_strategy
=
user_defined_strategy
self
.
_sharding_world_size
=
self
.
_hcg
.
get_sharding_parallel_world_size
()
self
.
_sharding_rank
=
self
.
_hcg
.
get_sharding_parallel_rank
()
# logic partitioning
self
.
_build_sharding_mapping
()
# actually create opt ops
self
.
_buid_inner_optimizer
()
def
clear_grad
(
self
):
"""
should clear grad for all parameters in model
"""
for
p
in
self
.
_parameter_list
:
if
not
p
.
stop_gradient
:
p
.
clear_gradient
()
def
_build_sharding_mapping
(
self
):
self
.
_rank2params
=
self
.
_partition_parameters
()
self
.
_param2rank
=
self
.
_map_param_to_rank
()
def
_partition_parameters
(
self
):
"""
Partitions parameters among sharding ranks.
Return:
Dict[int, List]
"""
# TODO(JZ-LIANG) support multiple partition methods
# method1: greedy even but unorder
# method2: roughly even with oreder
mapping
=
{}
for
rank_
in
range
(
self
.
_sharding_world_size
):
mapping
[
rank_
]
=
[]
sizes
=
[
0
]
*
self
.
_sharding_world_size
for
param
in
self
.
_parameter_list
:
rank
=
sizes
.
index
(
min
(
sizes
))
mapping
[
rank
].
append
(
param
)
numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
param
.
shape
)
assert
numel
>
0
,
"param [{}] should larger than 0, but it is [{}]"
.
format
(
param
.
name
,
numel
)
sizes
[
rank
]
+=
numel
return
mapping
def
_map_param_to_rank
(
self
):
"""
mapping parameters to the shard which holds it.
Return:
Dict[str, int]
"""
mapping
=
{}
for
rank
,
params
in
self
.
_rank2params
.
items
():
for
param
in
params
:
mapping
[
param
.
name
]
=
rank
return
mapping
def
_buid_inner_optimizer
(
self
):
# we rely on the inner opt to determine whether a parameter is stop_gradient or not:
# create moment
# update related ops: clip, regular, opt
self
.
_inner_optimizer
=
self
.
_inner_optimizer_class
(
parameters
=
self
.
_rank2params
[
self
.
_sharding_rank
],
**
self
.
_inner_optimizer_kargs
)
def
_sharding_sync_parameters
(
self
):
"""
sync parameter across sharding group
"""
# TODO speed up this functional
logger
.
debug
(
"sharding start sync parameters"
)
with
framework
.
no_grad
():
# TODO detach not need (?)
for
rank
,
params
in
self
.
_rank2params
.
items
():
for
param
in
params
:
paddle
.
distributed
.
broadcast
(
param
,
# the collective API need src rank to be the global rank id
# instead of the relative logic rank id within group
src
=
self
.
_hcg
.
get_sharding_parallel_group
().
ranks
[
rank
],
group
=
self
.
_hcg
.
get_sharding_parallel_group
(),
use_calc_stream
=
True
)
def
_update_trainable
(
self
):
"""
allow user to update trainable parameters list during training
"""
raise
NotImplementedError
def
minimize
(
self
,
loss
,
startup_program
=
None
,
parameters
=
None
,
no_grad_set
=
None
):
# NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step
input_param_names
=
set
([
param
.
name
for
param
in
parameters
])
parameters
=
list
(
filter
(
lambda
x
:
x
.
name
in
input_param_names
,
self
.
_rank2params
[
self
.
_sharding_rank
]))
result
=
self
.
_inner_optimizer
.
minimize
(
loss
,
startup_program
,
parameters
,
no_grad_set
)
# sync parameters accross sharding ranks
self
.
_sharding_sync_parameters
()
return
result
def
step
(
self
):
# TODO Check whether the model trainable param changed and update state accordingly
# actually updating
self
.
_inner_optimizer
.
step
()
# sync parameters accross sharding ranks
self
.
_sharding_sync_parameters
()
# TODO is it a good way to make _grad_clip a property
@
property
def
_grad_clip
(
self
):
assert
self
.
_inner_optimizer
is
not
None
,
"inner opt of sharding is not initiliazed."
return
self
.
_inner_optimizer
.
_grad_clip
def
__getattr__
(
self
,
item
):
return
getattr
(
self
.
_inner_optimizer
,
item
)
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
浏览文件 @
f33f2444
...
...
@@ -17,7 +17,7 @@ import sys
import
paddle
from
paddle.optimizer
import
Optimizer
from
paddle.fluid.clip
import
ClipGradByGlobalNorm
from
...utils.hybrid_parallel_util
import
fused_allreduce_gradients
from
...utils.hybrid_parallel_util
import
fused_allreduce_gradients
,
sharding_reduce_gradients
from
...base.topology
import
ParallelMode
from
paddle.fluid.dygraph
import
base
as
imperative_base
from
paddle.fluid
import
framework
...
...
@@ -98,6 +98,9 @@ class HybridParallelOptimizer:
self
.
_need_dp
=
(
self
.
_hcg
.
get_data_parallel_world_size
()
>
1
)
self
.
_sharding_enable
=
(
self
.
_hcg
.
get_sharding_parallel_world_size
()
>
1
)
if
isinstance
(
self
.
_inner_opt
.
_grad_clip
,
ClipGradByGlobalNorm
)
and
not
self
.
_use_dp_mode
:
logger
.
warning
(
"using ClipGradByGlobalNorm in TensorParallel, the origin "
\
...
...
@@ -108,6 +111,11 @@ class HybridParallelOptimizer:
@
imperative_base
.
no_grad
@
framework
.
dygraph_only
def
step
(
self
):
# Here should use global parameter list
if
self
.
_sharding_enable
:
sharding_reduce_gradients
(
list
(
self
.
_inner_opt
.
_parameter_list
),
self
.
_hcg
)
if
not
self
.
_use_dp_mode
and
self
.
_need_dp
:
fused_allreduce_gradients
(
list
(
self
.
_inner_opt
.
_parameter_list
),
self
.
_hcg
)
...
...
@@ -119,15 +127,19 @@ class HybridParallelOptimizer:
startup_program
=
None
,
parameters
=
None
,
no_grad_set
=
None
):
assert
isinstance
(
loss
,
Variable
),
"The loss should be an Tensor."
parameter_list
=
parameters
if
parameters
\
else
self
.
_parameter_list
else
self
.
_inner_opt
.
_parameter_list
# Here should use global parameter list
if
self
.
_sharding_enable
:
sharding_reduce_gradients
(
list
(
self
.
_inner_opt
.
_parameter_list
),
self
.
_hcg
)
if
not
self
.
_use_dp_mode
and
self
.
_need_dp
:
fused_allreduce_gradients
(
list
(
parameter_list
),
self
.
_hcg
)
return
self
.
_inner_opt
.
minimize
(
loss
,
startup_program
,
parameter
s
,
return
self
.
_inner_opt
.
minimize
(
loss
,
startup_program
,
parameter
_list
,
no_grad_set
)
def
__getattr__
(
self
,
item
):
...
...
python/paddle/distributed/fleet/meta_parallel/__init__.py
浏览文件 @
f33f2444
...
...
@@ -24,5 +24,6 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401
from
.parallel_layers
import
get_rng_state_tracker
# noqa: F401
from
.tensor_parallel
import
TensorParallel
# noqa: F401
from
.pipeline_parallel
import
PipelineParallel
# noqa: F401
from
.sharding_parallel
import
ShardingParallel
# noqa: F401
__all__
=
[]
python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py
0 → 100644
浏览文件 @
f33f2444
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid.dygraph.layers
import
Layer
from
.meta_parallel_base
import
MetaParallelBase
from
..utils.hybrid_parallel_util
import
broadcast_sharding_parameters
from
..utils.log_util
import
logger
__all__
=
[]
class
ShardingParallel
(
MetaParallelBase
):
def
__init__
(
self
,
layers
,
hcg
,
**
kwargs
):
super
(
ShardingParallel
,
self
).
__init__
(
layers
,
hcg
,
**
kwargs
)
def
_prepare_for_model
(
self
):
logger
.
info
(
"start broadcast sharding parameters"
)
broadcast_sharding_parameters
(
self
.
_layers
,
self
.
_hcg
)
# TODO (JZ-LIANG) to support Sharding-DP
logger
.
info
(
"sharding's parameters is ready"
)
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
浏览文件 @
f33f2444
...
...
@@ -119,3 +119,46 @@ def fused_allreduce_gradients(parameter_list, hcg):
logger
.
debug
(
"dp start fuse allreduce gradients"
)
with
framework
.
no_grad
():
_apply_collective_grads
(
parameter_list
,
data_parallel_group
)
def
sharding_reduce_gradients
(
parameter_list
,
hcg
):
# TODO allreduce --> reduce
# TODO merge grad / nrank with dp
logger
.
debug
(
"sharding start gradients sync"
)
with
framework
.
no_grad
():
sharding_nrank
=
hcg
.
get_sharding_parallel_group
().
nranks
for
param
in
parameter_list
:
if
param
.
trainable
and
(
param
.
_grad_ivar
()
is
not
None
):
g_var
=
param
.
_grad_ivar
()
# need use trace_op to allreduce
# paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"c_allreduce_sum"
,
inputs
=
{
'X'
:
g_var
},
outputs
=
{
'Out'
:
g_var
},
attrs
=
{
'ring_id'
:
hcg
.
get_sharding_parallel_group
().
id
,
'use_calc_stream'
:
True
})
# grad / sharding_rank
div_factor
=
paddle
.
to_tensor
(
sharding_nrank
,
dtype
=
g_var
.
dtype
)
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"elementwise_div"
,
inputs
=
{
'X'
:
g_var
,
'Y'
:
div_factor
},
outputs
=
{
'Out'
:
g_var
},
attrs
=
{
'axis'
:
-
1
})
def
broadcast_sharding_parameters
(
model
,
hcg
):
# TODO TO save memory, use un-fused broadcast to avoid potentional OOM
logger
.
debug
(
"sharding start init parameters sync"
)
sharding_parallel_group
=
hcg
.
get_sharding_parallel_group
()
src_rank
=
hcg
.
get_sharding_parallel_group_src_rank
()
sync_params_buffers
(
model
,
sharding_parallel_group
,
src_rank
,
is_model_parallel
=
False
)
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
f33f2444
...
...
@@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
...
...
@@ -185,6 +186,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_base_single
)
...
...
@@ -882,6 +884,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py
浏览文件 @
f33f2444
...
...
@@ -21,7 +21,8 @@ from paddle.distributed import fleet
class
TestNewGroupAPI
(
object
):
def
__init__
(
self
):
paddle
.
distributed
.
init_parallel_env
()
topo
=
fleet
.
CommunicateTopology
([
"data"
,
"model"
,
"pipe"
],
[
2
,
1
,
1
])
topo
=
fleet
.
CommunicateTopology
([
"data"
,
"model"
,
"sharding"
,
"pipe"
],
[
2
,
1
,
1
,
1
])
self
.
hcg
=
fleet
.
HybridCommunicateGroup
(
topo
)
d1
=
np
.
array
([
1
,
2
,
3
])
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py
0 → 100644
浏览文件 @
f33f2444
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
numpy
as
np
import
random
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet
as
fleet
from
paddle.io
import
DataLoader
,
Dataset
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer
import
DygraphShardingOptimizer
import
unittest
vocab_size
=
20
hidden_size
=
10
inner_size
=
8
output_size
=
10
seq_length
=
2
batch_size
=
4
STEPS
=
10
def
parallel_matmul
(
lm_output
,
logit_weights
,
parallel_output
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
model_parallel_group
=
hcg
.
get_model_parallel_group
()
world_size
=
hcg
.
get_model_parallel_world_size
()
rank
=
hcg
.
get_model_parallel_rank
()
if
world_size
>
1
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
lm_output
,
group
=
model_parallel_group
)
logits
=
paddle
.
matmul
(
input_parallel
,
logit_weights
,
transpose_y
=
True
)
if
parallel_output
:
return
logits
return
paddle
.
distributed
.
collective
.
_c_concat
(
logits
,
group
=
model_parallel_group
)
else
:
logits
=
paddle
.
matmul
(
lm_output
,
logit_weights
,
transpose_y
=
True
)
return
logits
class
SimpleMPNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
,
mp_id
):
super
(
SimpleMPNet
,
self
).
__init__
()
if
mp_id
==
0
:
init_fc1_data
=
np_fc1
[:,
:(
inner_size
//
2
)]
init_fc2_data
=
np_fc2
[:(
inner_size
//
2
),
:]
else
:
init_fc1_data
=
np_fc1
[:,
(
inner_size
//
2
):]
init_fc2_data
=
np_fc2
[(
inner_size
//
2
):,
:]
self
.
linear1
=
fleet
.
meta_parallel
.
ColumnParallelLinear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
init_fc1_data
)),
gather_output
=
False
,
has_bias
=
True
)
self
.
linear2
=
fleet
.
meta_parallel
.
RowParallelLinear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
init_fc2_data
)),
input_is_parallel
=
True
,
has_bias
=
True
)
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
embedding
=
fleet
.
meta_parallel
.
VocabParallelEmbedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
parallel_matmul
(
x
,
self
.
embedding
.
weight
,
False
)
return
x
class
SimpleDPNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
):
super
(
SimpleDPNet
,
self
).
__init__
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc1
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
linear2
=
paddle
.
nn
.
Linear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc2
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)))
self
.
embedding
=
paddle
.
nn
.
Embedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
paddle
.
matmul
(
x
,
self
.
embedding
.
weight
,
transpose_y
=
True
)
return
x
class
TestDistMPTraning
(
unittest
.
TestCase
):
def
setUp
(
self
):
random
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
paddle
.
seed
(
2021
)
self
.
strategy
=
fleet
.
DistributedStrategy
()
self
.
strategy
.
hybrid_configs
=
{
"sharding_degree"
:
2
,
"dp_degree"
:
1
,
"mp_degree"
:
1
,
"pp_degree"
:
1
,
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
self
.
strategy
)
self
.
data
=
[
np
.
random
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_length
,
))
for
_
in
range
(
STEPS
)
]
def
train_batch
(
self
,
batch
,
model
,
optimizer
):
output
=
model
(
batch
)
loss
=
output
.
mean
()
loss
.
backward
()
# do backward
optimizer
.
step
()
# update parameters
optimizer
.
clear_grad
()
return
loss
def
build_optimizer
(
self
,
model
,
strategy
=
None
,
is_sharding
=
True
,
Optimizer
=
"adam"
):
if
Optimizer
==
"adam"
:
if
is_sharding
:
optimizer
=
DygraphShardingOptimizer
(
hcg
=
fleet
.
get_hybrid_communicate_group
(),
user_defined_strategy
=
strategy
,
params
=
model
.
parameters
(),
inner_optimizer_class
=
paddle
.
optimizer
.
Adam
,
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
)
else
:
optimizer
=
paddle
.
optimizer
.
Adam
(
parameters
=
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
)
else
:
if
is_sharding
:
optimizer
=
DygraphShardingOptimizer
(
hcg
=
fleet
.
get_hybrid_communicate_group
(),
user_defined_strategy
=
strategy
,
params
=
model
.
parameters
(),
inner_optimizer_class
=
paddle
.
optimizer
.
Momentum
,
learning_rate
=
0.001
,
)
else
:
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
return
optimizer
def
build_model_optimizer
(
self
,
Optimizer
=
"adam"
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
word_size
=
hcg
.
get_model_parallel_world_size
()
sharding_id
=
hcg
.
get_sharding_parallel_rank
()
dp_id
=
hcg
.
get_data_parallel_rank
()
rank_id
=
dist
.
get_rank
()
np_fc1
=
np
.
random
.
random_sample
((
hidden_size
,
inner_size
))
np_fc2
=
np
.
random
.
random_sample
((
inner_size
,
hidden_size
))
model_a
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
)
optimizer_a
=
self
.
build_optimizer
(
model_a
,
strategy
=
self
.
strategy
,
is_sharding
=
True
,
Optimizer
=
Optimizer
)
model_a
=
fleet
.
distributed_model
(
model_a
)
optimizer_a
=
fleet
.
distributed_optimizer
(
optimizer_a
)
model_b
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
)
optimizer_b
=
self
.
build_optimizer
(
model_b
,
strategy
=
self
.
strategy
,
is_sharding
=
False
,
Optimizer
=
Optimizer
)
return
model_a
,
optimizer_a
,
model_b
,
optimizer_b
def
sharding_model
(
self
,
Optimizer
,
sharded_accumulators
):
model_a
,
optimizer_a
,
model_b
,
optimizer_b
=
self
.
build_model_optimizer
(
Optimizer
=
Optimizer
)
self
.
assertTrue
(
isinstance
(
optimizer_a
.
_inner_opt
,
DygraphShardingOptimizer
))
for
idx
in
range
(
STEPS
):
if
idx
==
2
and
paddle
.
distributed
.
get_rank
()
==
0
:
self
.
assertTrue
(
set
(
optimizer_a
.
_inner_opt
.
_inner_optimizer
.
state_dict
()
.
keys
())
==
sharded_accumulators
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
batch_sharding
=
paddle
.
to_tensor
(
self
.
data
[
idx
][:
2
])
else
:
batch_sharding
=
paddle
.
to_tensor
(
self
.
data
[
idx
][
2
:])
batch_single
=
paddle
.
to_tensor
(
self
.
data
[
idx
])
loss_a
=
self
.
train_batch
(
batch_sharding
,
model_a
,
optimizer_a
)
loss_b
=
self
.
train_batch
(
batch_single
,
model_b
,
optimizer_b
)
for
j
in
range
(
len
(
model_a
.
parameters
())):
np
.
testing
.
assert_allclose
(
model_a
.
parameters
()[
j
].
numpy
(),
model_b
.
parameters
()[
j
].
numpy
(),
rtol
=
1e-6
)
def
test_sharding_adam
(
self
):
sharded_accumulators
=
set
([
'linear_0.w_0_moment1_0'
,
'linear_1.b_0_moment1_0'
,
'linear_2.b_0_moment1_0'
,
'embedding_0.w_0_moment1_0'
,
'linear_0.w_0_moment2_0'
,
'linear_1.b_0_moment2_0'
,
'linear_2.b_0_moment2_0'
,
'embedding_0.w_0_moment2_0'
,
'linear_0.w_0_beta1_pow_acc_0'
,
'linear_1.b_0_beta1_pow_acc_0'
,
'linear_2.b_0_beta1_pow_acc_0'
,
'embedding_0.w_0_beta1_pow_acc_0'
,
'linear_0.w_0_beta2_pow_acc_0'
,
'linear_1.b_0_beta2_pow_acc_0'
,
'linear_2.b_0_beta2_pow_acc_0'
,
'embedding_0.w_0_beta2_pow_acc_0'
])
self
.
sharding_model
(
Optimizer
=
"adam"
,
sharded_accumulators
=
sharded_accumulators
)
def
test_sharding_momentum
(
self
):
sharded_accumulators
=
set
([
'linear_6.w_0_velocity_0'
,
'linear_7.b_0_velocity_0'
,
'linear_8.b_0_velocity_0'
,
'embedding_2.w_0_velocity_0'
])
self
.
sharding_model
(
Optimizer
=
"Momentum"
,
sharded_accumulators
=
sharded_accumulators
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_hybrid_parallel_topology.py
浏览文件 @
f33f2444
...
...
@@ -79,6 +79,99 @@ class TestCommunicateTopology(unittest.TestCase):
self
.
assertEqual
(
topo
.
get_dim_size
(
"mp"
),
2
)
self
.
assertEqual
(
topo
.
get_dim_size
(
"pp"
),
2
)
def
test_topology_4D
(
self
):
topo
=
fleet
.
CommunicateTopology
([
"dp"
,
"pp"
,
"sharding"
,
"mp"
],
[
2
,
2
,
2
,
2
])
# test get_comm_list
dp_comm_list
=
[[
0
,
8
],
[
1
,
9
],
[
2
,
10
],
[
3
,
11
],
[
4
,
12
],
[
5
,
13
],
[
6
,
14
],
[
7
,
15
]]
mp_comm_list
=
[[
0
,
1
],
[
2
,
3
],
[
4
,
5
],
[
6
,
7
],
[
8
,
9
],
[
10
,
11
],
[
12
,
13
],
[
14
,
15
]]
pp_comm_list
=
[[
0
,
4
],
[
1
,
5
],
[
2
,
6
],
[
3
,
7
],
[
8
,
12
],
[
9
,
13
],
[
10
,
14
],
[
11
,
15
]]
sharding_comm_list
=
[[
0
,
2
],
[
1
,
3
],
[
4
,
6
],
[
5
,
7
],
[
8
,
10
],
[
9
,
11
],
[
12
,
14
],
[
13
,
15
]]
np
.
testing
.
assert_array_equal
(
dp_comm_list
,
topo
.
get_comm_list
(
"dp"
))
np
.
testing
.
assert_array_equal
(
mp_comm_list
,
topo
.
get_comm_list
(
"mp"
))
np
.
testing
.
assert_array_equal
(
pp_comm_list
,
topo
.
get_comm_list
(
"pp"
))
np
.
testing
.
assert_array_equal
(
sharding_comm_list
,
topo
.
get_comm_list
(
"sharding"
))
# test get_hybrid_group_names
parallel_names
=
[
"dp"
,
"pp"
,
"sharding"
,
"mp"
]
np
.
testing
.
assert_array_equal
(
parallel_names
,
topo
.
get_hybrid_group_names
())
# test get_dims
np
.
testing
.
assert_array_equal
(
2
,
topo
.
get_dim
(
"dp"
))
np
.
testing
.
assert_array_equal
(
2
,
topo
.
get_dim
(
"mp"
))
np
.
testing
.
assert_array_equal
(
2
,
topo
.
get_dim
(
"pp"
))
np
.
testing
.
assert_array_equal
(
2
,
topo
.
get_dim
(
"sharding"
))
# test world size
self
.
assertEqual
(
topo
.
world_size
(),
16
)
# test get_rank
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
0
,
sharding
=
0
,
mp
=
0
),
0
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
0
,
sharding
=
0
,
mp
=
1
),
1
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
0
,
sharding
=
1
,
mp
=
0
),
2
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
0
,
sharding
=
1
,
mp
=
1
),
3
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
1
,
sharding
=
0
,
mp
=
0
),
4
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
1
,
sharding
=
0
,
mp
=
1
),
5
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
1
,
sharding
=
1
,
mp
=
0
),
6
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
0
,
pp
=
1
,
sharding
=
1
,
mp
=
1
),
7
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
0
,
sharding
=
0
,
mp
=
0
),
8
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
0
,
sharding
=
0
,
mp
=
1
),
9
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
0
,
sharding
=
1
,
mp
=
0
),
10
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
0
,
sharding
=
1
,
mp
=
1
),
11
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
1
,
sharding
=
0
,
mp
=
0
),
12
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
1
,
sharding
=
0
,
mp
=
1
),
13
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
1
,
sharding
=
1
,
mp
=
0
),
14
)
self
.
assertEqual
(
topo
.
get_rank
(
dp
=
1
,
pp
=
1
,
sharding
=
1
,
mp
=
1
),
15
)
# test get_coord
self
.
assertEqual
(
topo
.
get_coord
(
0
),
topo
.
coordinate
(
0
,
0
,
0
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
1
),
topo
.
coordinate
(
0
,
0
,
0
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
2
),
topo
.
coordinate
(
0
,
0
,
1
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
3
),
topo
.
coordinate
(
0
,
0
,
1
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
4
),
topo
.
coordinate
(
0
,
1
,
0
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
5
),
topo
.
coordinate
(
0
,
1
,
0
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
6
),
topo
.
coordinate
(
0
,
1
,
1
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
7
),
topo
.
coordinate
(
0
,
1
,
1
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
8
),
topo
.
coordinate
(
1
,
0
,
0
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
9
),
topo
.
coordinate
(
1
,
0
,
0
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
10
),
topo
.
coordinate
(
1
,
0
,
1
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
11
),
topo
.
coordinate
(
1
,
0
,
1
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
12
),
topo
.
coordinate
(
1
,
1
,
0
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
13
),
topo
.
coordinate
(
1
,
1
,
0
,
1
))
self
.
assertEqual
(
topo
.
get_coord
(
14
),
topo
.
coordinate
(
1
,
1
,
1
,
0
))
self
.
assertEqual
(
topo
.
get_coord
(
15
),
topo
.
coordinate
(
1
,
1
,
1
,
1
))
# test get_axis_list
self
.
assertEqual
(
topo
.
get_axis_list
(
"dp"
,
0
),
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"dp"
,
1
),
[
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"mp"
,
0
),
[
0
,
2
,
4
,
6
,
8
,
10
,
12
,
14
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"mp"
,
1
),
[
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"pp"
,
0
),
[
0
,
1
,
2
,
3
,
8
,
9
,
10
,
11
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"pp"
,
1
),
[
4
,
5
,
6
,
7
,
12
,
13
,
14
,
15
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"sharding"
,
0
),
[
0
,
1
,
4
,
5
,
8
,
9
,
12
,
13
])
self
.
assertEqual
(
topo
.
get_axis_list
(
"sharding"
,
1
),
[
2
,
3
,
6
,
7
,
10
,
11
,
14
,
15
])
# test get_dim_size
self
.
assertEqual
(
topo
.
get_dim_size
(
"dp"
),
2
)
self
.
assertEqual
(
topo
.
get_dim_size
(
"mp"
),
2
)
self
.
assertEqual
(
topo
.
get_dim_size
(
"pp"
),
2
)
self
.
assertEqual
(
topo
.
get_dim_size
(
"sharding"
),
2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py
浏览文件 @
f33f2444
...
...
@@ -124,6 +124,8 @@ class TestMultipleGpus(unittest.TestCase):
break
time
.
sleep
(
3
)
class
TestDataParallelGradientCheck
(
TestMultipleGpus
):
def
test_multiple_gpus_dynamic
(
self
):
self
.
run_mnist_2gpu
(
'parallel_dygraph_gradient_check.py'
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py
0 → 100644
浏览文件 @
f33f2444
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestHybridParallel
(
TestMultipleGpus
):
# check sharding logic as well as the accuracy with single mode
def
test_hybrid_parallel_sharding_logic
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_model.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录