Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
377f2496
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
377f2496
编写于
3月 28, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove in_tokens arguments of berts
上级
57dbe135
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
108 addition
and
190 deletion
+108
-190
demo/bert-cls/finetune_with_hub.py
demo/bert-cls/finetune_with_hub.py
+24
-21
demo/bert-cls/run_fintune_with_hub.sh
demo/bert-cls/run_fintune_with_hub.sh
+2
-2
paddle_hub/finetune/config.py
paddle_hub/finetune/config.py
+17
-5
paddle_hub/finetune/finetune.py
paddle_hub/finetune/finetune.py
+65
-73
paddle_hub/finetune/optimization.py
paddle_hub/finetune/optimization.py
+0
-89
未找到文件。
demo/bert-cls/finetune_with_hub.py
浏览文件 @
377f2496
...
...
@@ -70,8 +70,21 @@ run_type_g.add_arg("use_cuda", bool, True, "If set, use G
args
=
parser
.
parse_args
()
# yapf: enable.
def
test_hub_api
(
args
,
config
):
if
__name__
==
'__main__'
:
print_arguments
(
args
)
config
=
FinetuneConfig
(
log_interval
=
10
,
eval_interval
=
100
,
save_ckpt_interval
=
200
,
use_cuda
=
True
,
checkpoint_dir
=
"./bert_cls_ckpt"
,
learning_rate
=
args
.
learning_rate
,
num_epoch
=
args
.
epoch
,
batch_size
=
args
.
batch_size
,
max_seq_len
=
args
.
max_seq_len
,
weight_decay
=
args
.
weight_decay
,
in_tokens
=
args
.
in_tokens
,
warmup_proportion
=
args
.
warmup_proportion
)
processor
=
reader
.
ChnsenticorpProcessor
(
data_dir
=
args
.
data_dir
,
...
...
@@ -86,38 +99,28 @@ def test_hub_api(args, config):
# loading paddlehub BERT
module
=
hub
.
Module
(
module_dir
=
"./chinese_L-12_H-768_A-12.hub_module"
)
# bert's input tensor, output tensor and forward graph
# If you want to fine-tune the pretrain model parameter, please set
# trainable to True
input_dict
,
output_dict
,
train_program
=
module
.
context
(
sign_name
=
"pooled_output"
,
trainable
=
True
)
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_program
,
startup_program
):
with
fluid
.
program_guard
(
train_program
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
'int64'
)
pooled_output
=
output_dict
[
"pooled_output"
]
# setup feed list for data feeder
# Setup feed list for data feeder
# Must feed all the tensor of bert's module need
feed_list
=
[
input_dict
[
"src_ids"
].
name
,
input_dict
[
"pos_ids"
].
name
,
input_dict
[
"sent_ids"
].
name
,
input_dict
[
"input_mask"
].
name
,
label
.
name
]
# Define a classfication finetune task by PaddleHub's API
cls_task
=
hub
.
append_mlp_classifier
(
pooled_output
,
label
,
num_classes
=
num_labels
)
# Finetune and evaluate by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
hub
.
finetune_and_eval
(
cls_task
,
feed_list
,
processor
,
config
)
if
__name__
==
'__main__'
:
print_arguments
(
args
)
config
=
FinetuneConfig
(
stat_interval
=
args
.
skip_steps
,
eval_interval
=
args
.
validation_steps
,
use_cuda
=
True
,
learning_rate
=
args
.
learning_rate
,
weight_decay
=
args
.
weight_decay
,
in_tokens
=
args
.
in_tokens
,
epoch
=
args
.
epoch
,
batch_size
=
args
.
batch_size
,
max_seq_len
=
args
.
max_seq_len
,
warmup_proportion
=
args
.
warmup_proportion
)
test_hub_api
(
args
,
config
)
demo/bert-cls/run_fintune_with_hub.sh
浏览文件 @
377f2496
...
...
@@ -7,8 +7,8 @@ DATA_PATH=chnsenticorp_data
rm
-rf
$CKPT_PATH
python
-u
finetune_with_hub.py
\
--use_cuda
true
\
--batch_size
4096
\
--in_tokens
tru
e
\
--batch_size
32
\
--in_tokens
fals
e
\
--data_dir
${
DATA_PATH
}
\
--vocab_path
${
BERT_BASE_PATH
}
/vocab.txt
\
--weight_decay
0.01
\
...
...
paddle_hub/finetune/config.py
浏览文件 @
377f2496
...
...
@@ -14,8 +14,20 @@
import
collections
FinetuneConfig
=
collections
.
namedtuple
(
'FinetuneConfig'
,
[
'stat_interval'
,
'eval_interval'
,
'use_cuda'
,
'learning_rate'
,
'weight_decay'
,
'in_tokens'
,
'epoch'
,
'batch_size'
,
'max_seq_len'
,
'warmup_proportion'
])
FinetuneConfig
=
collections
.
namedtuple
(
'FinetuneConfig'
,
[
'log_interval'
,
# print training log every n step
'eval_interval'
,
# evalution the model every n steps
'save_ckpt_interval'
,
# save the model checkpoint every n steps
'use_cuda'
,
# use gpu or not
'learning_rate'
,
'checkpoint_dir'
,
# model checkpoint directory
'num_epoch'
,
# number of finetune epoch
'batch_size'
,
# for bert parameter
'max_seq_len'
,
# for bert
'weight_decay'
,
# for bert
'warmup_proportion'
,
# for bert
'in_tokens'
# for bert
])
paddle_hub/finetune/finetune.py
浏览文件 @
377f2496
...
...
@@ -30,22 +30,7 @@ def finetune_and_eval(task, feed_list, data_processor, config=None):
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
# data generator
data_generator
=
{
'train'
:
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'train'
,
epoch
=
config
.
epoch
,
shuffle
=
False
),
'test'
:
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'test'
,
shuffle
=
False
),
'dev'
:
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'dev'
,
shuffle
=
False
)
}
exe
=
fluid
.
Executor
(
place
)
# hub.finetune_and_eval start here
#TODO: to simplify
...
...
@@ -56,10 +41,10 @@ def finetune_and_eval(task, feed_list, data_processor, config=None):
num_train_examples
=
data_processor
.
get_num_examples
(
phase
=
'train'
)
if
config
.
in_tokens
:
max_train_steps
=
config
.
epoch
*
num_train_examples
//
(
max_train_steps
=
config
.
num_
epoch
*
num_train_examples
//
(
config
.
batch_size
//
config
.
max_seq_len
)
//
dev_count
else
:
max_train_steps
=
config
.
epoch
*
num_train_examples
//
config
.
batch_size
//
dev_count
max_train_steps
=
config
.
num_
epoch
*
num_train_examples
//
config
.
batch_size
//
dev_count
warmup_steps
=
int
(
max_train_steps
*
config
.
warmup_proportion
)
...
...
@@ -83,73 +68,80 @@ def finetune_and_eval(task, feed_list, data_processor, config=None):
num_example
.
name
])
place
=
fluid
.
CUDAPlace
(
0
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
# Traning block
# prepare training dataset
train_data_generator
=
data_generator
[
'train'
]
total_loss
,
total_acc
,
total_num_example
=
[],
[],
[]
step
=
0
time_begin
=
time
.
time
()
train_time_used
=
0.0
for
example
in
train_data_generator
():
step
+=
1
train_time_begin
=
time
.
time
()
np_loss
,
np_acc
,
np_num_example
=
exe
.
run
(
program
=
train_program
,
feed
=
feeder
.
feed
([
example
]),
fetch_list
=
[
loss
,
accuracy
,
num_example
])
train_time_used
+=
time
.
time
()
-
train_time_begin
for
epoch
in
range
(
1
,
config
.
num_epoch
+
1
):
print
(
"Epoch {}"
.
format
(
epoch
))
train_data_generator
=
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'train'
,
shuffle
=
False
)
for
example
in
train_data_generator
():
step
+=
1
train_time_begin
=
time
.
time
()
np_loss
,
np_acc
,
np_num_example
=
exe
.
run
(
program
=
train_program
,
feed
=
feeder
.
feed
([
example
]),
fetch_list
=
[
loss
,
accuracy
,
num_example
])
train_time_used
+=
time
.
time
()
-
train_time_begin
# Statistic Block
total_loss
.
extend
(
np_loss
*
np_num_example
)
total_acc
.
extend
(
np_acc
*
np_num_example
)
total_num_example
.
extend
(
np_num_example
)
if
step
%
config
.
log_interval
==
0
:
# get training progress
accum_num_example
=
np
.
sum
(
total_num_example
)
print
(
"step {}: loss={:.5f} acc={:.5f} [step/sec: {:.2f}]"
.
format
(
step
,
np
.
sum
(
total_loss
)
/
accum_num_example
,
np
.
sum
(
total_acc
)
/
accum_num_example
,
config
.
log_interval
/
train_time_used
))
# reset statistic variables
total_loss
,
total_acc
,
total_num_example
=
[],
[],
[]
train_time_used
=
0.0
# Evaluation block
if
step
%
config
.
eval_interval
==
0
:
test_data_generator
=
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'test'
,
shuffle
=
False
)
dev_data_generator
=
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'dev'
,
shuffle
=
False
)
evaluate
(
task
,
test_program
,
exe
,
feeder
,
dev_data_generator
)
evaluate
(
task
,
test_program
,
exe
,
feeder
,
test_data_generator
)
# Save model checkpoint
if
step
%
config
.
save_ckpt_interval
==
0
:
save_checkpoint
(
exe
,
train_program
,
step
,
config
.
checkpoint_dir
)
# finish final evaluation on testset
test_data_generator
=
data_processor
.
data_generator
(
batch_size
=
config
.
batch_size
,
phase
=
'test'
,
shuffle
=
False
)
evaluate
(
task
,
test_program
,
exe
,
feeder
,
test_data_generator
)
def
save_checkpoint
(
exe
,
train_program
,
step
,
ckpt_dir
):
#TODO: add global step variable for restore checkpoint like tensorflow
ckpt_step_dir
=
os
.
path
.
join
(
ckpt_dir
,
"step_{}"
.
format
(
step
))
fluid
.
io
.
save_persistables
(
exe
,
ckpt_step_dir
,
train_program
)
def
evaluate
(
task
,
test_program
,
exe
,
feeder
,
data_generator
):
loss
=
task
.
variable
(
"loss"
)
probs
=
task
.
variable
(
"probs"
)
accuracy
=
task
.
variable
(
"accuracy"
)
num_example
=
task
.
variable
(
"num_example"
)
# Statistic Block
total_loss
.
extend
(
np_loss
*
np_num_example
)
total_acc
.
extend
(
np_acc
*
np_num_example
)
total_num_example
.
extend
(
np_num_example
)
if
step
%
config
.
stat_interval
==
0
:
# get training progress
accum_num_example
=
np
.
sum
(
total_num_example
)
print
(
"step {}: loss={:.5f} acc={:.5f} [step/sec: {:.2f}]"
.
format
(
step
,
np
.
sum
(
total_loss
)
/
accum_num_example
,
np
.
sum
(
total_acc
)
/
accum_num_example
,
config
.
stat_interval
/
train_time_used
))
# reset statistic variables
total_loss
,
total_acc
,
total_num_example
=
[],
[],
[]
train_time_used
=
0.0
# Evaluation block
if
step
%
config
.
eval_interval
==
0
:
evaluate
(
test_program
,
exe
,
data_generator
)
if
step
%
config
.
eval_interval
==
0
:
# Final Test Block
total_loss
,
total_acc
,
total_num_example
=
[],
[],
[]
test_data_generator
=
data_generator
[
'test'
]
for
example
in
test_data_generator
():
np_loss
,
np_acc
,
np_num_example
=
exe
.
run
(
program
=
test_program
,
feed
=
feeder
.
feed
([
example
]),
fetch_list
=
[
loss
,
accuracy
,
num_example
])
total_loss
.
extend
(
np_loss
*
np_num_example
)
total_acc
.
extend
(
np_acc
*
np_num_example
)
total_num_example
.
extend
(
np_num_example
)
accum_num_example
=
np
.
sum
(
total_num_example
)
print
(
"[Final Test] loss={:.5f} acc={:.5f}"
.
format
(
np
.
sum
(
total_loss
)
/
accum_num_example
,
np
.
sum
(
total_acc
)
/
accum_num_example
))
def
evaluate
(
test_program
,
exe
,
feeder
,
data_generator
):
print
(
"Evaluation start"
)
total_loss
,
total_acc
,
total_num_example
=
[],
[],
[]
dev_data_generator
=
data_generator
[
'dev'
]
eval_step
=
0
eval_time_begin
=
time
.
time
()
for
example
in
d
ev_d
ata_generator
():
for
example
in
data_generator
():
eval_step
+=
1
np_loss
,
np_acc
,
np_num_example
=
exe
.
run
(
program
=
test_program
,
...
...
@@ -160,6 +152,6 @@ def evaluate(test_program, exe, feeder, data_generator):
total_num_example
.
extend
(
np_num_example
)
eval_time_used
=
time
.
time
()
-
eval_time_begin
accum_num_example
=
np
.
sum
(
total_num_example
)
print
(
"[
E
valuation] loss={:.5f} acc={:.5f} [step/sec: {:.2f}]"
.
format
(
print
(
"[
e
valuation] loss={:.5f} acc={:.5f} [step/sec: {:.2f}]"
.
format
(
np
.
sum
(
total_loss
)
/
accum_num_example
,
np
.
sum
(
total_acc
)
/
accum_num_example
,
eval_step
/
eval_time_used
))
paddle_hub/finetune/optimization.py
浏览文件 @
377f2496
...
...
@@ -49,95 +49,6 @@ def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
return
lr
def
optimization
(
loss
,
warmup_steps
,
num_train_steps
,
learning_rate
,
train_program
,
startup_prog
,
weight_decay
,
scheduler
=
'linear_warmup_decay'
,
use_fp16
=
False
,
loss_scaling
=
1.0
):
if
warmup_steps
>
0
:
if
scheduler
==
'noam_decay'
:
scheduled_lr
=
fluid
.
layers
.
learning_rate_scheduler
\
.
noam_decay
(
1
/
(
warmup_steps
*
(
learning_rate
**
2
)),
warmup_steps
)
elif
scheduler
==
'linear_warmup_decay'
:
scheduled_lr
=
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
)
else
:
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
else
:
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
)
scheduled_lr
=
learning_rate
clip_norm_thres
=
1.0
# When using mixed precision training, scale the gradient clip threshold
# by loss_scaling
if
use_fp16
and
loss_scaling
>
1.0
:
clip_norm_thres
*=
loss_scaling
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
clip_norm_thres
))
def
exclude_from_weight_decay
(
name
):
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
param_list
=
dict
()
if
use_fp16
:
param_grads
=
optimizer
.
backward
(
loss
)
master_param_grads
=
create_master_params_grads
(
param_grads
,
train_program
,
startup_prog
,
loss_scaling
)
for
param
,
_
in
master_param_grads
:
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
optimizer
.
apply_gradients
(
master_param_grads
)
if
weight_decay
>
0
:
for
param
,
grad
in
master_param_grads
:
if
exclude_from_weight_decay
(
param
.
name
.
rstrip
(
".master"
)):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
master_param_to_train_param
(
master_param_grads
,
param_grads
,
train_program
)
else
:
for
param
in
train_program
.
global_block
().
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
if
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
exclude_from_weight_decay
(
param
.
name
):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
return
scheduled_lr
def
bert_optimization
(
loss
,
warmup_steps
,
num_train_steps
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录