Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0c2a51d2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0c2a51d2
编写于
11月 30, 2020
作者:
W
WangXi
提交者:
GitHub
11月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimizer amp, all use fp16 communication, overlap last comm and compute (#28957)
上级
0b032fae
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
75 addition
and
21 deletion
+75
-21
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
+7
-5
paddle/fluid/operators/amp/update_loss_scaling_op.cu
paddle/fluid/operators/amp/update_loss_scaling_op.cu
+2
-1
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
...paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
+9
-0
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+28
-12
python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py
...le/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py
+29
-3
未找到文件。
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
浏览文件 @
0c2a51d2
...
...
@@ -20,8 +20,9 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
__global__
void
GpuInverse
(
const
T
*
s
,
T
*
o
)
{
__global__
void
InverseAndMemset
(
const
T
*
s
,
T
*
o
,
bool
*
found_inf
)
{
*
o
=
Inverse
<
T
>
(
*
s
);
*
found_inf
=
false
;
}
template
<
typename
T
>
...
...
@@ -30,10 +31,11 @@ __global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
num
)
{
if
(
!
isfinite
(
in
[
idx
]))
{
T
val
=
in
[
idx
]
*
(
*
scale
);
out
[
idx
]
=
val
;
if
(
!
isfinite
(
val
))
{
*
found_inf
=
true
;
}
out
[
idx
]
=
*
found_inf
?
in
[
idx
]
:
in
[
idx
]
*
(
*
scale
);
}
}
...
...
@@ -49,13 +51,13 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
const
T
*
scale_data
=
scale
->
data
<
T
>
();
bool
*
found_inf_data
=
found_inf
->
mutable_data
<
bool
>
(
dev_ctx
.
GetPlace
());
cudaMemset
(
found_inf_data
,
false
,
found_inf
->
numel
()
*
sizeof
(
bool
));
framework
::
Tensor
inverse_scale
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
({
1
},
dev_ctx
);
T
*
inverse_scale_v
=
inverse_scale
.
template
data
<
T
>();
GpuInverse
<
T
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
scale_data
,
inverse_scale_v
);
InverseAndMemset
<
T
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
scale_data
,
inverse_scale_v
,
found_inf_data
);
for
(
size_t
i
=
0
;
i
<
xs
.
size
();
++
i
)
{
const
auto
*
x
=
xs
[
i
];
...
...
paddle/fluid/operators/amp/update_loss_scaling_op.cu
浏览文件 @
0c2a51d2
...
...
@@ -61,13 +61,14 @@ class LazyZeroInputs<platform::CUDADeviceContext, T> {
bool
has_inf
{
false
};
memory
::
Copy
(
platform
::
CPUPlace
(),
&
has_inf
,
gpu_place
,
found_inf_data
,
sizeof
(
bool
),
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
// wait async copy
if
(
has_inf
)
{
VLOG
(
1
)
<<
"-- UpdateLossScaling: Infinite values are found in grads. --"
;
for
(
size_t
i
=
0
;
i
<
xs
.
size
();
++
i
)
{
auto
*
out
=
outs
[
i
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int
num
=
out
->
numel
();
cudaMemset
(
out_data
,
0
,
num
*
sizeof
(
T
));
cudaMemset
Async
(
out_data
,
0
,
num
*
sizeof
(
T
),
dev_ctx
.
stream
(
));
}
}
}
...
...
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
浏览文件 @
0c2a51d2
...
...
@@ -53,6 +53,15 @@ class AMPOptimizer(MetaOptimizerBase):
config
[
'incr_ratio'
],
config
[
'decr_ratio'
],
config
[
'use_dynamic_loss_scaling'
])
# if worker_num > 1, all cards will communication with each other,
# add is_distributed to optimize amp, overlap communication and
# computation by split the check_finite_and_unscale op.
is_distributed
=
self
.
role_maker
.
_worker_num
()
>
1
if
self
.
user_defined_strategy
.
sharding
:
# FIXME(wangxi). sharding failed when split check_finite_and_unscale
is_distributed
=
False
self
.
wrapped_opt
.
_set_distributed
(
is_distributed
)
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
return
False
...
...
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
0c2a51d2
...
...
@@ -61,6 +61,7 @@ class OptimizerWithMixedPrecision(object):
self
.
_param_grads
=
None
self
.
_train_program
=
None
self
.
_is_distributed
=
False
self
.
_scaled_loss
=
None
self
.
_loss_scaling
=
None
self
.
_init_loss_scaling
=
init_loss_scaling
...
...
@@ -73,6 +74,12 @@ class OptimizerWithMixedPrecision(object):
self
.
_num_good_steps
=
None
self
.
_num_bad_steps
=
None
def
_set_distributed
(
self
,
flag
):
# if distributed, all cards will communication with each other,
# overlap communication and computation by split the
# check_finite_and_unscale op.
self
.
_is_distributed
=
flag
def
get_loss_scaling
(
self
):
"""Return the real-time loss scaling factor.
"""
...
...
@@ -168,13 +175,28 @@ class OptimizerWithMixedPrecision(object):
"""
grads
=
[
g
for
_
,
g
in
params_grads
]
with
self
.
_train_program
.
_optimized_guard
(
grads
):
grads
,
found_inf
=
check_finite_and_unscale
(
grads
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
)
if
not
self
.
_is_distributed
:
with
self
.
_train_program
.
_optimized_guard
(
grads
):
grads
,
found_inf
=
check_finite_and_unscale
(
grads
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
)
else
:
# if distributed, split check_finite_and_unscale to overlap
# unscale with communication
found_infs
=
[]
for
p
,
g
in
params_grads
:
with
self
.
_train_program
.
_optimized_guard
([
p
,
g
]):
_
,
found_inf
=
check_finite_and_unscale
(
[
g
,
],
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
)
found_infs
.
append
(
found_inf
)
if
self
.
_use_dynamic_loss_scaling
:
with
self
.
_train_program
.
_optimized_guard
(
grads
):
grads
=
update_loss_scaling
(
if
self
.
_is_distributed
:
with
self
.
_train_program
.
_optimized_guard
([]):
all_infs
=
layers
.
concat
(
found_infs
)
found_inf
=
layers
.
reduce_any
(
all_infs
)
with
self
.
_train_program
.
_optimized_guard
([]):
update_loss_scaling
(
grads
,
found_inf
,
self
.
_loss_scaling
,
...
...
@@ -186,13 +208,7 @@ class OptimizerWithMixedPrecision(object):
self
.
_decr_ratio
,
name
=
"update_loss_scaling"
)
params_unscaled_grads
=
[]
for
pg
,
new_g
in
zip
(
params_grads
,
grads
):
params_unscaled_grads
.
append
((
pg
[
0
],
new_g
))
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
optimize_ops
=
self
.
_optimizer
.
apply_gradients
(
params_unscaled_grads
)
optimize_ops
=
self
.
_optimizer
.
apply_gradients
(
params_grads
)
return
optimize_ops
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
...
...
python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py
浏览文件 @
0c2a51d2
...
...
@@ -19,6 +19,7 @@ import paddle.distributed.fleet as fleet
from
paddle.distributed.fleet.meta_optimizers
import
AMPOptimizer
import
os
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
import
paddle.distributed.fleet.base.role_maker
as
role_maker
paddle
.
enable_static
()
...
...
@@ -32,7 +33,10 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer):
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
AMPOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
self
.
set_strategy
(
strategy
,
'amp'
)
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
opt
.
_set_basic_info
(
avg_cost
,
role
,
opt
,
strategy
)
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
...
...
@@ -47,7 +51,10 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer):
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
AMPOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
self
.
set_strategy
(
strategy
,
'amp'
)
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
opt
.
_set_basic_info
(
avg_cost
,
role
,
opt
,
strategy
)
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
opt
.
apply_gradients
(
params_grads
)
...
...
@@ -64,7 +71,10 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer):
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
AMPOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
self
.
set_strategy
(
strategy
,
'amp'
)
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
opt
.
_set_basic_info
(
avg_cost
,
role
,
opt
,
strategy
)
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
opt
.
apply_optimize
(
avg_cost
,
startup_prog
,
params_grads
)
...
...
@@ -83,6 +93,22 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer):
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
def
test_amp_distributed_optimizer
(
self
):
""" test amp when distributed """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'amp'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
check_count
=
0
for
name
in
ops
:
if
name
==
'check_finite_and_unscale'
:
check_count
+=
1
self
.
assertEqual
(
check_count
,
len
(
train_prog
.
all_parameters
()))
def
test_amp_recompute_optimizer
(
self
):
""" test amp + recompute """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录