Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PLSC
提交
1cf2e6d4
P
PLSC
项目概览
PaddlePaddle
/
PLSC
通知
10
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
5
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PLSC
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
5
Issue
5
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
1cf2e6d4
编写于
12月 31, 2019
作者:
D
danleifeng
提交者:
lilong12
12月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
edit mixed precision user interface (#20)
* edit fp16 user interface * edit fp16 doc
上级
f0978c37
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
94 addition
and
55 deletion
+94
-55
docs/mixed_precision.md
docs/mixed_precision.md
+24
-7
plsc/entry.py
plsc/entry.py
+66
-45
plsc/models/dist_algo.py
plsc/models/dist_algo.py
+4
-3
未找到文件。
docs/mixed_precision.md
浏览文件 @
1cf2e6d4
...
...
@@ -10,25 +10,42 @@ PLSC支持混合精度训练。使用混合精度训练可以提升训练的速
from
plsc
import
Entry
def
main
():
ins
=
Entry
()
ins
.
set_mixed_precision
(
True
,
1.0
)
ins
.
set_mixed_precision
(
True
)
ins
.
train
()
if
__name__
==
"__main__"
:
main
()
```
其中,
`set_mixed_precision`
函数介绍如下:
| API | 描述
| 参数说明
|
|
:------------------- | :--------------------| :----------------------
|
| set_mixed_precision
(use_fp16, loss_scaling) | 设置混合精度训练 |
`use_fp16`
为是否开启混合精度训练,默认为False;
`loss_scaling`
为初始的损失缩放值,默认为1.0|
| API | 描述 |
|
--- | ---
|
| set_mixed_precision
| 设置混合精度训练
-
`use_fp16`
:bool类型,当想要开启混合精度训练时,可将此参数设为True即可。
-
`loss_scaling`
:float类型,为初始的损失缩放值,这个值有可能会影响混合精度训练的精度,建议设为默认值1.0。
## 参数说明
set_mixed_precision 函数提供7个参数,其中use_fp16为必选项,决定是否开启混合精度训练,其他6个参数均有默认值,具体说明如下:
为了提高混合精度训练的稳定性和精度,默认开启了动态损失缩放机制。更多关于混合精度训练的介绍可参考:
[
混合精度训练
](
https://arxiv.org/abs/1710.03740
)
| 参数 | 类型 | 默认值| 说明
| --- | --- | ---|---|
|use_fp16| bool | 无,需用户设定| 是否开启混合精度训练,设为True为开启混合精度训练
|init_loss_scaling| float | 1.0|初始的损失缩放值,这个值有可能会影响混合精度训练的精度,建议设为默认值
|incr_every_n_steps | int | 2000|累计迭代
`incr_every_n_steps`
步都没出现FP16的越界,loss_scaling则会增加
`incr_ratio`
倍,建议设为默认值
|decr_every_n_nan_or_inf| int | 2|累计迭代
`decr_every_n_nan_or_inf`
步出现了FP16的越界,loss_scaling则会缩小为原来的
`decr_ratio`
倍,建议设为默认值
|incr_ratio |float|2.0|扩大loss_scaling的倍数,建议设为默认值
|decr_ratio| float |0.5| 缩小loss_scaling的倍数,建议设为默认值
|use_dynamic_loss_scaling | bool | True| 是否使用动态损失缩放机制。如果开启,才会用到
`incr_every_n_steps`
,
`decr_every_n_nan_or_inf`
,
`incr_ratio`
,
`decr_ratio`
四个参数,开启会提高混合精度训练的稳定性和精度,建议设为默认值
|amp_lists|AutoMixedPrecisionLists类|None|自动混合精度列表类,可以指定具体使用fp16计算的operators列表,建议设为默认值
更多关于混合精度训练的介绍可参考:
-
Paper:
[
MIXED PRECISION TRAINING
](
https://arxiv.org/abs/1710.03740
)
-
Nvidia Introduction:
[
Training With Mixed Precision
](
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
)
## 训练性能
配置: Nvidia Tesla v100 GPU 单机8卡
| 模型
\速
度 | FP32训练 | 混合精度训练 | 加速比 |
| --- | --- | --- | --- |
| ResNet50 | 2567.96 images/s | 3643.11 images/s | 1.42 |
备注:上述模型训练使用的loss_type均为'dist_arcface'。
plsc/entry.py
浏览文件 @
1cf2e6d4
...
...
@@ -100,7 +100,6 @@ class Entry(object):
self
.
fs_dir
=
None
self
.
use_fp16
=
False
self
.
init_loss_scaling
=
1.0
self
.
fp16_user_dict
=
None
self
.
val_targets
=
self
.
config
.
val_targets
...
...
@@ -149,16 +148,30 @@ class Entry(object):
self
.
global_train_batch_size
=
batch_size
*
self
.
num_trainers
logger
.
info
(
"Set train batch size to {}."
.
format
(
batch_size
))
def
set_mixed_precision
(
self
,
use_fp16
,
loss_scaling
):
def
set_mixed_precision
(
self
,
use_fp16
,
init_loss_scaling
=
1.0
,
incr_every_n_steps
=
2000
,
decr_every_n_nan_or_inf
=
2
,
incr_ratio
=
2.0
,
decr_ratio
=
0.5
,
use_dynamic_loss_scaling
=
True
,
amp_lists
=
None
):
"""
Whether to use mixed precision training.
"""
self
.
use_fp16
=
use_fp16
self
.
init_loss_scaling
=
loss_scaling
self
.
fp16_user_dict
=
dict
()
self
.
fp16_user_dict
[
'init_loss_scaling'
]
=
self
.
init_loss_scaling
self
.
fp16_user_dict
[
'init_loss_scaling'
]
=
init_loss_scaling
self
.
fp16_user_dict
[
'incr_every_n_steps'
]
=
incr_every_n_steps
self
.
fp16_user_dict
[
'decr_every_n_nan_or_inf'
]
=
decr_every_n_nan_or_inf
self
.
fp16_user_dict
[
'incr_ratio'
]
=
incr_ratio
self
.
fp16_user_dict
[
'decr_ratio'
]
=
decr_ratio
self
.
fp16_user_dict
[
'use_dynamic_loss_scaling'
]
=
use_dynamic_loss_scaling
self
.
fp16_user_dict
[
'amp_lists'
]
=
amp_lists
logger
.
info
(
"Use mixed precision training: {}."
.
format
(
use_fp16
))
logger
.
info
(
"Set init loss scaling to {}."
.
format
(
loss_scaling
))
for
key
in
self
.
fp16_user_dict
:
logger
.
info
(
"Set init {} to {}."
.
format
(
key
,
self
.
fp16_user_dict
[
key
]))
def
set_test_batch_size
(
self
,
batch_size
):
self
.
test_batch_size
=
batch_size
...
...
@@ -313,7 +326,15 @@ class Entry(object):
fp16_user_dict
=
self
.
fp16_user_dict
)
elif
self
.
use_fp16
:
self
.
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
=
optimizer
,
init_loss_scaling
=
self
.
init_loss_scaling
)
optimizer
=
optimizer
,
init_loss_scaling
=
self
.
fp16_user_dict
[
'init_loss_scaling'
],
incr_every_n_steps
=
self
.
fp16_user_dict
[
'incr_every_n_steps'
],
decr_every_n_nan_or_inf
=
self
.
fp16_user_dict
[
'decr_every_n_nan_or_inf'
],
incr_ratio
=
self
.
fp16_user_dict
[
'incr_ratio'
],
decr_ratio
=
self
.
fp16_user_dict
[
'decr_ratio'
],
use_dynamic_loss_scaling
=
self
.
fp16_user_dict
[
'use_dynamic_loss_scaling'
],
amp_lists
=
self
.
fp16_user_dict
[
'amp_lists'
]
)
return
self
.
optimizer
def
build_program
(
self
,
...
...
plsc/models/dist_algo.py
浏览文件 @
1cf2e6d4
...
...
@@ -41,13 +41,13 @@ class DistributedClassificationOptimizer(Optimizer):
def
init_fp16_params
(
self
,
loss_type
,
fp16_user_dict
):
# set default value for fp16_params_dict
fp16_params_dict
=
dict
()
fp16_params_dict
[
'amp_lists'
]
=
None
fp16_params_dict
[
'init_loss_scaling'
]
=
1.0
fp16_params_dict
[
'incr_every_n_steps'
]
=
1000
fp16_params_dict
[
'decr_every_n_nan_or_inf'
]
=
2
fp16_params_dict
[
'incr_ratio'
]
=
2.0
fp16_params_dict
[
'decr_ratio'
]
=
0.5
fp16_params_dict
[
'use_dynamic_loss_scaling'
]
=
True
fp16_params_dict
[
'amp_lists'
]
=
None
if
fp16_user_dict
is
not
None
:
# update fp16_params_dict
for
key
in
fp16_user_dict
:
...
...
@@ -56,8 +56,9 @@ class DistributedClassificationOptimizer(Optimizer):
else
:
logging
.
warning
(
"Can't find name '%s' in our fp16_params_dict. "
"Please check your dict key. You can set fp16 params only "
"in [amp_lists, init_loss_scaling, decr_every_n_nan_or_inf, "
"incr_ratio, decr_ratio, use_dynamic_loss_scaling]."
%
(
key
))
"in [init_loss_scaling, incr_every_n_steps, "
"decr_every_n_nan_or_inf, incr_ratio, decr_ratio, "
"use_dynamic_loss_scaling, amp_lists]"
%
(
key
))
self
.
_amp_lists
=
fp16_params_dict
[
'amp_lists'
]
if
self
.
_amp_lists
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录