Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6a3941e3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
6a3941e3
编写于
10月 20, 2021
作者:
H
Haohongxiang
提交者:
GitHub
10月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs of ClipGradByGlobalNorm in HybridParallel (#36555)
* fix bugs of ClipGradByGlobalNorm * add unittests * add unittests
上级
17b4dd70
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
128 addition
and
20 deletion
+128
-20
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
...optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
+58
-20
python/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py
...n/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py
+59
-0
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
...on/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
+4
-0
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
...n/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
+4
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py
.../tests/unittests/test_parallel_dygraph_tensor_parallel.py
+3
-0
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
浏览文件 @
6a3941e3
...
...
@@ -50,8 +50,11 @@ class HybridParallelClipGrad:
@
imperative_base
.
no_grad
def
_dygraph_clip
(
self
,
params_grads
):
params_and_grads
=
[]
sum_square_list_dist
=
[]
sum_square_list_not_dist
=
[]
sum_square_dist_fp16
=
[]
sum_square_dist_fp32
=
[]
sum_square_not_dist_fp16
=
[]
sum_square_not_dist_fp32
=
[]
for
p
,
g
in
params_grads
:
if
g
is
None
:
...
...
@@ -71,20 +74,51 @@ class HybridParallelClipGrad:
if
not_shared_enable
:
if
p
.
is_distributed
:
sum_square_list_dist
.
append
(
sum_square
)
if
p
.
dtype
==
paddle
.
float16
:
sum_square_dist_fp16
.
append
(
sum_square
)
elif
p
.
dtype
==
paddle
.
float32
:
sum_square_dist_fp32
.
append
(
sum_square
)
else
:
sum_square_list_not_dist
.
append
(
sum_square
)
global_norm_var_dist
=
layers
.
concat
(
sum_square_list_dist
)
if
len
(
sum_square_list_dist
)
!=
0
else
layers
.
concat
(
[
paddle
.
to_tensor
([
0.
])])
global_norm_var_dist
=
layers
.
reduce_sum
(
global_norm_var_dist
)
global_norm_var_not_dist
=
layers
.
concat
(
sum_square_list_not_dist
)
if
len
(
sum_square_list_not_dist
)
!=
0
else
layers
.
concat
(
[
paddle
.
to_tensor
([
0.
])])
global_norm_var_not_dist
=
layers
.
reduce_sum
(
global_norm_var_not_dist
)
if
p
.
dtype
==
paddle
.
float16
:
sum_square_not_dist_fp16
.
append
(
sum_square
)
elif
p
.
dtype
==
paddle
.
float32
:
sum_square_not_dist_fp32
.
append
(
sum_square
)
# global norm of distributed FP16 params_and_grads
if
len
(
sum_square_dist_fp16
)
==
0
:
global_norm_dist_fp16
=
paddle
.
to_tensor
([
0.
],
dtype
=
paddle
.
float32
)
else
:
global_norm_dist_fp16
=
layers
.
concat
(
sum_square_dist_fp16
)
global_norm_dist_fp16
=
layers
.
reduce_sum
(
global_norm_dist_fp16
)
global_norm_dist_fp16
=
paddle
.
cast
(
global_norm_dist_fp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP16 params_and_grads
if
len
(
sum_square_not_dist_fp16
)
==
0
:
global_norm_not_dist_fp16
=
paddle
.
to_tensor
(
[
0.
],
dtype
=
paddle
.
float32
)
else
:
global_norm_not_dist_fp16
=
layers
.
concat
(
sum_square_not_dist_fp16
)
global_norm_not_dist_fp16
=
layers
.
reduce_sum
(
global_norm_not_dist_fp16
)
global_norm_not_dist_fp16
=
paddle
.
cast
(
global_norm_not_dist_fp16
,
dtype
=
paddle
.
float32
)
# global norm of distributed FP32 params_and_grads
global_norm_dist_fp32
=
layers
.
concat
(
sum_square_dist_fp32
)
if
len
(
sum_square_dist_fp32
)
!=
0
else
paddle
.
to_tensor
(
[
0.
],
dtype
=
paddle
.
float32
)
global_norm_dist_fp32
=
layers
.
reduce_sum
(
global_norm_dist_fp32
)
# global norm of non-distributed FP32 params_and_grads
global_norm_not_dist_fp32
=
layers
.
concat
(
sum_square_not_dist_fp32
)
if
len
(
sum_square_not_dist_fp32
)
!=
0
else
paddle
.
to_tensor
(
[
0.
],
dtype
=
paddle
.
float32
)
global_norm_not_dist_fp32
=
layers
.
reduce_sum
(
global_norm_not_dist_fp32
)
global_norm_var_dist
=
global_norm_dist_fp16
+
global_norm_dist_fp32
global_norm_var_not_dist
=
global_norm_not_dist_fp16
+
global_norm_not_dist_fp32
# add all reduce to get global norm of distributed params_and_grads
if
self
.
_hcg
.
get_model_parallel_world_size
()
>
1
:
...
...
@@ -105,22 +139,26 @@ class HybridParallelClipGrad:
global_norm_var_not_dist
,
group
=
self
.
_hcg
.
get_sharding_parallel_group
())
global_norm_var
=
layers
.
sqrt
(
global_norm_var_dist
+
global_norm_var_not_dist
)
global_norm_var
_fp32
=
layers
.
sqrt
(
global_norm_var_dist
+
global_norm_var_not_dist
)
max_global_norm
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
global_norm_var
.
dtype
,
value
=
self
.
clip_norm
)
shape
=
[
1
],
dtype
=
global_norm_var
_fp32
.
dtype
,
value
=
self
.
clip_norm
)
clip_var
=
layers
.
elementwise_div
(
x
=
max_global_norm
,
y
=
layers
.
elementwise_max
(
x
=
global_norm_var
,
y
=
max_global_norm
))
x
=
global_norm_var_fp32
,
y
=
max_global_norm
))
clip_var_fp16
=
paddle
.
cast
(
clip_var
,
paddle
.
float16
)
for
p
,
g
in
params_grads
:
if
g
is
None
:
continue
if
getattr
(
p
,
'need_clip'
,
True
)
is
False
:
params_and_grads
.
append
((
p
,
g
))
continue
new_grad
=
layers
.
elementwise_mul
(
x
=
g
,
y
=
clip_var
)
if
p
.
dtype
==
paddle
.
float16
:
new_grad
=
layers
.
elementwise_mul
(
x
=
g
,
y
=
clip_var_fp16
)
else
:
new_grad
=
layers
.
elementwise_mul
(
x
=
g
,
y
=
clip_var
)
params_and_grads
.
append
((
p
,
new_grad
))
return
params_and_grads
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py
0 → 100644
浏览文件 @
6a3941e3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
numpy
as
np
from
hybrid_parallel_mp_model
import
TestDistMPTraning
import
paddle.distributed.fleet
as
fleet
import
unittest
class
TestMPFP16
(
TestDistMPTraning
):
def
build_optimizer
(
self
,
model
):
grad_clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
1.0
)
scheduler
=
paddle
.
optimizer
.
lr
.
ExponentialDecay
(
learning_rate
=
0.001
,
gamma
=
0.999
,
verbose
=
True
)
optimizer
=
paddle
.
optimizer
.
SGD
(
scheduler
,
grad_clip
=
grad_clip
,
parameters
=
model
.
parameters
())
model
,
optimizer
=
paddle
.
amp
.
decorate
(
models
=
model
,
optimizers
=
optimizer
,
level
=
'O2'
,
save_dtype
=
'float32'
)
return
optimizer
def
train_batch
(
self
,
batch
,
model
,
optimizer
,
is_mp
):
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
5160
)
if
is_mp
:
scaler
=
fleet
.
distributed_scaler
(
scaler
)
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
"O2"
):
output
=
model
(
batch
)
loss
=
output
.
mean
()
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
return
scaled
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
浏览文件 @
6a3941e3
...
...
@@ -61,11 +61,14 @@ class TestDistPPTraning(unittest.TestCase):
rank_id
=
dist
.
get_rank
()
set_random_seed
(
1024
,
dp_id
,
rank_id
)
grad_clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
1.0
)
#construct model a
model_a
=
AlexNet
(
10
)
scheduler_a
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer_a
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler_a
,
grad_clip
=
grad_clip
,
parameters
=
model_a
.
parameters
())
scaler_a
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
2
**
5
)
...
...
@@ -80,6 +83,7 @@ class TestDistPPTraning(unittest.TestCase):
scheduler_b
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer_b
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler_b
,
grad_clip
=
grad_clip
,
parameters
=
model_b
.
parameters
())
model_b
=
fleet
.
distributed_model
(
model_b
)
optimizer_b
=
fleet
.
distributed_optimizer
(
optimizer_b
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
浏览文件 @
6a3941e3
...
...
@@ -61,11 +61,14 @@ class TestDistPPTraning(unittest.TestCase):
rank_id
=
dist
.
get_rank
()
set_random_seed
(
1024
,
dp_id
,
rank_id
)
grad_clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
1.0
)
#construct model a
model_a
=
AlexNet
(
10
)
scheduler_a
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer_a
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler_a
,
grad_clip
=
grad_clip
,
parameters
=
model_a
.
parameters
())
scaler_a
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
2
**
5
)
...
...
@@ -75,6 +78,7 @@ class TestDistPPTraning(unittest.TestCase):
scheduler_b
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer_b
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler_b
,
grad_clip
=
grad_clip
,
parameters
=
model_b
.
parameters
())
param_len
=
len
(
model_a
.
parameters
())
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py
浏览文件 @
6a3941e3
...
...
@@ -30,6 +30,9 @@ class TestHybridParallel(TestMultipleGpus):
def
test_hybrid_parallel_mp_amp
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_amp.py'
)
def
test_hybrid_parallel_mp_fp16
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_fp16.py'
)
def
test_hybrid_parallel_mp_clip_grad
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_clip_grad.py'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录