Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4c4d3185
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4c4d3185
编写于
7月 19, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
7月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Sharding stage 1 tensor fusion (#55427)
上级
f7cbfc4c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
477 addition
and
11 deletion
+477
-11
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+5
-0
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
...ptimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
+77
-11
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py
...ted/fleet/meta_parallel/sharding/group_sharded_storage.py
+14
-0
python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
...on/paddle/distributed/fleet/utils/tensor_fusion_helper.py
+192
-0
test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py
...ctive/fleet/hybrid_parallel_sharding_model_with_fusion.py
+186
-0
test/collective/fleet/test_parallel_dygraph_sharding_parallel.py
...llective/fleet/test_parallel_dygraph_sharding_parallel.py
+3
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
4c4d3185
...
@@ -66,6 +66,10 @@ message PpConfig {
...
@@ -66,6 +66,10 @@ message PpConfig {
optional
bool
profiling
=
5
[
default
=
false
];
optional
bool
profiling
=
5
[
default
=
false
];
}
}
message
DygraphShardingConfig
{
optional
bool
tensor_fusion
=
1
[
default
=
false
];
}
message
HybridConfig
{
message
HybridConfig
{
optional
int32
dp_degree
=
1
[
default
=
-
1
];
optional
int32
dp_degree
=
1
[
default
=
-
1
];
optional
int32
mp_degree
=
2
[
default
=
1
];
optional
int32
mp_degree
=
2
[
default
=
1
];
...
@@ -73,6 +77,7 @@ message HybridConfig {
...
@@ -73,6 +77,7 @@ message HybridConfig {
optional
int32
sharding_degree
=
4
[
default
=
1
];
optional
int32
sharding_degree
=
4
[
default
=
1
];
optional
MpConfig
mp_configs
=
5
;
optional
MpConfig
mp_configs
=
5
;
optional
PpConfig
pp_configs
=
6
;
optional
PpConfig
pp_configs
=
6
;
optional
DygraphShardingConfig
sharding_configs
=
7
;
}
}
message
AMPConfig
{
message
AMPConfig
{
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
浏览文件 @
4c4d3185
...
@@ -18,8 +18,10 @@ from functools import reduce
...
@@ -18,8 +18,10 @@ from functools import reduce
import
paddle
import
paddle
from
paddle
import
framework
from
paddle
import
framework
from
paddle.distributed
import
fleet
from
...utils.log_util
import
logger
from
...utils.log_util
import
logger
from
...utils.tensor_fusion_helper
import
fused_parameters
def
_is_trainable
(
param
):
def
_is_trainable
(
param
):
...
@@ -62,15 +64,53 @@ class DygraphShardingOptimizer:
...
@@ -62,15 +64,53 @@ class DygraphShardingOptimizer:
self
.
_sharding_world_size
=
self
.
_hcg
.
get_sharding_parallel_world_size
()
self
.
_sharding_world_size
=
self
.
_hcg
.
get_sharding_parallel_world_size
()
self
.
_sharding_rank
=
self
.
_hcg
.
get_sharding_parallel_rank
()
self
.
_sharding_rank
=
self
.
_hcg
.
get_sharding_parallel_rank
()
strategy
=
fleet
.
fleet
.
_user_defined_strategy
self
.
tensor_fusion
=
strategy
.
hybrid_configs
[
'sharding_configs'
].
tensor_fusion
pp_overlap
=
strategy
.
hybrid_configs
[
'pp_configs'
].
sharding_comm_overlap
if
self
.
tensor_fusion
:
assert
(
not
pp_overlap
),
"Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time."
self
.
_rank2params
=
self
.
_partition_parameters
()
self
.
_rank2params
=
self
.
_partition_parameters
()
self
.
_param2rank
=
self
.
_map_param_to_rank
()
self
.
_param2rank
=
self
.
_map_param_to_rank
()
self
.
_set_inner_opt_attr
(
if
not
self
.
tensor_fusion
:
'_parameter_list'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
self
.
_set_inner_opt_attr
(
)
'_parameter_list'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
self
.
_set_inner_opt_attr
(
)
'_param_groups'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
self
.
_set_inner_opt_attr
(
)
'_param_groups'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
)
else
:
self
.
_use_main_grad
=
hasattr
(
self
.
_parameter_list
[
0
],
"main_grad"
)
self
.
_rank2decay
=
{}
self
.
_rank2fused
=
{}
self
.
_tensor_fusion
()
decay_params
=
[
p
.
name
for
p
in
self
.
_rank2decay
[
self
.
_sharding_rank
]
]
all_params
=
self
.
_rank2fused
[
self
.
_sharding_rank
]
apply_decay_param_fun
=
lambda
x
:
x
in
decay_params
params
=
[]
for
v
in
self
.
_rank2fused
.
values
():
params
+=
v
self
.
_parameter_list
=
params
self
.
_param_groups
=
params
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
all_params
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
all_params
)
origin_decay_param_fun
=
getattr
(
self
.
_inner_opt
,
'_apply_decay_param_fun'
,
None
)
if
origin_decay_param_fun
is
not
None
:
self
.
_set_inner_opt_attr
(
'_apply_decay_param_fun'
,
apply_decay_param_fun
)
def
clear_grad
(
self
,
set_to_zero
=
True
):
def
clear_grad
(
self
,
set_to_zero
=
True
):
"""
"""
...
@@ -85,7 +125,25 @@ class DygraphShardingOptimizer:
...
@@ -85,7 +125,25 @@ class DygraphShardingOptimizer:
p
.
main_grad
.
_clear
()
p
.
main_grad
.
_clear
()
p
.
main_grad
=
None
p
.
main_grad
=
None
elif
not
hasattr
(
p
,
"main_grad"
):
elif
not
hasattr
(
p
,
"main_grad"
):
p
.
clear_gradient
(
set_to_zero
)
if
self
.
tensor_fusion
:
if
set_to_zero
:
p
.
grad
.
zero_
()
else
:
p
.
grad
.
_clear
()
p
.
grad
=
None
else
:
p
.
clear_gradient
(
set_to_zero
)
def
_tensor_fusion
(
self
):
for
i
in
range
(
self
.
_sharding_world_size
):
params
=
self
.
_rank2params
[
i
]
decay_fused
,
all_fused
=
fused_parameters
(
params
,
self
.
_use_main_grad
)
self
.
_rank2decay
[
i
]
=
decay_fused
self
.
_rank2fused
[
i
]
=
all_fused
for
p
in
all_fused
:
self
.
_param2rank
[
p
.
name
]
=
i
def
_partition_parameters
(
self
):
def
_partition_parameters
(
self
):
"""
"""
...
@@ -167,7 +225,12 @@ class DygraphShardingOptimizer:
...
@@ -167,7 +225,12 @@ class DygraphShardingOptimizer:
logger
.
debug
(
"sharding start sync parameters"
)
logger
.
debug
(
"sharding start sync parameters"
)
with
framework
.
no_grad
():
with
framework
.
no_grad
():
# TODO detach not need (?)
# TODO detach not need (?)
for
rank
,
params
in
self
.
_rank2params
.
items
():
valid_rank_to_params
=
(
self
.
_rank2params
if
not
self
.
tensor_fusion
else
self
.
_rank2fused
)
for
rank
,
params
in
valid_rank_to_params
.
items
():
for
param
in
params
:
for
param
in
params
:
paddle
.
distributed
.
broadcast
(
paddle
.
distributed
.
broadcast
(
param
,
param
,
...
@@ -236,9 +299,12 @@ class DygraphShardingOptimizer:
...
@@ -236,9 +299,12 @@ class DygraphShardingOptimizer:
params_grads
=
self
.
_inner_opt
.
_grad_clip
(
params_grads
)
params_grads
=
self
.
_inner_opt
.
_grad_clip
(
params_grads
)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self
.
_set_inner_opt_attr
(
'_grad_clip'
,
None
)
self
.
_set_inner_opt_attr
(
'_grad_clip'
,
None
)
update_param_names
=
[
rank_params
=
(
p
.
name
for
p
in
self
.
_rank2params
[
self
.
_sharding_rank
]
self
.
_rank2params
[
self
.
_sharding_rank
]
]
if
not
self
.
tensor_fusion
else
self
.
_rank2fused
[
self
.
_sharding_rank
]
)
update_param_names
=
[
p
.
name
for
p
in
rank_params
]
update_params_grads
=
[
update_params_grads
=
[
(
p
,
g
)
for
p
,
g
in
params_grads
if
p
.
name
in
update_param_names
(
p
,
g
)
for
p
,
g
in
params_grads
if
p
.
name
in
update_param_names
]
]
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py
浏览文件 @
4c4d3185
...
@@ -30,6 +30,14 @@ from paddle.framework import core
...
@@ -30,6 +30,14 @@ from paddle.framework import core
from
.group_sharded_utils
import
Type
,
cvt_to_device
,
device_guard
from
.group_sharded_utils
import
Type
,
cvt_to_device
,
device_guard
class
BufferWarper
(
core
.
eager
.
Tensor
):
def
__init__
(
self
):
super
().
__init__
()
self
.
need_clip
=
True
self
.
is_distributed
=
False
self
.
trainable
=
True
class
InternalStorage
:
class
InternalStorage
:
"""
"""
This is a basic class, which is responsible for consolidating the basic storage tensor.
This is a basic class, which is responsible for consolidating the basic storage tensor.
...
@@ -97,6 +105,12 @@ class InternalStorage:
...
@@ -97,6 +105,12 @@ class InternalStorage:
self
.
buffer
=
self
.
buffer
.
cast
(
dtype
=
dtype
)
self
.
buffer
=
self
.
buffer
.
cast
(
dtype
=
dtype
)
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
def
warp_buffer
(
self
):
tmp_buffer
=
BufferWarper
()
self
.
_buffer
=
self
.
buffer
tmp_buffer
.
get_tensor
().
_share_data_with
(
self
.
buffer
.
get_tensor
())
self
.
buffer
=
tmp_buffer
class
ParamStorage
(
InternalStorage
):
class
ParamStorage
(
InternalStorage
):
"""
"""
...
...
python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
0 → 100644
浏览文件 @
4c4d3185
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
itertools
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle
from
paddle.framework
import
core
alignment
=
{
"gpu"
:
256
,
}
align
=
{
paddle
.
float16
.
value
:
2
,
paddle
.
bfloat16
.
value
:
2
,
paddle
.
float32
.
value
:
4
,
}
def
assign_group_by_size
(
parameters
,
group_size
=
256
*
1024
*
1024
):
# TODO(Yuang Liu): make pp_utils/utils use this tensor fusion helper
is_sparse_gradient
=
[
False
]
*
len
(
parameters
)
group_indices
=
core
.
eager_assign_group_by_size
(
parameters
,
is_sparse_gradient
,
[
group_size
,
group_size
]
)
var_groups
=
OrderedDict
()
for
group_idx
,
indices
in
enumerate
(
group_indices
):
for
index
in
indices
:
var_groups
.
setdefault
(
group_idx
,
[]).
append
(
parameters
[
index
])
return
var_groups
def
flatten_dense_tensors
(
parameters
,
use_main_grad
):
from
paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage
import
(
GradStorage
,
ParamStorage
,
)
_buffer_size
=
0
_param2align
=
{}
dtype
=
parameters
[
0
].
dtype
for
param
in
parameters
:
assert
param
.
trainable
,
"param must be trainable..."
size
=
np
.
prod
(
param
.
shape
)
*
align
[
dtype
]
remaining
=
size
%
alignment
[
"gpu"
]
ali
=
0
if
remaining
==
0
else
alignment
[
"gpu"
]
-
remaining
align_
=
ali
//
align
[
dtype
]
_buffer_size
+=
np
.
prod
(
param
.
shape
)
+
align_
_param2align
[
param
.
name
]
=
align_
param_storage
=
ParamStorage
(
size
=
_buffer_size
,
dtype
=
dtype
,
device
=
"gpu"
)
param_storage
.
add_rank_params
(
parameters
,
_param2align
)
# process gradient
grad_dtype
=
paddle
.
float32
if
use_main_grad
else
dtype
grad_storage
=
GradStorage
(
size
=
_buffer_size
,
dtype
=
grad_dtype
,
device
=
"gpu"
,
destination
=
"0"
,
parm2align
=
_param2align
,
)
for
param
in
parameters
:
grad_storage
.
add_grad
(
param
,
_param2align
[
param
.
name
])
param_storage
.
warp_buffer
()
grad_storage
.
warp_buffer
()
if
not
use_main_grad
:
# param_storage --> grad_storage
param_storage
.
buffer
.
_copy_gradient_from
(
grad_storage
.
buffer
)
else
:
param_storage
.
buffer
.
main_grad
=
grad_storage
.
buffer
param_storage
.
buffer
.
stop_gradient
=
False
return
param_storage
,
grad_storage
def
obtain_storage
(
parameters
,
use_main_grad
,
clip
,
dist
):
if
len
(
parameters
)
<
1
:
return
[]
var_groups
=
assign_group_by_size
(
parameters
)
storage
=
[]
for
group_idx
,
parameters
in
var_groups
.
items
():
param_storage
,
grad_storage
=
flatten_dense_tensors
(
parameters
,
use_main_grad
)
param_storage
.
buffer
.
need_clip
=
clip
param_storage
.
buffer
.
is_distributed
=
dist
storage
.
append
(
param_storage
.
buffer
)
return
storage
def
filter_params
(
params
,
is_fp32
,
is_distributed
,
need_clip
):
params
=
list
(
filter
(
lambda
x
:
x
.
is_distributed
if
is_distributed
else
(
not
x
.
is_distributed
),
params
,
)
)
params
=
list
(
filter
(
lambda
x
:
getattr
(
x
,
'need_clip'
,
True
)
if
need_clip
else
(
not
getattr
(
x
,
'need_clip'
,
True
)),
params
,
)
)
params
=
list
(
filter
(
lambda
x
:
x
.
dtype
==
paddle
.
float32
if
is_fp32
else
x
.
dtype
!=
paddle
.
float32
,
params
,
)
)
dtype
=
None
for
p
in
params
:
if
dtype
is
None
:
dtype
=
p
.
dtype
else
:
assert
dtype
==
p
.
dtype
return
params
,
dtype
def
fused_parameters
(
parameters
,
use_main_grad
):
param_groups
=
[]
attrs
=
[]
is_fp32
=
[
True
,
False
]
is_distributed
=
[
True
,
False
]
need_clip
=
[
True
,
False
]
no_fp32_dtype
=
None
for
fp32
,
dist
,
clip
in
itertools
.
product
(
is_fp32
,
is_distributed
,
need_clip
):
params
,
dtype
=
filter_params
(
parameters
,
fp32
,
dist
,
clip
)
if
not
fp32
:
if
no_fp32_dtype
is
None
:
no_fp32_dtype
=
dtype
elif
dtype
is
not
None
:
assert
no_fp32_dtype
==
dtype
attrs
.
append
([
dtype
,
dist
,
clip
])
param_groups
.
append
(
params
)
decay_fused
=
[]
all_fused
=
[]
for
params
,
attr
in
zip
(
param_groups
,
attrs
):
decay_params
=
[]
other_params
=
[]
for
param
in
params
:
if
not
any
(
nd
in
param
.
name
for
nd
in
[
"bias"
,
"norm"
,
"b_0"
]):
decay_params
.
append
(
param
)
else
:
other_params
.
append
(
param
)
is_distributed
=
attr
[
1
]
need_clip
=
attr
[
2
]
decay
=
obtain_storage
(
decay_params
,
use_main_grad
,
need_clip
,
is_distributed
)
other
=
obtain_storage
(
other_params
,
use_main_grad
,
need_clip
,
is_distributed
)
decay_fused
+=
decay
all_fused
+=
decay
all_fused
+=
other
return
decay_fused
,
all_fused
test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py
0 → 100644
浏览文件 @
4c4d3185
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
import
paddle
from
paddle.distributed
import
fleet
from
paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer
import
(
DygraphShardingOptimizer
,
)
vocab_size
=
20
hidden_size
=
10
inner_size
=
8
output_size
=
10
seq_length
=
2
batch_size
=
4
STEPS
=
10
class
SimpleDPNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
):
super
().
__init__
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
hidden_size
,
inner_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc1
)
),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)
),
)
self
.
linear2
=
paddle
.
nn
.
Linear
(
inner_size
,
hidden_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Assign
(
np_fc2
)
),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)
),
)
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
,
weight_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)
),
bias_attr
=
paddle
.
framework
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
)
),
)
self
.
embedding
=
paddle
.
nn
.
Embedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
),
)
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
paddle
.
matmul
(
x
,
self
.
embedding
.
weight
,
transpose_y
=
True
)
return
x
class
TestDistSharding
(
unittest
.
TestCase
):
def
setUp
(
self
):
random
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
paddle
.
seed
(
2021
)
self
.
strategy
=
fleet
.
DistributedStrategy
()
self
.
strategy
.
hybrid_configs
=
{
"sharding_degree"
:
2
,
"dp_degree"
:
1
,
"mp_degree"
:
1
,
"pp_degree"
:
1
,
}
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
tensor_fusion
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
self
.
strategy
)
self
.
data
=
np
.
random
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_length
,
),
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
self
.
batch_sharding
=
paddle
.
to_tensor
(
self
.
data
[:
2
])
else
:
self
.
batch_sharding
=
paddle
.
to_tensor
(
self
.
data
[
2
:])
self
.
batch_single
=
paddle
.
to_tensor
(
self
.
data
)
def
train_batch
(
self
,
batch
,
model
,
optimizer
):
output
=
model
(
batch
)
loss
=
output
.
mean
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
return
loss
def
build_optimizer
(
self
,
model
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.5
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.001
,
grad_clip
=
clip
,
)
return
optimizer
def
build_model_optimizer
(
self
):
np_fc1
=
np
.
random
.
random_sample
((
hidden_size
,
inner_size
))
np_fc2
=
np
.
random
.
random_sample
((
inner_size
,
hidden_size
))
model_a
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
)
optimizer_a
=
self
.
build_optimizer
(
model_a
)
model_b
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
,
np_fc1
,
np_fc2
)
optimizer_b
=
self
.
build_optimizer
(
model_b
)
model_a
=
fleet
.
distributed_model
(
model_a
)
optimizer_a
=
fleet
.
distributed_optimizer
(
optimizer_a
)
return
model_a
,
optimizer_a
,
model_b
,
optimizer_b
def
sharding_model
(
self
):
(
model_a
,
optimizer_a
,
model_b
,
optimizer_b
,
)
=
self
.
build_model_optimizer
()
self
.
assertTrue
(
isinstance
(
optimizer_a
.
_inner_opt
,
DygraphShardingOptimizer
)
)
for
idx
in
range
(
STEPS
):
loss_a
=
self
.
train_batch
(
self
.
batch_sharding
,
model_a
,
optimizer_a
)
loss_b
=
self
.
train_batch
(
self
.
batch_single
,
model_b
,
optimizer_b
)
np
.
testing
.
assert_allclose
(
loss_a
,
loss_b
,
rtol
=
1e-6
,
atol
=
1e-6
)
for
j
in
range
(
len
(
model_a
.
parameters
())):
np
.
testing
.
assert_allclose
(
model_a
.
parameters
()[
j
].
numpy
(),
model_b
.
parameters
()[
j
].
numpy
(),
rtol
=
1e-6
,
atol
=
1e-7
,
)
def
test_sharding_adam
(
self
):
self
.
sharding_model
()
if
__name__
==
"__main__"
:
unittest
.
main
()
test/collective/fleet/test_parallel_dygraph_sharding_parallel.py
浏览文件 @
4c4d3185
...
@@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus):
...
@@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus):
def
test_hybrid_parallel_sharding_logic
(
self
):
def
test_hybrid_parallel_sharding_logic
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_model.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_model.py'
)
def
test_hybrid_parallel_sharding_tensor_fusion
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_model_with_fusion.py'
)
def
test_hybrid_parallel_sharding_state_dict
(
self
):
def
test_hybrid_parallel_sharding_state_dict
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_state_dict.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_state_dict.py'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录