Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
1d494f87
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1d494f87
编写于
9月 09, 2019
作者:
Y
Yibing Liu
提交者:
GitHub
9月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Release auto mixed-precision training for bert (#3292)
* Release auto mixed-precision training for bert * Update README
上级
2a14983c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
293 addition
and
135 deletion
+293
-135
PaddleNLP/language_representations_kit/BERT/README.md
PaddleNLP/language_representations_kit/BERT/README.md
+6
-10
PaddleNLP/language_representations_kit/BERT/model/classifier.py
...NLP/language_representations_kit/BERT/model/classifier.py
+0
-3
PaddleNLP/language_representations_kit/BERT/optimization.py
PaddleNLP/language_representations_kit/BERT/optimization.py
+55
-30
PaddleNLP/language_representations_kit/BERT/run_classifier.py
...leNLP/language_representations_kit/BERT/run_classifier.py
+25
-16
PaddleNLP/language_representations_kit/BERT/run_squad.py
PaddleNLP/language_representations_kit/BERT/run_squad.py
+24
-18
PaddleNLP/language_representations_kit/BERT/train.py
PaddleNLP/language_representations_kit/BERT/train.py
+39
-18
PaddleNLP/language_representations_kit/BERT/train.sh
PaddleNLP/language_representations_kit/BERT/train.sh
+1
-1
PaddleNLP/language_representations_kit/BERT/utils/fp16.py
PaddleNLP/language_representations_kit/BERT/utils/fp16.py
+143
-39
未找到文件。
PaddleNLP/language_representations_kit/BERT/README.md
浏览文件 @
1d494f87
...
...
@@ -48,7 +48,7 @@
-
[
inference 接口调用示例
](
#inference-接口调用示例
)
## 安装
本项目依赖于 Paddle Fluid
**1.5.1**
,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
本项目依赖于 Paddle Fluid
**1.5.1**
及以上版本
,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
## 预训练
...
...
@@ -59,7 +59,7 @@
我们给出了 token 化后的示例明文数据:
[
`demo_wiki_tokens.txt`
](
./data/demo_wiki_tokens.txt
)
,其中每行数据为2个 tab 分隔的句子,示例如下:
```
1 . 雏 凤 鸣 剧 团
2 . 古 典 之 门 : 帝 女 花 3 . 戏 曲 之 旅 : 第 155 期 心 系 唐 氏 慈 善 戏 曲 晚 会 4 . 区 文 凤 , 郑 燕 虹 1999 编 , 香 港 当 代 粤 剧 人 名 录 , 中 大 音 乐 系 5 . 王 胜 泉 , 张 文 珊 2011 编 , 香 港 当 代 粤 剧 人 名 录 , 中 大 音 乐 系
1 . 雏 凤 鸣 剧 团
2 . 古 典 之 门 : 帝 女 花 3 . 戏 曲 之 旅 : 第 155 期 心 系 唐 氏 慈 善 戏 曲 晚 会 4 . 区 文 凤 , 郑 燕 虹 1999 编 , 香 港 当 代 粤 剧 人 名 录 , 中 大 音 乐 系 5 . 王 胜 泉 , 张 文 珊 2011 编 , 香 港 当 代 粤 剧 人 名 录 , 中 大 音 乐 系
```
同时我们也给出了 id 化后的部分训练数据:
[
`demo_wiki_train.gz`
](
./data/train/demo_wiki_train.gz
)
、和测试数据:
[
`demo_wiki_validation.gz`
](
./data/validation/demo_wiki_validation.gz
)
,每行数据为1个训练样本,示例如下:
...
...
@@ -294,21 +294,17 @@ python ${SQUAD_PATH}/evaluate-v2.0.py ${SQUAD_PATH}/dev-v2.0.json ${CHECKPOINT_P
其中会输出
`best_f1_thresh`
是最佳阈值,可以使用这个阈值重新训练,或者从
`nbest_predictions.json`
中重新抽取最终
`prediction`
。
训练方法与前面大体相同,只需要设定
`--null_score_diff_threshold`
参数的值为测评时输出的
`best_f1_thresh`
,通常这个值在 -1.0 到 -5.0 之间。
## 混合精度训练
##
动态
混合精度训练
预训练过程和 Fine-tuning 均支持 FP16/FP32
混合精度训练。要使能混合精度训练
,只需在前面所述的这些训练启动命令中加入参数
预训练过程和 Fine-tuning 均支持 FP16/FP32
动态混合精度训练(Auto Mixed-Precision training, AMP)。在 V100/T4 等支持 tensorcore 的 GPU 设备上,AMP 能显著地加速训练过程。要使能 AMP
,只需在前面所述的这些训练启动命令中加入参数
```
--use_fp16=true \
```
为了减少混合精度训练的精度损失,通常在训练过程中计算误差的反向传播时,会将损失函数乘上一个大于 1.0 的因子
,这里可以通过如下方式设置这个因子
为了减少混合精度训练的精度损失,通常在训练过程中计算误差的反向传播时,会将损失函数乘上一个大于 1.0 的因子
`loss_scaling`
。在动态混合精度训练过程中,
`loss_scaling`
会动态调整,使得训练过程相对于 FP32 尽可能无精度损失。
```
--loss_scaling=8.0 \
```
实验表明,在 BERT 相关的任务中
`loss_scaling`
的取值范围在 8.0 ~ 128.0 之间时模型训练精度没有显著的损失,在 V100 GPU 上混合精度训练相对于 FP32 训练有 1.7 左右的加速比。
实验表明,在 GPU V100 上 BERT BASE 的 AMP 训练相对于 FP32 训练有 1.7x 的加速比, BERT LARGE 有 2.0x 的加速比。
更多的细节,可参见
[
参考论文
](
https://arxiv.org/abs/1710.03740
)
。
...
...
PaddleNLP/language_representations_kit/BERT/model/classifier.py
浏览文件 @
1d494f87
...
...
@@ -77,9 +77,6 @@ def create_model(args, bert_config, num_labels, is_prediction=False):
logits
=
logits
,
label
=
labels
,
return_softmax
=
True
)
loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
if
args
.
use_fp16
and
args
.
loss_scaling
>
1.0
:
loss
*=
args
.
loss_scaling
num_seqs
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
accuracy
=
fluid
.
layers
.
accuracy
(
input
=
probs
,
label
=
labels
,
total
=
num_seqs
)
...
...
PaddleNLP/language_representations_kit/BERT/optimization.py
浏览文件 @
1d494f87
...
...
@@ -19,7 +19,7 @@ from __future__ import print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
from
utils.fp16
import
create_master_params_grads
,
master_param_to_train_param
from
utils.fp16
import
create_master_params_grads
,
master_param_to_train_param
,
apply_dynamic_loss_scaling
def
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
):
...
...
@@ -59,32 +59,42 @@ def optimization(loss,
weight_decay
,
scheduler
=
'linear_warmup_decay'
,
use_fp16
=
False
,
loss_scaling
=
1.0
):
if
warmup_steps
>
0
:
if
scheduler
==
'noam_decay'
:
use_dynamic_loss_scaling
=
False
,
init_loss_scaling
=
1.0
,
incr_every_n_steps
=
1000
,
decr_every_n_nan_or_inf
=
2
,
incr_ratio
=
2.0
,
decr_ratio
=
0.8
):
scheduled_lr
,
loss_scaling
=
None
,
None
if
scheduler
==
'noam_decay'
:
if
warmup_steps
>
0
:
scheduled_lr
=
fluid
.
layers
.
learning_rate_scheduler
\
.
noam_decay
(
1
/
(
warmup_steps
*
(
learning_rate
**
2
)),
warmup_steps
)
elif
scheduler
==
'linear_warmup_decay'
:
scheduled_lr
=
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
)
warmup_steps
)
else
:
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
printf
(
"WARNING: noam decay should have postive warmup steps, using "
"constant learning rate instead!"
)
scheduled_lr
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
learning_rate
,
dtype
=
'float32'
,
persistable
=
True
)
elif
scheduler
==
'linear_warmup_decay'
:
scheduled_lr
=
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
)
else
:
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
)
scheduled_lr
=
learning_rate
clip_norm_thres
=
1.0
# When using mixed precision training, scale the gradient clip threshold
# by loss_scaling
if
use_fp16
and
loss_scaling
>
1.0
:
clip_norm_thres
*=
loss_scaling
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
clip_norm_thres
))
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
))
def
exclude_from_weight_decay
(
name
):
def
exclude_from_weight_decay
(
param
):
name
=
param
.
name
.
rstrip
(
".master"
)
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
...
...
@@ -96,19 +106,33 @@ def optimization(loss,
param_list
=
dict
()
if
use_fp16
:
loss_scaling
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"loss_scaling"
),
shape
=
[
1
],
value
=
init_loss_scaling
,
dtype
=
'float32'
,
persistable
=
True
)
loss
*=
loss_scaling
param_grads
=
optimizer
.
backward
(
loss
)
master_param_grads
=
create_master_params_grads
(
param_grads
,
train_program
,
startup_prog
,
loss_scaling
)
for
param
,
_
in
master_param_grads
:
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
weight_decay
>
0
:
for
param
,
_
in
master_param_grads
:
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
use_dynamic_loss_scaling
:
apply_dynamic_loss_scaling
(
loss_scaling
,
master_param_grads
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
optimizer
.
apply_gradients
(
master_param_grads
)
if
weight_decay
>
0
:
for
param
,
grad
in
master_param_grads
:
if
exclude_from_weight_decay
(
param
.
name
.
rstrip
(
".master"
)
):
if
exclude_from_weight_decay
(
param
):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
...
...
@@ -120,15 +144,16 @@ def optimization(loss,
train_program
)
else
:
for
param
in
train_program
.
global_block
().
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
weight_decay
>
0
:
for
param
in
train_program
.
global_block
().
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
if
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
exclude_from_weight_decay
(
param
.
name
):
if
exclude_from_weight_decay
(
param
):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
...
...
@@ -136,4 +161,4 @@ def optimization(loss,
param
.
name
]
*
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
return
scheduled_lr
return
scheduled_lr
,
loss_scaling
PaddleNLP/language_representations_kit/BERT/run_classifier.py
浏览文件 @
1d494f87
...
...
@@ -59,8 +59,16 @@ train_g.add_arg("warmup_proportion", float, 0.1,
train_g
.
add_arg
(
"save_steps"
,
int
,
10000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"validation_steps"
,
int
,
1000
,
"The steps interval to evaluate model performance."
)
train_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to use fp16 mixed precision training."
)
train_g
.
add_arg
(
"loss_scaling"
,
float
,
1.0
,
train_g
.
add_arg
(
"use_dynamic_loss_scaling"
,
bool
,
True
,
"Whether to use dynamic loss scaling in mixed precision training."
)
train_g
.
add_arg
(
"init_loss_scaling"
,
float
,
2
**
32
,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g
.
add_arg
(
"incr_every_n_steps"
,
int
,
1000
,
"Increases loss scaling every n consecutive."
)
train_g
.
add_arg
(
"decr_every_n_nan_or_inf"
,
int
,
2
,
"Decreases loss scaling every n accumulated steps with nan or inf gradients."
)
train_g
.
add_arg
(
"incr_ratio"
,
float
,
2.0
,
"The multiplier to use when increasing the loss scaling."
)
train_g
.
add_arg
(
"decr_ratio"
,
float
,
0.8
,
"The less-than-one-multiplier to use when decreasing."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related."
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
...
...
@@ -195,7 +203,7 @@ def main(args):
args
,
bert_config
=
bert_config
,
num_labels
=
num_labels
)
scheduled_lr
=
optimization
(
scheduled_lr
,
loss_scaling
=
optimization
(
loss
=
loss
,
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
...
...
@@ -205,7 +213,12 @@ def main(args):
weight_decay
=
args
.
weight_decay
,
scheduler
=
args
.
lr_scheduler
,
use_fp16
=
args
.
use_fp16
,
loss_scaling
=
args
.
loss_scaling
)
use_dynamic_loss_scaling
=
args
.
use_dynamic_loss_scaling
,
init_loss_scaling
=
args
.
init_loss_scaling
,
incr_every_n_steps
=
args
.
incr_every_n_steps
,
decr_every_n_nan_or_inf
=
args
.
decr_every_n_nan_or_inf
,
incr_ratio
=
args
.
incr_ratio
,
decr_ratio
=
args
.
decr_ratio
)
if
args
.
verbose
:
if
args
.
in_tokens
:
...
...
@@ -311,23 +324,20 @@ def main(args):
ce_info
=
[]
while
True
:
try
:
#
steps += 1
steps
+=
1
if
steps
%
args
.
skip_steps
==
0
:
if
warmup_steps
<=
0
:
fetch_list
=
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
]
if
args
.
use_fp16
:
fetch_list
=
[
loss
.
name
,
accuracy
.
name
,
scheduled_lr
.
name
,
num_seqs
.
name
,
loss_scaling
.
name
]
else
:
fetch_list
=
[
loss
.
name
,
accuracy
.
name
,
scheduled_lr
.
name
,
num_seqs
.
name
]
fetch_list
=
[
loss
.
name
,
accuracy
.
name
,
scheduled_lr
.
name
,
num_seqs
.
name
]
else
:
fetch_list
=
[]
outputs
=
exe
.
run
(
train_compiled_program
,
fetch_list
=
fetch_list
)
if
steps
%
args
.
skip_steps
==
0
:
if
warmup_steps
<=
0
:
np_loss
,
np_acc
,
np_
num_seqs
=
outputs
if
args
.
use_fp16
:
np_loss
,
np_acc
,
np_
lr
,
np_num_seqs
,
np_scaling
=
outputs
else
:
np_loss
,
np_acc
,
np_lr
,
np_num_seqs
=
outputs
...
...
@@ -338,9 +348,9 @@ def main(args):
if
args
.
verbose
:
verbose
=
"train pyreader queue size: %d, "
%
train_pyreader
.
queue
.
size
(
)
verbose
+=
"learning rate: %f"
%
(
np_lr
[
0
]
if
warmup_steps
>
0
else
args
.
learning_rate
)
verbose
+=
"learning rate: %f"
%
np_lr
[
0
]
if
args
.
use_fp16
:
verbose
+=
", loss scaling: %f"
%
np_scaling
[
0
]
print
(
verbose
)
current_example
,
current_epoch
=
processor
.
get_train_progress
(
...
...
@@ -362,7 +372,6 @@ def main(args):
total_cost
,
total_acc
,
total_num_seqs
=
[],
[],
[]
time_begin
=
time
.
time
()
steps
+=
1
if
steps
%
args
.
save_steps
==
0
:
save_path
=
os
.
path
.
join
(
args
.
checkpoints
,
"step_"
+
str
(
steps
))
...
...
PaddleNLP/language_representations_kit/BERT/run_squad.py
浏览文件 @
1d494f87
...
...
@@ -52,8 +52,16 @@ train_g.add_arg("warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
1000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to use fp16 mixed precision training."
)
train_g
.
add_arg
(
"loss_scaling"
,
float
,
1.0
,
train_g
.
add_arg
(
"use_dynamic_loss_scaling"
,
bool
,
True
,
"Whether to use dynamic loss scaling in mixed precision training."
)
train_g
.
add_arg
(
"init_loss_scaling"
,
float
,
2
**
32
,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g
.
add_arg
(
"incr_every_n_steps"
,
int
,
1000
,
"Increases loss scaling every n consecutive."
)
train_g
.
add_arg
(
"decr_every_n_nan_or_inf"
,
int
,
2
,
"Decreases loss scaling every n accumulated steps with nan or inf gradients."
)
train_g
.
add_arg
(
"incr_ratio"
,
float
,
2.0
,
"The multiplier to use when increasing the loss scaling."
)
train_g
.
add_arg
(
"decr_ratio"
,
float
,
0.8
,
"The less-than-one-multiplier to use when decreasing."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related."
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
...
...
@@ -164,9 +172,6 @@ def create_model(bert_config, is_training=False):
start_loss
=
compute_loss
(
start_logits
,
start_positions
)
end_loss
=
compute_loss
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2.0
if
args
.
use_fp16
and
args
.
loss_scaling
>
1.0
:
total_loss
=
total_loss
*
args
.
loss_scaling
return
pyreader
,
total_loss
,
num_seqs
else
:
return
pyreader
,
unique_id
,
start_logits
,
end_logits
,
num_seqs
...
...
@@ -274,7 +279,7 @@ def train(args):
bert_config
=
bert_config
,
is_training
=
True
)
scheduled_lr
=
optimization
(
scheduled_lr
,
loss_scaling
=
optimization
(
loss
=
loss
,
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
...
...
@@ -284,8 +289,12 @@ def train(args):
weight_decay
=
args
.
weight_decay
,
scheduler
=
args
.
lr_scheduler
,
use_fp16
=
args
.
use_fp16
,
loss_scaling
=
args
.
loss_scaling
)
use_dynamic_loss_scaling
=
args
.
use_dynamic_loss_scaling
,
init_loss_scaling
=
args
.
init_loss_scaling
,
incr_every_n_steps
=
args
.
incr_every_n_steps
,
decr_every_n_nan_or_inf
=
args
.
decr_every_n_nan_or_inf
,
incr_ratio
=
args
.
incr_ratio
,
decr_ratio
=
args
.
decr_ratio
)
if
args
.
verbose
:
if
args
.
in_tokens
:
...
...
@@ -306,7 +315,6 @@ def train(args):
bert_config
=
bert_config
,
is_training
=
False
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
exe
.
run
(
startup_prog
)
...
...
@@ -357,20 +365,18 @@ def train(args):
try
:
steps
+=
1
if
steps
%
args
.
skip_steps
==
0
:
if
warmup_steps
<=
0
:
fetch_list
=
[
loss
.
name
,
num_seqs
.
name
]
if
args
.
use_fp16
:
fetch_list
=
[
loss
.
name
,
scheduled_lr
.
name
,
num_seqs
.
name
,
loss_scaling
.
name
]
else
:
fetch_list
=
[
loss
.
name
,
scheduled_lr
.
name
,
num_seqs
.
name
]
fetch_list
=
[
loss
.
name
,
scheduled_lr
.
name
,
num_seqs
.
name
]
else
:
fetch_list
=
[]
outputs
=
exe
.
run
(
train_compiled_program
,
fetch_list
=
fetch_list
)
if
steps
%
args
.
skip_steps
==
0
:
if
warmup_steps
<=
0
:
np_loss
,
np_
num_seqs
=
outputs
if
args
.
use_fp16
:
np_loss
,
np_
lr
,
np_num_seqs
,
np_scaling
=
outputs
else
:
np_loss
,
np_lr
,
np_num_seqs
=
outputs
total_cost
.
extend
(
np_loss
*
np_num_seqs
)
...
...
@@ -379,9 +385,9 @@ def train(args):
if
args
.
verbose
:
verbose
=
"train pyreader queue size: %d, "
%
train_pyreader
.
queue
.
size
(
)
verbose
+=
"learning rate: %f
"
%
(
np_lr
[
0
]
if
warmup_steps
>
0
else
args
.
learning_rate
)
verbose
+=
"learning rate: %f
"
%
np_lr
[
0
]
if
args
.
use_fp16
:
verbose
+=
", loss scaling: %f"
%
np_scaling
[
0
]
print
(
verbose
)
time_end
=
time
.
time
()
...
...
PaddleNLP/language_representations_kit/BERT/train.py
浏览文件 @
1d494f87
...
...
@@ -52,8 +52,16 @@ train_g.add_arg("warmup_steps", int, 4000, "Total steps to perform wa
train_g
.
add_arg
(
"save_steps"
,
int
,
10000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"validation_steps"
,
int
,
1000
,
"The steps interval to evaluate model performance."
)
train_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to use fp16 mixed precision training."
)
train_g
.
add_arg
(
"loss_scaling"
,
float
,
1.0
,
train_g
.
add_arg
(
"use_dynamic_loss_scaling"
,
bool
,
True
,
"Whether to use dynamic loss scaling in mixed precision training."
)
train_g
.
add_arg
(
"init_loss_scaling"
,
float
,
2
**
32
,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g
.
add_arg
(
"incr_every_n_steps"
,
int
,
1000
,
"Increases loss scaling every n consecutive."
)
train_g
.
add_arg
(
"decr_every_n_nan_or_inf"
,
int
,
2
,
"Decreases loss scaling every n accumulated steps with nan or inf gradients."
)
train_g
.
add_arg
(
"incr_ratio"
,
float
,
2.0
,
"The multiplier to use when increasing the loss scaling."
)
train_g
.
add_arg
(
"decr_ratio"
,
float
,
0.8
,
"The less-than-one-multiplier to use when decreasing."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related."
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
...
...
@@ -113,9 +121,6 @@ def create_model(bert_config):
next_sent_acc
,
mask_lm_loss
,
total_loss
=
bert
.
get_pretraining_output
(
mask_label
,
mask_pos
,
labels
)
if
args
.
use_fp16
and
args
.
loss_scaling
>
1.0
:
total_loss
*=
args
.
loss_scaling
return
pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
...
...
@@ -220,7 +225,7 @@ def train(args):
with
fluid
.
unique_name
.
guard
():
train_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
bert_config
=
bert_config
)
scheduled_lr
=
optimization
(
scheduled_lr
,
loss_scaling
=
optimization
(
loss
=
total_loss
,
warmup_steps
=
args
.
warmup_steps
,
num_train_steps
=
args
.
num_train_steps
,
...
...
@@ -230,7 +235,12 @@ def train(args):
weight_decay
=
args
.
weight_decay
,
scheduler
=
args
.
lr_scheduler
,
use_fp16
=
args
.
use_fp16
,
loss_scaling
=
args
.
loss_scaling
)
use_dynamic_loss_scaling
=
args
.
use_dynamic_loss_scaling
,
init_loss_scaling
=
args
.
init_loss_scaling
,
incr_every_n_steps
=
args
.
incr_every_n_steps
,
decr_every_n_nan_or_inf
=
args
.
decr_every_n_nan_or_inf
,
incr_ratio
=
args
.
incr_ratio
,
decr_ratio
=
args
.
decr_ratio
)
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
...
...
@@ -341,7 +351,7 @@ def train(args):
time_begin
=
time
.
time
()
while
steps
<
args
.
num_train_steps
:
try
:
steps
+=
nccl2_num_trainers
steps
+=
1
skip_steps
=
args
.
skip_steps
*
nccl2_num_trainers
if
nccl2_trainer_id
!=
0
:
...
...
@@ -351,34 +361,45 @@ def train(args):
exe
.
run
(
fetch_list
=
[],
program
=
train_compiled_program
)
continue
if
steps
%
skip_steps
!=
0
:
if
steps
%
args
.
skip_steps
!=
0
:
if
use_ngraph
:
exe
.
run
(
fetch_list
=
[],
program
=
train_program
)
else
:
exe
.
run
(
fetch_list
=
[],
program
=
train_compiled_program
)
else
:
fetch_list
=
[
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
scheduled_lr
.
name
]
if
args
.
use_fp16
:
fetch_list
.
append
(
loss_scaling
.
name
)
if
use_ngraph
:
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
exe
.
run
(
fetch_list
=
[
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
scheduled_lr
.
name
],
program
=
train_program
)
outputs
=
exe
.
run
(
fetch_list
=
fetch_list
,
program
=
train_program
)
else
:
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
exe
.
run
(
fetch_list
=
[
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
scheduled_lr
.
name
],
program
=
train_compiled_program
)
outputs
=
exe
.
run
(
fetch_list
=
fetch_list
,
program
=
train_compiled_program
)
if
args
.
use_fp16
:
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
,
np_scaling
=
outputs
else
:
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
outputs
acc
.
extend
(
each_next_acc
)
lm_cost
.
extend
(
each_mask_lm_cost
)
cost
.
extend
(
each_total_cost
)
print
(
"feed_queue size"
,
train_pyreader
.
queue
.
size
())
time_end
=
time
.
time
()
used_time
=
time_end
-
time_begin
epoch
,
current_file_index
,
total_file
,
current_file
=
data_reader
.
get_progress
(
)
print
(
"current learning_rate:%f"
%
np_lr
[
0
])
if
args
.
verbose
:
verbose
=
"feed_queue size: %d, "
%
train_pyreader
.
queue
.
size
()
verbose
+=
"current learning_rate: %f, "
%
np_lr
[
0
]
if
args
.
use_fp16
:
verbose
+=
"loss scaling: %f"
%
np_scaling
[
0
]
print
(
verbose
)
print
(
"epoch: %d, progress: %d/%d, step: %d, loss: %f, "
"ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s"
%
(
epoch
,
current_file_index
,
total_file
,
steps
,
...
...
PaddleNLP/language_representations_kit/BERT/train.sh
浏览文件 @
1d494f87
...
...
@@ -51,5 +51,5 @@ python -u ./train.py ${is_distributed}\
--validation_steps
1000
\
--num_iteration_per_drop_scope
10
\
--use_fp16
false
\
--
loss_scaling
8.0
--
verbose
true
PaddleNLP/language_representations_kit/BERT/utils/fp16.py
浏览文件 @
1d494f87
...
...
@@ -17,26 +17,20 @@ import paddle
import
paddle.fluid
as
fluid
def
cast_fp16_to_fp32
(
i
,
o
,
prog
):
def
append_cast_op
(
i
,
o
,
prog
):
"""
Append a cast op in a given Program to cast input `i` to data type `o.dtype`.
Args:
i (Variable): The input Variable.
o (Variable): The output Variable.
prog (Program): The Program to append cast op.
"""
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP16
,
"out_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP32
})
def
cast_fp32_to_fp16
(
i
,
o
,
prog
):
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP16
})
attrs
=
{
"in_dtype"
:
i
.
dtype
,
"out_dtype"
:
o
.
dtype
})
def
copy_to_master_param
(
p
,
block
):
...
...
@@ -59,32 +53,66 @@ def copy_to_master_param(p, block):
return
new_p
def
apply_dynamic_loss_scaling
(
loss_scaling
,
master_params_grads
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
_incr_every_n_steps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
incr_every_n_steps
)
_decr_every_n_nan_or_inf
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
decr_every_n_nan_or_inf
)
_num_good_steps
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"num_good_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
_num_bad_steps
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
grads
=
[
fluid
.
layers
.
reduce_sum
(
g
)
for
[
_
,
g
]
in
master_params_grads
]
all_grads
=
fluid
.
layers
.
concat
(
grads
)
all_grads_sum
=
fluid
.
layers
.
reduce_sum
(
all_grads
)
is_overall_finite
=
fluid
.
layers
.
isfinite
(
all_grads_sum
)
update_loss_scaling
(
is_overall_finite
,
loss_scaling
,
_num_good_steps
,
_num_bad_steps
,
_incr_every_n_steps
,
_decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
pass
with
switch
.
default
():
for
_
,
g
in
master_params_grads
:
fluid
.
layers
.
assign
(
fluid
.
layers
.
zeros_like
(
g
),
g
)
def
create_master_params_grads
(
params_grads
,
main_prog
,
startup_prog
,
loss_scaling
):
master_params_grads
=
[]
tmp_role
=
main_prog
.
_current_role
OpRole
=
fluid
.
core
.
op_proto_and_checker_maker
.
OpRole
main_prog
.
_current_role
=
OpRole
.
Backward
for
p
,
g
in
params_grads
:
# create master parameters
master_param
=
copy_to_master_param
(
p
,
main_prog
.
global_block
())
startup_master_param
=
startup_prog
.
global_block
().
_clone_variable
(
master_param
)
startup_p
=
startup_prog
.
global_block
().
var
(
p
.
name
)
cast_fp16_to_fp32
(
startup_p
,
startup_master_param
,
startup_prog
)
# cast fp16 gradients to fp32 before apply gradients
if
g
.
name
.
find
(
"layer_norm"
)
>
-
1
:
if
loss_scaling
>
1
:
scaled_g
=
g
/
float
(
loss_scaling
)
else
:
scaled_g
=
g
master_params_grads
.
append
([
p
,
scaled_g
])
continue
master_grad
=
fluid
.
layers
.
cast
(
g
,
"float32"
)
if
loss_scaling
>
1
:
master_grad
=
master_grad
/
float
(
loss_scaling
)
master_params_grads
.
append
([
master_param
,
master_grad
])
main_prog
.
_current_role
=
tmp_role
with
main_prog
.
_optimized_guard
([
p
,
g
]):
# create master parameters
master_param
=
copy_to_master_param
(
p
,
main_prog
.
global_block
())
startup_master_param
=
startup_prog
.
global_block
().
_clone_variable
(
master_param
)
startup_p
=
startup_prog
.
global_block
().
var
(
p
.
name
)
append_cast_op
(
startup_p
,
startup_master_param
,
startup_prog
)
# cast fp16 gradients to fp32 before apply gradients
if
g
.
name
.
find
(
"layer_norm"
)
>
-
1
:
scaled_g
=
g
/
loss_scaling
master_params_grads
.
append
([
p
,
scaled_g
])
continue
master_grad
=
fluid
.
layers
.
cast
(
g
,
"float32"
)
master_grad
=
master_grad
/
loss_scaling
master_params_grads
.
append
([
master_param
,
master_grad
])
return
master_params_grads
...
...
@@ -94,4 +122,80 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
if
train_p
.
name
.
find
(
"layer_norm"
)
>
-
1
:
continue
with
main_prog
.
_optimized_guard
([
m_p_g
[
0
],
m_p_g
[
1
]]):
cast_fp32_to_fp16
(
m_p_g
[
0
],
train_p
,
main_prog
)
append_cast_op
(
m_p_g
[
0
],
train_p
,
main_prog
)
def
update_loss_scaling
(
is_overall_finite
,
prev_loss_scaling
,
num_good_steps
,
num_bad_steps
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwisw, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
0
)
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
should_incr_loss_scaling
=
fluid
.
layers
.
less_than
(
incr_every_n_steps
,
num_good_steps
+
1
)
with
fluid
.
layers
.
Switch
()
as
switch1
:
with
switch1
.
case
(
should_incr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
incr_ratio
loss_scaling_is_finite
=
fluid
.
layers
.
isfinite
(
new_loss_scaling
)
with
fluid
.
layers
.
Switch
()
as
switch2
:
with
switch2
.
case
(
loss_scaling_is_finite
):
fluid
.
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
with
switch2
.
default
():
pass
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch1
.
default
():
fluid
.
layers
.
increment
(
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch
.
default
():
should_decr_loss_scaling
=
fluid
.
layers
.
less_than
(
decr_every_n_nan_or_inf
,
num_bad_steps
+
1
)
with
fluid
.
layers
.
Switch
()
as
switch3
:
with
switch3
.
case
(
should_decr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
decr_ratio
static_loss_scaling
=
\
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
less_than_one
=
fluid
.
layers
.
less_than
(
new_loss_scaling
,
static_loss_scaling
)
with
fluid
.
layers
.
Switch
()
as
switch4
:
with
switch4
.
case
(
less_than_one
):
fluid
.
layers
.
assign
(
static_loss_scaling
,
prev_loss_scaling
)
with
switch4
.
default
():
fluid
.
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch3
.
default
():
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
increment
(
num_bad_steps
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录