Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PLSC
提交
1cf2e6d4
P
PLSC
项目概览
PaddlePaddle
/
PLSC
通知
10
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
5
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PLSC
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
5
Issue
5
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1cf2e6d4
编写于
12月 31, 2019
作者:
D
danleifeng
提交者:
lilong12
12月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
edit mixed precision user interface (#20)
* edit fp16 user interface * edit fp16 doc
上级
f0978c37
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
94 addition
and
55 deletion
+94
-55
docs/mixed_precision.md
docs/mixed_precision.md
+24
-7
plsc/entry.py
plsc/entry.py
+66
-45
plsc/models/dist_algo.py
plsc/models/dist_algo.py
+4
-3
未找到文件。
docs/mixed_precision.md
浏览文件 @
1cf2e6d4
...
...
@@ -10,25 +10,42 @@ PLSC支持混合精度训练。使用混合精度训练可以提升训练的速
from
plsc
import
Entry
def
main
():
ins
=
Entry
()
ins
.
set_mixed_precision
(
True
,
1.0
)
ins
.
set_mixed_precision
(
True
)
ins
.
train
()
if
__name__
==
"__main__"
:
main
()
```
其中,
`set_mixed_precision`
函数介绍如下:
| API | 描述
| 参数说明
|
|
:------------------- | :--------------------| :----------------------
|
| set_mixed_precision
(use_fp16, loss_scaling) | 设置混合精度训练 |
`use_fp16`
为是否开启混合精度训练,默认为False;
`loss_scaling`
为初始的损失缩放值,默认为1.0|
| API | 描述 |
|
--- | ---
|
| set_mixed_precision
| 设置混合精度训练
-
`use_fp16`
:bool类型,当想要开启混合精度训练时,可将此参数设为True即可。
-
`loss_scaling`
:float类型,为初始的损失缩放值,这个值有可能会影响混合精度训练的精度,建议设为默认值1.0。
## 参数说明
set_mixed_precision 函数提供7个参数,其中use_fp16为必选项,决定是否开启混合精度训练,其他6个参数均有默认值,具体说明如下:
为了提高混合精度训练的稳定性和精度,默认开启了动态损失缩放机制。更多关于混合精度训练的介绍可参考:
[
混合精度训练
](
https://arxiv.org/abs/1710.03740
)
| 参数 | 类型 | 默认值| 说明
| --- | --- | ---|---|
|use_fp16| bool | 无,需用户设定| 是否开启混合精度训练,设为True为开启混合精度训练
|init_loss_scaling| float | 1.0|初始的损失缩放值,这个值有可能会影响混合精度训练的精度,建议设为默认值
|incr_every_n_steps | int | 2000|累计迭代
`incr_every_n_steps`
步都没出现FP16的越界,loss_scaling则会增加
`incr_ratio`
倍,建议设为默认值
|decr_every_n_nan_or_inf| int | 2|累计迭代
`decr_every_n_nan_or_inf`
步出现了FP16的越界,loss_scaling则会缩小为原来的
`decr_ratio`
倍,建议设为默认值
|incr_ratio |float|2.0|扩大loss_scaling的倍数,建议设为默认值
|decr_ratio| float |0.5| 缩小loss_scaling的倍数,建议设为默认值
|use_dynamic_loss_scaling | bool | True| 是否使用动态损失缩放机制。如果开启,才会用到
`incr_every_n_steps`
,
`decr_every_n_nan_or_inf`
,
`incr_ratio`
,
`decr_ratio`
四个参数,开启会提高混合精度训练的稳定性和精度,建议设为默认值
|amp_lists|AutoMixedPrecisionLists类|None|自动混合精度列表类,可以指定具体使用fp16计算的operators列表,建议设为默认值
更多关于混合精度训练的介绍可参考:
-
Paper:
[
MIXED PRECISION TRAINING
](
https://arxiv.org/abs/1710.03740
)
-
Nvidia Introduction:
[
Training With Mixed Precision
](
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
)
## 训练性能
配置: Nvidia Tesla v100 GPU 单机8卡
| 模型
\速
度 | FP32训练 | 混合精度训练 | 加速比 |
| --- | --- | --- | --- |
| ResNet50 | 2567.96 images/s | 3643.11 images/s | 1.42 |
备注:上述模型训练使用的loss_type均为'dist_arcface'。
plsc/entry.py
浏览文件 @
1cf2e6d4
...
...
@@ -48,7 +48,7 @@ logging.basicConfig(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
datefmt
=
'%d %b %Y %H:%M:%S'
)
logger
=
logging
.
getLogger
(
__name__
)
class
Entry
(
object
):
"""
...
...
@@ -100,7 +100,6 @@ class Entry(object):
self
.
fs_dir
=
None
self
.
use_fp16
=
False
self
.
init_loss_scaling
=
1.0
self
.
fp16_user_dict
=
None
self
.
val_targets
=
self
.
config
.
val_targets
...
...
@@ -149,16 +148,30 @@ class Entry(object):
self
.
global_train_batch_size
=
batch_size
*
self
.
num_trainers
logger
.
info
(
"Set train batch size to {}."
.
format
(
batch_size
))
def
set_mixed_precision
(
self
,
use_fp16
,
loss_scaling
):
def
set_mixed_precision
(
self
,
use_fp16
,
init_loss_scaling
=
1.0
,
incr_every_n_steps
=
2000
,
decr_every_n_nan_or_inf
=
2
,
incr_ratio
=
2.0
,
decr_ratio
=
0.5
,
use_dynamic_loss_scaling
=
True
,
amp_lists
=
None
):
"""
Whether to use mixed precision training.
"""
self
.
use_fp16
=
use_fp16
self
.
init_loss_scaling
=
loss_scaling
self
.
fp16_user_dict
=
dict
()
self
.
fp16_user_dict
[
'init_loss_scaling'
]
=
self
.
init_loss_scaling
self
.
fp16_user_dict
[
'init_loss_scaling'
]
=
init_loss_scaling
self
.
fp16_user_dict
[
'incr_every_n_steps'
]
=
incr_every_n_steps
self
.
fp16_user_dict
[
'decr_every_n_nan_or_inf'
]
=
decr_every_n_nan_or_inf
self
.
fp16_user_dict
[
'incr_ratio'
]
=
incr_ratio
self
.
fp16_user_dict
[
'decr_ratio'
]
=
decr_ratio
self
.
fp16_user_dict
[
'use_dynamic_loss_scaling'
]
=
use_dynamic_loss_scaling
self
.
fp16_user_dict
[
'amp_lists'
]
=
amp_lists
logger
.
info
(
"Use mixed precision training: {}."
.
format
(
use_fp16
))
logger
.
info
(
"Set init loss scaling to {}."
.
format
(
loss_scaling
))
for
key
in
self
.
fp16_user_dict
:
logger
.
info
(
"Set init {} to {}."
.
format
(
key
,
self
.
fp16_user_dict
[
key
]))
def
set_test_batch_size
(
self
,
batch_size
):
self
.
test_batch_size
=
batch_size
...
...
@@ -168,7 +181,7 @@ class Entry(object):
def
set_hdfs_info
(
self
,
fs_name
,
fs_ugi
,
directory
):
"""
Set the info to download from or upload to hdfs filesystems.
If the information is provided, we will download pretrained
If the information is provided, we will download pretrained
model from hdfs at the begining and upload pretrained models
to hdfs at the end automatically.
"""
...
...
@@ -281,13 +294,13 @@ class Entry(object):
if
not
self
.
optimizer
:
bd
=
[
step
for
step
in
self
.
lr_steps
]
start_lr
=
self
.
lr
global_batch_size
=
self
.
global_train_batch_size
train_image_num
=
self
.
train_image_num
images_per_trainer
=
int
(
math
.
ceil
(
train_image_num
*
1.0
/
self
.
num_trainers
))
steps_per_pass
=
int
(
math
.
ceil
(
images_per_trainer
*
1.0
/
self
.
train_batch_size
))
images_per_trainer
*
1.0
/
self
.
train_batch_size
))
logger
.
info
(
"Steps per epoch: %d"
%
steps_per_pass
)
warmup_steps
=
steps_per_pass
*
self
.
warmup_epochs
batch_denom
=
1024
...
...
@@ -300,12 +313,12 @@ class Entry(object):
values
=
lr
),
warmup_steps
,
start_lr
,
base_lr
)
else
:
lr_val
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
lr_val
,
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
5e-4
))
self
.
optimizer
=
optimizer
if
self
.
loss_type
in
[
"dist_softmax"
,
"dist_arcface"
]:
self
.
optimizer
=
DistributedClassificationOptimizer
(
self
.
optimizer
,
global_batch_size
,
use_fp16
=
self
.
use_fp16
,
...
...
@@ -313,7 +326,15 @@ class Entry(object):
fp16_user_dict
=
self
.
fp16_user_dict
)
elif
self
.
use_fp16
:
self
.
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
=
optimizer
,
init_loss_scaling
=
self
.
init_loss_scaling
)
optimizer
=
optimizer
,
init_loss_scaling
=
self
.
fp16_user_dict
[
'init_loss_scaling'
],
incr_every_n_steps
=
self
.
fp16_user_dict
[
'incr_every_n_steps'
],
decr_every_n_nan_or_inf
=
self
.
fp16_user_dict
[
'decr_every_n_nan_or_inf'
],
incr_ratio
=
self
.
fp16_user_dict
[
'incr_ratio'
],
decr_ratio
=
self
.
fp16_user_dict
[
'decr_ratio'
],
use_dynamic_loss_scaling
=
self
.
fp16_user_dict
[
'use_dynamic_loss_scaling'
],
amp_lists
=
self
.
fp16_user_dict
[
'amp_lists'
]
)
return
self
.
optimizer
def
build_program
(
self
,
...
...
@@ -428,7 +449,7 @@ class Entry(object):
stdout
=
sys
.
stdout
,
stderr
=
subprocess
.
STDOUT
)
process
.
wait
()
for
file
in
os
.
listdir
(
local_dir
):
if
"dist@"
in
file
and
"@rank@"
in
file
:
file
=
os
.
path
.
join
(
local_dir
,
file
)
...
...
@@ -442,10 +463,10 @@ class Entry(object):
def
_append_broadcast_ops
(
self
,
program
):
"""
Before test, we broadcast bathnorm-related parameters to all
Before test, we broadcast bathnorm-related parameters to all
other trainers from trainer-0.
"""
bn_vars
=
[
var
for
var
in
program
.
list_vars
()
bn_vars
=
[
var
for
var
in
program
.
list_vars
()
if
'batch_norm'
in
var
.
name
and
var
.
persistable
]
block
=
program
.
current_block
()
for
var
in
bn_vars
:
...
...
@@ -475,7 +496,7 @@ class Entry(object):
shutil
.
rmtree
(
checkpoint_dir
)
os
.
makedirs
(
checkpoint_dir
)
# sync all trainers to avoid loading checkpoints before
# sync all trainers to avoid loading checkpoints before
# parameters are downloaded
file_name
=
os
.
path
.
join
(
checkpoint_dir
,
'.lock'
)
if
self
.
trainer_id
==
0
:
...
...
@@ -483,14 +504,14 @@ class Entry(object):
with
open
(
file_name
,
'w'
)
as
f
:
pass
time
.
sleep
(
10
)
os
.
remove
(
file_name
)
os
.
remove
(
file_name
)
else
:
while
True
:
if
not
os
.
path
.
exists
(
file_name
):
time
.
sleep
(
1
)
else
:
break
# Preporcess distributed parameters.
file_name
=
os
.
path
.
join
(
checkpoint_dir
,
'.lock'
)
distributed
=
self
.
loss_type
in
[
"dist_softmax"
,
"dist_arcface"
]
...
...
@@ -499,7 +520,7 @@ class Entry(object):
with
open
(
file_name
,
'w'
)
as
f
:
pass
time
.
sleep
(
10
)
os
.
remove
(
file_name
)
os
.
remove
(
file_name
)
elif
load_for_train
and
distributed
:
# wait trainer_id (0) to complete
while
True
:
...
...
@@ -600,7 +621,7 @@ class Entry(object):
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
'image'
,
'label'
],
program
=
main_program
)
fetch_list
=
[
emb
.
name
]
for
data
in
predict_reader
():
emb
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
...
...
@@ -674,33 +695,33 @@ class Entry(object):
embeddings
=
np
.
zeros
((
data
.
shape
[
0
],
_embeddings
.
shape
[
1
]))
embeddings
[
start
:
start
+
real_test_batch_size
,
:]
=
_embeddings
[:,
:]
beg
=
parallel_test_steps
*
real_test_batch_size
while
beg
<
data
.
shape
[
0
]:
end
=
min
(
beg
+
self
.
test_batch_size
,
data
.
shape
[
0
])
count
=
end
-
beg
_data
=
[]
for
k
in
xrange
(
end
-
self
.
test_batch_size
,
end
):
_data
.
append
((
data
[
k
],
0
))
[
_embeddings
]
=
exe
.
run
(
test_program
,
[
_embeddings
]
=
exe
.
run
(
test_program
,
fetch_list
=
fetch_list
,
feed
=
feeder
.
feed
(
_data
),
use_program_cache
=
True
)
_embeddings
=
_embeddings
[
0
:
self
.
test_batch_size
,:]
embeddings
[
beg
:
end
,
:]
=
_embeddings
[(
self
.
test_batch_size
-
count
):,
:]
beg
=
end
embeddings_list
.
append
(
embeddings
)
xnorm
=
0.0
xnorm_cnt
=
0
for
embed
in
embeddings_list
:
xnorm
+=
np
.
sqrt
((
embed
*
embed
).
sum
(
axis
=
1
)).
sum
(
axis
=
0
)
xnorm_cnt
+=
embed
.
shape
[
0
]
xnorm
/=
xnorm_cnt
embeddings
=
embeddings_list
[
0
]
+
embeddings_list
[
1
]
embeddings
=
sklearn
.
preprocessing
.
normalize
(
embeddings
)
_
,
_
,
accuracy
,
val
,
val_std
,
far
=
evaluate
(
embeddings
,
issame_list
,
nrof_folds
=
10
)
acc
,
std
=
np
.
mean
(
accuracy
),
np
.
std
(
accuracy
)
print
(
'[%s][%d]XNorm: %f'
%
(
test_name_list
[
i
],
pass_id
,
xnorm
))
print
(
'[%s][%d]Accuracy-Flip: %1.5f+-%1.5f'
%
(
test_name_list
[
i
],
pass_id
,
acc
,
std
))
sys
.
stdout
.
flush
()
...
...
@@ -712,7 +733,7 @@ class Entry(object):
trainer_id
=
self
.
trainer_id
num_trainers
=
self
.
num_trainers
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
strategy
=
DistributedStrategy
()
...
...
@@ -720,7 +741,7 @@ class Entry(object):
strategy
.
collective_mode
=
"grad_allreduce"
self
.
fleet
=
fleet
self
.
strategy
=
strategy
train_emb
,
train_loss
,
train_acc1
,
train_acc5
,
optimizer
=
\
self
.
build_program
(
True
,
False
)
if
self
.
with_test
:
...
...
@@ -730,10 +751,10 @@ class Entry(object):
self
.
dataset_dir
,
self
.
val_targets
)
test_program
=
self
.
test_program
self
.
_append_broadcast_ops
(
test_program
)
global_lr
=
optimizer
.
_global_learning_rate
(
program
=
self
.
train_program
)
origin_prog
=
fleet
.
_origin_program
train_prog
=
fleet
.
main_program
if
trainer_id
==
0
:
...
...
@@ -745,25 +766,25 @@ class Entry(object):
program_to_code
(
origin_prog
,
fout
,
True
)
with
open
(
'test.program'
,
'w'
)
as
fout
:
program_to_code
(
test_program
,
fout
,
True
)
gpu_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
0
))
place
=
fluid
.
CUDAPlace
(
gpu_id
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
)
if
self
.
with_test
:
test_feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
'image'
,
'label'
],
program
=
test_program
)
fetch_list_test
=
[
test_emb
.
name
]
fetch_list_test
=
[
test_emb
.
name
]
real_test_batch_size
=
self
.
global_test_batch_size
if
self
.
checkpoint_dir
:
load_checkpoint
=
True
else
:
load_checkpoint
=
False
if
load_checkpoint
:
self
.
load_checkpoint
(
executor
=
exe
,
main_program
=
origin_prog
)
if
self
.
train_reader
is
None
:
train_reader
=
paddle
.
batch
(
reader
.
arc_train
(
self
.
dataset_dir
,
self
.
num_classes
),
...
...
@@ -773,13 +794,13 @@ class Entry(object):
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
'image'
,
'label'
],
program
=
origin_prog
)
if
self
.
calc_train_acc
:
fetch_list
=
[
train_loss
.
name
,
global_lr
.
name
,
train_acc1
.
name
,
train_acc5
.
name
]
else
:
fetch_list
=
[
train_loss
.
name
,
global_lr
.
name
]
local_time
=
0.0
nsamples
=
0
inspect_steps
=
200
...
...
@@ -818,7 +839,7 @@ class Entry(object):
local_time
=
0
nsamples
=
0
local_train_info
=
[[],
[],
[],
[]]
train_loss
=
np
.
array
(
train_info
[
0
]).
mean
()
print
(
"End pass {0}, train_loss {1}"
.
format
(
pass_id
,
train_loss
))
sys
.
stdout
.
flush
()
...
...
@@ -850,45 +871,45 @@ class Entry(object):
embeddings
=
np
.
zeros
((
data
.
shape
[
0
],
_embeddings
.
shape
[
1
]))
embeddings
[
start
:
start
+
real_test_batch_size
,
:]
=
_embeddings
[:,
:]
beg
=
parallel_test_steps
*
real_test_batch_size
while
beg
<
data
.
shape
[
0
]:
end
=
min
(
beg
+
self
.
test_batch_size
,
data
.
shape
[
0
])
count
=
end
-
beg
_data
=
[]
for
k
in
xrange
(
end
-
self
.
test_batch_size
,
end
):
_data
.
append
((
data
[
k
],
0
))
[
_embeddings
]
=
exe
.
run
(
test_program
,
[
_embeddings
]
=
exe
.
run
(
test_program
,
fetch_list
=
fetch_list_test
,
feed
=
test_feeder
.
feed
(
_data
),
use_program_cache
=
True
)
_embeddings
=
_embeddings
[
0
:
self
.
test_batch_size
,:]
embeddings
[
beg
:
end
,
:]
=
_embeddings
[(
self
.
test_batch_size
-
count
):,
:]
beg
=
end
embeddings_list
.
append
(
embeddings
)
xnorm
=
0.0
xnorm_cnt
=
0
for
embed
in
embeddings_list
:
xnorm
+=
np
.
sqrt
((
embed
*
embed
).
sum
(
axis
=
1
)).
sum
(
axis
=
0
)
xnorm_cnt
+=
embed
.
shape
[
0
]
xnorm
/=
xnorm_cnt
embeddings
=
embeddings_list
[
0
]
+
embeddings_list
[
1
]
embeddings
=
sklearn
.
preprocessing
.
normalize
(
embeddings
)
_
,
_
,
accuracy
,
val
,
val_std
,
far
=
evaluate
(
embeddings
,
issame_list
,
nrof_folds
=
10
)
acc
,
std
=
np
.
mean
(
accuracy
),
np
.
std
(
accuracy
)
print
(
'[%s][%d]XNorm: %f'
%
(
test_name_list
[
i
],
pass_id
,
xnorm
))
print
(
'[%s][%d]Accuracy-Flip: %1.5f+-%1.5f'
%
(
test_name_list
[
i
],
pass_id
,
acc
,
std
))
sys
.
stdout
.
flush
()
test_end
=
time
.
time
()
print
(
"test time: {}"
.
format
(
test_end
-
test_start
))
#save model
if
self
.
model_save_dir
:
model_save_dir
=
os
.
path
.
join
(
self
.
model_save_dir
,
str
(
pass_id
))
if
not
os
.
path
.
exists
(
model_save_dir
):
# may be more than one processes trying
# may be more than one processes trying
# to create the directory
try
:
os
.
makedirs
(
model_save_dir
)
...
...
@@ -921,7 +942,7 @@ class Entry(object):
#upload model
if
self
.
model_save_dir
and
self
.
fs_name
and
trainer_id
==
0
:
self
.
put_files_to_hdfs
(
self
.
model_save_dir
)
if
__name__
==
'__main__'
:
ins
=
Entry
()
...
...
plsc/models/dist_algo.py
浏览文件 @
1cf2e6d4
...
...
@@ -41,13 +41,13 @@ class DistributedClassificationOptimizer(Optimizer):
def
init_fp16_params
(
self
,
loss_type
,
fp16_user_dict
):
# set default value for fp16_params_dict
fp16_params_dict
=
dict
()
fp16_params_dict
[
'amp_lists'
]
=
None
fp16_params_dict
[
'init_loss_scaling'
]
=
1.0
fp16_params_dict
[
'incr_every_n_steps'
]
=
1000
fp16_params_dict
[
'decr_every_n_nan_or_inf'
]
=
2
fp16_params_dict
[
'incr_ratio'
]
=
2.0
fp16_params_dict
[
'decr_ratio'
]
=
0.5
fp16_params_dict
[
'use_dynamic_loss_scaling'
]
=
True
fp16_params_dict
[
'amp_lists'
]
=
None
if
fp16_user_dict
is
not
None
:
# update fp16_params_dict
for
key
in
fp16_user_dict
:
...
...
@@ -56,8 +56,9 @@ class DistributedClassificationOptimizer(Optimizer):
else
:
logging
.
warning
(
"Can't find name '%s' in our fp16_params_dict. "
"Please check your dict key. You can set fp16 params only "
"in [amp_lists, init_loss_scaling, decr_every_n_nan_or_inf, "
"incr_ratio, decr_ratio, use_dynamic_loss_scaling]."
%
(
key
))
"in [init_loss_scaling, incr_every_n_steps, "
"decr_every_n_nan_or_inf, incr_ratio, decr_ratio, "
"use_dynamic_loss_scaling, amp_lists]"
%
(
key
))
self
.
_amp_lists
=
fp16_params_dict
[
'amp_lists'
]
if
self
.
_amp_lists
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录