Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
7d7dadd1
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7d7dadd1
编写于
5月 27, 2020
作者:
M
Meiyim
提交者:
GitHub
5月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix
#470
(#471)
* Update README.md * fix grad_clip * fix distill * up distill
上级
241c0282
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
42 addition
and
38 deletion
+42
-38
demo/finetune_mrc_dygraph.py
demo/finetune_mrc_dygraph.py
+2
-2
demo/finetune_ner_dygraph.py
demo/finetune_ner_dygraph.py
+0
-1
demo/finetune_sentiment_analysis_dygraph.py
demo/finetune_sentiment_analysis_dygraph.py
+5
-3
distill/README.md
distill/README.md
+1
-7
distill/distill.py
distill/distill.py
+34
-25
未找到文件。
demo/finetune_mrc_dygraph.py
浏览文件 @
7d7dadd1
...
...
@@ -82,8 +82,8 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz
model
=
D
.
parallel
.
DataParallel
(
model
,
ctx
)
max_steps
=
len
(
train_features
)
*
args
.
epoch
//
args
.
bsz
opt
=
AdamW
(
learning_rate
=
args
.
lr
,
parameter_list
=
model
.
parameters
(),
weight_decay
=
args
.
wd
)
g_clip
=
F
.
clip
.
GradientClipByGlobalNorm
(
1.0
)
#experimental
opt
=
AdamW
(
learning_rate
=
args
.
lr
,
parameter_list
=
model
.
parameters
(),
weight_decay
=
args
.
wd
,
grad_clip
=
g_clip
)
train_dataset
=
train_dataset
\
.
repeat
()
\
...
...
@@ -97,7 +97,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz
scaled_loss
=
model
.
scale_loss
(
loss
)
scaled_loss
.
backward
()
model
.
apply_collective_grads
()
opt
.
minimize
(
scaled_loss
,
grad_clip
=
g_clip
)
opt
.
minimize
(
scaled_loss
)
model
.
clear_gradients
()
if
D
.
parallel
.
Env
().
dev_id
==
0
and
step
%
10
==
0
:
log
.
debug
(
'[step %d] train loss %.5f lr %.3e'
%
(
step
,
loss
.
numpy
(),
opt
.
current_step_lr
()))
...
...
demo/finetune_ner_dygraph.py
浏览文件 @
7d7dadd1
...
...
@@ -26,7 +26,6 @@ from functools import reduce, partial
import
numpy
as
np
import
multiprocessing
import
pickle
import
jieba
import
logging
from
sklearn.metrics
import
f1_score
...
...
demo/finetune_sentiment_analysis_dygraph.py
浏览文件 @
7d7dadd1
...
...
@@ -95,12 +95,14 @@ if __name__ == '__main__':
dev_ds
.
data_shapes
=
shapes
dev_ds
.
data_types
=
types
g_clip
=
F
.
clip
.
GradientClipByGlobalNorm
(
1.0
)
#experimental
opt
=
AdamW
(
learning_rate
=
LinearDecay
(
args
.
lr
,
int
(
args
.
warmup_proportion
*
args
.
max_steps
),
args
.
max_steps
),
parameter_list
=
model
.
parameters
(),
weight_decay
=
args
.
wd
)
g_clip
=
F
.
clip
.
GradientClipByGlobalNorm
(
1.0
)
#experimental
weight_decay
=
args
.
wd
,
grad_clip
=
g_clip
)
for
epoch
in
range
(
args
.
epoch
):
for
step
,
d
in
enumerate
(
tqdm
(
train_ds
.
start
(
place
),
desc
=
'training'
)):
ids
,
sids
,
label
=
d
...
...
@@ -108,7 +110,7 @@ if __name__ == '__main__':
loss
.
backward
()
if
step
%
10
==
0
:
log
.
debug
(
'train loss %.5f lr %.3e'
%
(
loss
.
numpy
(),
opt
.
current_step_lr
()))
opt
.
minimize
(
loss
,
grad_clip
=
g_clip
)
opt
.
minimize
(
loss
)
model
.
clear_gradients
()
if
step
%
100
==
0
:
acc
=
[]
...
...
distill/README.md
浏览文件 @
7d7dadd1
...
...
@@ -5,7 +5,6 @@
*
[
效果验证
](
#效果验证
)
*
[
Case#1 用户提供“无标注数据”
](
#case1
)
*
[
Case#2 用户未提供“无标注数据”
](
#case2
)
*
[
FAQ
](
#faq
)
# ERNIE Slim 数据蒸馏
在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。
...
...
@@ -37,7 +36,7 @@
# 使用教程
我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从
[
这里
](
https://ernie
.bj.bcebos.com/distill_data
.tar.gz
)
下载。即可执行下面的脚本开始蒸馏。
我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从
[
这里
](
https://ernie
-github.cdn.bcebos.com/data-chnsenticorp-distill
.tar.gz
)
下载。即可执行下面的脚本开始蒸馏。
```
shell
python ./distill/distill.py
...
...
@@ -64,8 +63,3 @@ python ./distill/distill.py
|非ERNIE基线(LSTM)|91.2%|
|
**+ 数据蒸馏**
|93.9%|
# FAQ
### FQA1: 预测同时蒸馏报错:`Client call failed`
终端打印的错误是client的日志,server端的日志在前面。一般来说可能是server显存超限导致。这种时候需要在student模型finetune的脚本中使用
`--server_batch_size `
显示控制请求服务的batch大小。
distill/distill.py
浏览文件 @
7d7dadd1
...
...
@@ -30,12 +30,13 @@ from ernie.optimization import AdamW, LinearDecay
# 本例子采用chnsenticorp中文情感识别任务作为示范;并且事先通过数据增强扩充了蒸馏所需的无监督数据
#
#
请从“”下载数据;并数据
存放在 ./chnsenticorp-data/
#
下载数据;并
存放在 ./chnsenticorp-data/
# 数据分为3列:原文;空格切词;情感标签
# 其中第一列为ERNIE的输入;第二列为BoW词袋模型的输入
# 事先统计好的BoW 词典在 ./chnsenticorp-data/vocab.bow.txt
# 定义finetune teacher模型所需要的超参数
DATA_DIR
=
'./chnsenticorp-data/'
SEQLEN
=
256
BATCH
=
32
EPOCH
=
10
...
...
@@ -43,7 +44,7 @@ LR=5e-5
tokenizer
=
ErnieTokenizer
.
from_pretrained
(
'ernie-1.0'
)
student_vocab
=
{
i
.
strip
():
l
for
l
,
i
in
enumerate
(
open
(
'./chnsenticorp-data/vocab.bow.txt'
).
readlines
())}
student_vocab
=
{
i
.
strip
():
l
for
l
,
i
in
enumerate
(
open
(
os
.
path
.
join
(
DATA_DIR
,
'vocab.bow.txt'
)
).
readlines
())}
def
space_tokenizer
(
i
):
return
i
.
decode
(
'utf8'
).
split
()
...
...
@@ -63,11 +64,17 @@ def map_fn(seg_a, seg_a_student, label):
return
seg_a_student
,
sentence
,
segments
,
label
train_ds
=
feature_column
.
build_dataset
(
'train'
,
data_dir
=
'./chnsenticorp-data/train/'
,
shuffle
=
True
,
repeat
=
False
,
use_gz
=
False
)
.
map
(
map_fn
)
.
padded_batch
(
BATCH
,)
train_ds
=
feature_column
.
build_dataset
(
'train'
,
data_dir
=
os
.
path
.
join
(
DATA_DIR
,
'train/'
),
shuffle
=
True
,
repeat
=
False
,
use_gz
=
False
)
\
.
map
(
map_fn
)
\
.
padded_batch
(
BATCH
)
train_ds_unlabel
=
feature_column
.
build_dataset
(
'train-da'
,
data_dir
=
'./chnsenticorp-data/train-data-augmented/'
,
shuffle
=
True
,
repeat
=
False
,
use_gz
=
False
)
.
map
(
map_fn
)
.
padded_batch
(
BATCH
,)
train_ds_unlabel
=
feature_column
.
build_dataset
(
'train-da'
,
data_dir
=
os
.
path
.
join
(
DATA_DIR
,
'train-data-augmented/'
),
shuffle
=
True
,
repeat
=
False
,
use_gz
=
False
)
\
.
map
(
map_fn
)
\
.
padded_batch
(
BATCH
)
dev_ds
=
feature_column
.
build_dataset
(
'dev'
,
data_dir
=
'./chnsenticorp-data/dev/'
,
shuffle
=
False
,
repeat
=
False
,
use_gz
=
False
)
.
map
(
map_fn
)
.
padded_batch
(
BATCH
,)
dev_ds
=
feature_column
.
build_dataset
(
'dev'
,
data_dir
=
os
.
path
.
join
(
DATA_DIR
,
'dev/'
),
shuffle
=
False
,
repeat
=
False
,
use_gz
=
False
)
\
.
map
(
map_fn
)
\
.
padded_batch
(
BATCH
,)
shapes
=
([
-
1
,
SEQLEN
],[
-
1
,
SEQLEN
],
[
-
1
,
SEQLEN
],
[
-
1
])
types
=
(
'int64'
,
'int64'
,
'int64'
,
'int64'
)
...
...
@@ -99,15 +106,15 @@ def evaluate_teacher(model, dataset):
teacher_model
=
ErnieModelForSequenceClassification
.
from_pretrained
(
'ernie-1.0'
,
num_labels
=
2
)
teacher_model
.
train
()
if
not
os
.
path
.
exists
(
'./teacher_model.pdparams'
):
opt
=
AdamW
(
learning_rate
=
LinearDecay
(
LR
,
9600
*
EPOCH
*
0.1
/
BATCH
,
9600
*
EPOCH
/
BATCH
),
parameter_list
=
teacher_model
.
parameters
(),
weight_decay
=
0.01
)
g_clip
=
F
.
clip
.
GradientClipByGlobalNorm
(
1.0
)
opt
=
AdamW
(
learning_rate
=
LinearDecay
(
LR
,
9600
*
EPOCH
*
0.1
/
BATCH
,
9600
*
EPOCH
/
BATCH
),
parameter_list
=
teacher_model
.
parameters
(),
weight_decay
=
0.01
,
grad_clip
=
g_clip
)
for
epoch
in
range
(
EPOCH
):
for
step
,
(
ids_student
,
ids
,
sids
,
labels
)
in
enumerate
(
train_ds
.
start
(
place
)):
loss
,
logits
=
teacher_model
(
ids
,
labels
=
labels
)
loss
.
backward
()
if
step
%
10
==
0
:
print
(
'[step %03d] teacher train loss %.5f lr %.3e'
%
(
step
,
loss
.
numpy
(),
opt
.
current_step_lr
()))
opt
.
minimize
(
loss
,
grad_clip
=
g_clip
)
opt
.
minimize
(
loss
)
teacher_model
.
clear_gradients
()
if
step
%
100
==
0
:
f1
=
evaluate_teacher
(
teacher_model
,
dev_ds
)
...
...
@@ -199,32 +206,34 @@ def KL(pred, target):
teacher_model
.
eval
()
model
=
BOW
()
opt
=
AdamW
(
learning_rate
=
LR
,
parameter_list
=
model
.
parameters
(),
weight_decay
=
0.01
)
g_clip
=
F
.
clip
.
GradientClipByGlobalNorm
(
1.0
)
#experimental
opt
=
AdamW
(
learning_rate
=
LR
,
parameter_list
=
model
.
parameters
(),
weight_decay
=
0.01
,
grad_clip
=
g_clip
)
model
.
train
()
for
epoch
in
range
(
EPOCH
):
for
step
,
(
ids_student
,
ids
,
sids
,
_
)
in
enumerate
(
train_ds
.
start
(
place
)):
for
step
,
(
ids_student
,
ids
,
sids
,
label
)
in
enumerate
(
train_ds
.
start
(
place
)):
_
,
logits_t
=
teacher_model
(
ids
,
sids
)
# teacher 模型输出logits
logits_t
.
stop_gradient
=
True
_
,
logits_s
=
model
(
ids_student
)
# student 模型输出logits
loss
=
KL
(
logits_s
,
logits_t
)
# 由KL divergence度量两个分布的距离
loss_ce
,
_
=
model
(
ids_student
,
labels
=
label
)
loss_kd
=
KL
(
logits_s
,
logits_t
)
# 由KL divergence度量两个分布的距离
loss
=
loss_ce
+
loss_kd
loss
.
backward
()
if
step
%
10
==
0
:
print
(
'[step %03d]
无监督
train loss %.5f lr %.3e'
%
(
step
,
loss
.
numpy
(),
opt
.
current_step_lr
()))
opt
.
minimize
(
loss
,
grad_clip
=
g_clip
)
print
(
'[step %03d]
distill
train loss %.5f lr %.3e'
%
(
step
,
loss
.
numpy
(),
opt
.
current_step_lr
()))
opt
.
minimize
(
loss
)
model
.
clear_gradients
()
f1
=
evaluate_student
(
model
,
dev_ds
)
print
(
'f1 %.5f'
%
f1
)
for
step
,
(
ids_student
,
ids
,
sids
,
label
)
in
enumerate
(
train_ds
.
start
(
place
)):
loss
,
_
=
model
(
ids_student
,
labels
=
label
)
loss
.
backward
(
)
if
step
%
10
==
0
:
print
(
'[step %03d] 监督 train loss %.5f lr %.3e'
%
(
step
,
loss
.
numpy
(),
opt
.
current_step_lr
()))
opt
.
minimize
(
loss
,
grad_clip
=
g_clip
)
model
.
clear_gradients
(
)
f1
=
evaluate_student
(
model
,
dev_ds
)
print
(
'f1 %.5f'
%
f1
)
print
(
'
student
f1 %.5f'
%
f1
)
# 最后再加一轮hard label训练巩固结果
for
step
,
(
ids_student
,
ids
,
sids
,
label
)
in
enumerate
(
train_ds
.
start
(
place
)):
loss
,
_
=
model
(
ids_student
,
labels
=
label
)
loss
.
backward
()
if
step
%
10
==
0
:
print
(
'[step %03d] train loss %.5f lr %.3e'
%
(
step
,
loss
.
numpy
(),
opt
.
current_step_lr
())
)
opt
.
minimize
(
loss
)
model
.
clear_gradients
()
f1
=
evaluate_student
(
model
,
dev_ds
)
print
(
'final f1 %.5f'
%
f1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录