Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
cce162f8
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 2 年 前同步成功
通知
115
Star
5997
Fork
1271
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cce162f8
编写于
9月 29, 2019
作者:
C
chenxuyi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
+ distill
上级
8c1d6e85
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
1133 addition
and
0 deletion
+1133
-0
distill/README.md
distill/README.md
+135
-0
distill/distill_chnsentocorp.py
distill/distill_chnsentocorp.py
+193
-0
distill/distill_chnsentocorp_with_propeller_server.py
distill/distill_chnsentocorp_with_propeller_server.py
+223
-0
distill/finetune_chnsenticorp.py
distill/finetune_chnsenticorp.py
+223
-0
distill/script/distill_chnsenticorp.sh
distill/script/distill_chnsenticorp.sh
+103
-0
distill/script/distill_chnsenticorp_with_propeller_server.sh
distill/script/distill_chnsenticorp_with_propeller_server.sh
+85
-0
utils/data.py
utils/data.py
+171
-0
未找到文件。
distill/README.md
0 → 100644
浏览文件 @
cce162f8
*
[
ERNIE Slim 数据蒸馏
](
#ernie-slim-数据蒸馏
)
*
[
ERNIE数据蒸馏三步
](
#ernie数据蒸馏三步
)
*
[
数据增强
](
#数据增强
)
*
[
使用教程
](
#使用教程
)
*
[
离线蒸馏
](
#离线蒸馏
)
*
[
在线蒸馏
](
#在线蒸馏
)
*
[
效果验证
](
#效果验证
)
*
[
Case#1 用户提供“无标注数据”
](
#case1
)
*
[
Case#2 用户未提供“无标注数据”
](
#case2
)
*
[
FAQ
](
#faq
)
# ERNIE Slim 数据蒸馏
在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。
<img
src=
"http://agroup-bos.cdn.bcebos.com/ae16a29d6a334c74107cebcf56bc2419d385b364"
title=
"ERNIE数据蒸馏示意图"
width=
"900"
>
因此,如上图所示,我们基于
[
数据蒸馏技术
](
https://arxiv.org/pdf/1712.04440.pdf
)
构建了
**ERNIE Slim数据蒸馏系统**
。它的原理是通过数据作为桥梁,将ERNIE模型的知识迁移至小模型,以达到损失很小的效果却能达到上千倍的预测速度提升的效果。
### ERNIE数据蒸馏三步
-
**Step 1**
. 使用ERNIE模型对输入标注数据对进行fine-tune,得到Teacher Model
-
**Step 2**
. 使用ERNIE Service对以下无监督数据进行预测:
1.
用户提供的大规模无标注数据,需与标注数据同源
2.
对标注数据进行数据增强,具体增强策略见下节
3.
对无标注数据和数据增强数据进行一定比例混合
-
**Step 3.**
使用步骤2的数据训练出Student Model
### 数据增强
目前采用三种
[
数据增强策略
](
https://arxiv.org/pdf/1903.12136.pdf
)
策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:
1.
添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签
2.
同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集钟随机一个同词性的词
3.
N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值
# 使用教程
我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从
[
这里
](
https://ernie.bj.bcebos.com/distill_data.tar.gz
)
下载。将下载的
`distill`
文件夹放入
`${TASK_DATA_PATH}`
后即可执行下面的脚本开始蒸馏。
### 离线蒸馏
离线蒸馏指的是先通过训练好的ERNIE模型预测出无监督数据的label,然后student模型去学习这些label。只需执行
```
script
sh ./distill/script/distill_chnsenticorp.sh
```
即可开始离线蒸馏。
该脚本会进行前述的三步:1. 在任务数据上Fine-tune。 2. 加载Fine-tune好的模型对增强数据进行打分。 3.使用Student模型进行训练。脚本采用hard-label蒸馏,在第二步中将会直接预测出ERNIE标注的label。
该脚本涉及两个python文件:
`./distill/finetune_chnsenticorp.py`
负责finetune以及预测teacher模型,
`distill/distill_chnsentocorp.py`
负责student模型的训练。事先构造好的增强数据放在
`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug`
在脚本的第二步中,使用
`--do_predict`
参数进入预测模式:
```
script
cat ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 |python3 -u ./distill/finetune_chnsenticorp.py \
--do_predict \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
...
```
脚本从标准输入获取明文输入,并将打分输出到标准输出。用这种方式对数据增强后的无监督训练预料进行标注。最终的标注结果放在
`prediction_output/part.0`
文件中。标注结果包含两列, 第一列为明文,第二列为标注label。
在第三步开始student模型的训练:
```
script
python3 ./distill/distill_chnsentocorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--unsupervise_data_dir ./prediction_output/ \
--max_seqlen 128 \
...
```
训练流程与第一步相同,
`--data_dir`
指定的监督数据,
`--unsupervise_data_dir`
指定ERNIE标注数据。Student模型是一个简单的BOW模型,其定义位于
`distill/distill_chnsentocorp.py`
。用户只需改写其中的model部分即可实现定制蒸馏模型。
如果用户已经拥有了无监督数据,则可以将无监督数据放入
`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug`
即可。
### 在线蒸馏
考虑到在某些场景下,无监督数据过大导致预测过程十分耗时,或者ERNIE预测出的分布过大而无法预先存放在磁盘中。针对这种场景我们提出一种
**在线蒸馏**
方案。采用
`propeller`
进行fine-tune并使用
`BestInferenceModelExporter`
后,
`propeller`
会自动将指标最好的模型保存为paddle inference model格式,随后启动一个预测服务。Student模型在训练的同时,实时地访问这个服务来获得ERNIE的预测打分。只需执行
```
sh ./distill/script/distill_chnsenticorp_with_propeller_server.sh
```
即可完成上述流程。
流程包含3步:1. finetune ERNIE模型。2. 取指标最好的ERNIE模型启动
`propeller`
服务。 3.在student模型的训练过程中访问服务获取teacher模型的标注。
此流程涉及两个python文件:
`distill/finetune_chnsenticorp.py`
与
`distill/distill_chnsentocorp_with_propeller_server.py`
。其中第一步与离线蒸馏中的用法完全一样。
第二步中使用
```
script
python3 -m propeller.tools.start_server -p 8113 -m ${teacher_dir}/best/inference/ &
```
启动一个ernie预测服务
第三步开始student模型的同步训练:
```
script
python3 ./distill/distill_chnsentocorp_with_propeller_server.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--teacher_vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--teacher_max_seqlen 128 \
--server_batch_size 64 \
--teacher_host tcp://localhost:8113 \
--num_coroutine 10
```
该脚本将
`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug`
目录下的增强数据进行切字并请求
`propeller`
服务。
`--num_coroutine`
指定了请求的并发数,
`--teacher_host`
指定了服务的端口和IP,
`--server_batch_size`
指定了请求的batch_size,在实际的请求中每个batch的数据会拆分成若干个
`--server_batch_size`
大小的数据去请求服务。
# 效果验证
我们将实际应用场景分类为两种:
### Case#1 用户提供“无标注数据”<a name="case1"></a>
|模型 | 评论低质识别【分类
\|
ACC】 | 中文情感【分类
\|
ACC】 |问题识别【分类
\|
ACC】|搜索问答匹配【匹配
\|
正逆序】|
|---|---|---|---|---|
|ERNIE-Finetune | 90.6% | 96.2% | 97.5% | 4.25 |
|非ERNIE基线(BOW)| 80.8% | 94.7% | 93.0% | 1.83 |
|
**+ 数据蒸馏**
| 87.2% | 95.8% | 96.3% | 3.30 |
### Case#2 用户未提供“无标注数据”(通过数据增强生成数据)<a name="case2"></a>
|模型 |ChnSentiCorp |
|---|---|
|ERNIE-Finetune |95.4% |
|非ERNIE基线(BOW)|90.1%|
|
**+ 数据蒸馏**
|91.4%|
|非ERNIE基线(LSTM)|91.2%|
|
**+ 数据蒸馏**
|93.9%|
# FAQ
### FQA1: 预测同时蒸馏报错:`Client call failed`
终端打印的错误是client的日志,server端的日志在前面。一般来说可能是server显存超限导致。这种时候需要在student模型finetune的脚本中使用
`--server_batch_size `
显示控制请求服务的batch大小。
distill/distill_chnsentocorp.py
0 → 100644
浏览文件 @
cce162f8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
time
from
random
import
random
from
functools
import
reduce
,
partial
import
logging
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
propeller
import
log
import
propeller.paddle
as
propeller
from
propeller.paddle.data
import
Dataset
from
optimization
import
optimization
import
utils.data
log
.
setLevel
(
logging
.
DEBUG
)
class
ClassificationBowModel
(
propeller
.
train
.
Model
):
"""propeller Model wraper for paddle-ERNIE """
def
__init__
(
self
,
config
,
mode
,
run_config
):
self
.
config
=
config
self
.
mode
=
mode
self
.
run_config
=
run_config
self
.
_param_initializer
=
F
.
initializer
.
TruncatedNormal
(
scale
=
config
.
initializer_range
)
self
.
_emb_dtype
=
"float32"
self
.
_word_emb_name
=
"word_embedding"
def
forward
(
self
,
features
):
text_ids_a
,
=
features
def
bow
(
ids
):
embed
=
L
.
embedding
(
input
=
ids
,
size
=
[
self
.
config
.
vocab_size
,
self
.
config
.
emb_size
],
dtype
=
self
.
_emb_dtype
,
param_attr
=
F
.
ParamAttr
(
name
=
self
.
_word_emb_name
,
initializer
=
self
.
_param_initializer
),
is_sparse
=
False
)
zero
=
L
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
0
)
pad
=
L
.
cast
(
L
.
logical_not
(
L
.
equal
(
ids
,
zero
)),
'float32'
)
sumed
=
L
.
reduce_sum
(
embed
*
pad
,
dim
=
1
)
sumed
=
L
.
softsign
(
sumed
)
return
sumed
sumed
=
bow
(
text_ids_a
)
fced
=
L
.
fc
(
input
=
sumed
,
size
=
self
.
config
.
emb_size
,
act
=
'tanh'
,
param_attr
=
F
.
ParamAttr
(
name
=
"middle_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"middle_fc.b_0"
)
logits
=
L
.
fc
(
input
=
fced
,
size
=
self
.
config
.
num_label
,
act
=
None
,
param_attr
=
F
.
ParamAttr
(
name
=
"pooler_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"pooler_fc.b_0"
)
if
self
.
mode
is
propeller
.
RunMode
.
PREDICT
:
probs
=
L
.
softmax
(
logits
)
return
probs
else
:
return
logits
def
loss
(
self
,
predictions
,
labels
):
labels
=
L
.
softmax
(
labels
)
loss
=
L
.
softmax_with_cross_entropy
(
predictions
,
labels
,
soft_label
=
True
)
loss
=
L
.
mean
(
loss
)
return
loss
def
backward
(
self
,
loss
):
scheduled_lr
,
_
=
optimization
(
loss
=
loss
,
warmup_steps
=
int
(
self
.
run_config
.
max_steps
*
self
.
config
.
warmup_proportion
),
num_train_steps
=
self
.
run_config
.
max_steps
,
learning_rate
=
self
.
config
.
learning_rate
,
train_program
=
F
.
default_main_program
(),
startup_prog
=
F
.
default_startup_program
(),
weight_decay
=
self
.
config
.
weight_decay
,
scheduler
=
"linear_warmup_decay"
,)
propeller
.
summary
.
scalar
(
'lr'
,
scheduled_lr
)
def
metrics
(
self
,
predictions
,
labels
):
predictions
=
L
.
argmax
(
predictions
,
axis
=
1
)
labels
=
L
.
argmax
(
labels
,
axis
=
1
)
#predictions = L.unsqueeze(predictions, axes=[1])
acc
=
propeller
.
metrics
.
Acc
(
labels
,
predictions
)
#auc = propeller.metrics.Auc(labels, predictions)
return
{
'acc'
:
acc
}
if
__name__
==
'__main__'
:
parser
=
propeller
.
ArgumentParser
(
'DAN model with Paddle'
)
parser
.
add_argument
(
'--max_seqlen'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--vocab_file'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--unsupervise_data_dir'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
)
args
=
parser
.
parse_args
()
run_config
=
propeller
.
parse_runconfig
(
args
)
hparams
=
propeller
.
parse_hparam
(
args
)
vocab
=
{
j
.
strip
().
split
(
b
'
\t
'
)[
0
].
decode
(
'utf8'
):
i
for
i
,
j
in
enumerate
(
open
(
args
.
vocab_file
,
'rb'
))}
unk_id
=
vocab
[
'[UNK]'
]
char_tokenizer
=
utils
.
data
.
CharTokenizer
(
vocab
.
keys
())
space_tokenizer
=
utils
.
data
.
SpaceTokenizer
(
vocab
.
keys
())
supervise_feature_column
=
propeller
.
data
.
FeatureColumns
([
propeller
.
data
.
TextColumn
(
'text_a'
,
unk_id
=
unk_id
,
vocab_dict
=
vocab
,
tokenizer
=
space_tokenizer
),
propeller
.
data
.
LabelColumn
(
'label'
),
])
def
before
(
text_a
,
label
):
sentence_a
=
text_a
[:
args
.
max_seqlen
]
return
sentence_a
,
label
def
after
(
sentence_a
,
label
):
batch_size
=
sentence_a
.
shape
[
0
]
onehot_label
=
np
.
zeros
([
batch_size
,
hparams
.
num_label
],
dtype
=
np
.
float32
)
onehot_label
[
np
.
arange
(
batch_size
),
label
]
=
9999.
sentence_a
,
=
utils
.
data
.
expand_dims
(
sentence_a
)
return
sentence_a
,
onehot_label
train_ds
=
supervise_feature_column
.
build_dataset
(
'train'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'train'
),
shuffle
=
True
,
repeat
=
True
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
))
\
.
map
(
after
)
\
unsup_train_ds
=
supervise_feature_column
.
build_dataset
(
'unsup_train'
,
data_dir
=
args
.
unsupervise_data_dir
,
shuffle
=
True
,
repeat
=
True
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
))
\
.
map
(
after
)
dev_ds
=
supervise_feature_column
.
build_dataset
(
'dev'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'dev'
),
shuffle
=
False
,
repeat
=
False
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
))
\
.
map
(
after
)
train_ds
=
utils
.
data
.
interleave
(
train_ds
,
unsup_train_ds
)
shapes
=
([
-
1
,
args
.
max_seqlen
,
1
],
[
-
1
,
hparams
.
num_label
])
types
=
(
'int64'
,
'float32'
)
train_ds
.
data_shapes
=
shapes
train_ds
.
data_types
=
types
dev_ds
.
data_shapes
=
shapes
dev_ds
.
data_types
=
types
'''
from tqdm import tqdm
for slots in tqdm(train_ds):
pass
'''
best_exporter
=
propeller
.
train
.
exporter
.
BestExporter
(
os
.
path
.
join
(
run_config
.
model_dir
,
'best'
),
cmp_fn
=
lambda
old
,
new
:
new
[
'dev'
][
'acc'
]
>
old
[
'dev'
][
'acc'
])
propeller
.
train
.
train_and_eval
(
model_class_or_model_fn
=
ClassificationBowModel
,
params
=
hparams
,
run_config
=
run_config
,
train_dataset
=
train_ds
,
eval_dataset
=
{
'dev'
:
dev_ds
},
exporters
=
[
best_exporter
])
print
(
'dev_acc3
\t
%.5f'
%
(
best_exporter
.
_best
[
'dev'
][
'acc'
]))
distill/distill_chnsentocorp_with_propeller_server.py
0 → 100644
浏览文件 @
cce162f8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
time
from
random
import
random
from
functools
import
reduce
,
partial
import
logging
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
propeller
import
log
import
propeller.paddle
as
propeller
from
propeller.paddle.data
import
Dataset
from
propeller.service.client
import
InferenceClient
from
optimization
import
optimization
import
utils.data
log
.
setLevel
(
logging
.
DEBUG
)
class
ClassificationBowModel
(
propeller
.
train
.
Model
):
"""propeller Model wraper for paddle-ERNIE """
def
__init__
(
self
,
config
,
mode
,
run_config
):
self
.
config
=
config
self
.
mode
=
mode
self
.
run_config
=
run_config
self
.
_param_initializer
=
F
.
initializer
.
TruncatedNormal
(
scale
=
config
.
initializer_range
)
self
.
_emb_dtype
=
"float32"
self
.
_word_emb_name
=
"word_embedding"
def
forward
(
self
,
features
):
text_ids_a
,
=
features
def
bow
(
ids
):
embed
=
L
.
embedding
(
input
=
ids
,
size
=
[
self
.
config
.
vocab_size
,
self
.
config
.
emb_size
],
dtype
=
self
.
_emb_dtype
,
param_attr
=
F
.
ParamAttr
(
name
=
self
.
_word_emb_name
,
initializer
=
self
.
_param_initializer
),
is_sparse
=
False
)
zero
=
L
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
0
)
pad
=
L
.
cast
(
L
.
logical_not
(
L
.
equal
(
ids
,
zero
)),
'float32'
)
sumed
=
L
.
reduce_sum
(
embed
*
pad
,
dim
=
1
)
sumed
=
L
.
softsign
(
sumed
)
return
sumed
sumed
=
bow
(
text_ids_a
)
fced
=
L
.
fc
(
input
=
sumed
,
size
=
self
.
config
.
emb_size
,
act
=
'tanh'
,
param_attr
=
F
.
ParamAttr
(
name
=
"middle_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"middle_fc.b_0"
)
logits
=
L
.
fc
(
input
=
fced
,
size
=
self
.
config
.
num_label
,
act
=
None
,
param_attr
=
F
.
ParamAttr
(
name
=
"pooler_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"pooler_fc.b_0"
)
if
self
.
mode
is
propeller
.
RunMode
.
PREDICT
:
probs
=
L
.
softmax
(
logits
)
return
probs
else
:
return
logits
def
loss
(
self
,
predictions
,
labels
):
labels
=
L
.
softmax
(
labels
)
loss
=
L
.
softmax_with_cross_entropy
(
predictions
,
labels
,
soft_label
=
True
)
loss
=
L
.
mean
(
loss
)
return
loss
def
backward
(
self
,
loss
):
scheduled_lr
,
_
=
optimization
(
loss
=
loss
,
warmup_steps
=
int
(
self
.
run_config
.
max_steps
*
self
.
config
.
warmup_proportion
),
num_train_steps
=
self
.
run_config
.
max_steps
,
learning_rate
=
self
.
config
.
learning_rate
,
train_program
=
F
.
default_main_program
(),
startup_prog
=
F
.
default_startup_program
(),
weight_decay
=
self
.
config
.
weight_decay
,
scheduler
=
"linear_warmup_decay"
,)
propeller
.
summary
.
scalar
(
'lr'
,
scheduled_lr
)
def
metrics
(
self
,
predictions
,
labels
):
predictions
=
L
.
argmax
(
predictions
,
axis
=
1
)
labels
=
L
.
argmax
(
labels
,
axis
=
1
)
#predictions = L.unsqueeze(predictions, axes=[1])
acc
=
propeller
.
metrics
.
Acc
(
labels
,
predictions
)
#auc = propeller.metrics.Auc(labels, predictions)
return
{
'acc'
:
acc
}
if
__name__
==
'__main__'
:
parser
=
propeller
.
ArgumentParser
(
'DAN model with Paddle'
)
parser
.
add_argument
(
'--max_seqlen'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--vocab_file'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--teacher_vocab_file'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--teacher_max_seqlen'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
)
parser
.
add_argument
(
'--server_batch_size'
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
'--num_coroutine'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--teacher_host'
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
run_config
=
propeller
.
parse_runconfig
(
args
)
hparams
=
propeller
.
parse_hparam
(
args
)
teacher_vocab
=
{
j
.
strip
().
split
(
b
'
\t
'
)[
0
].
decode
(
'utf8'
):
i
for
i
,
j
in
enumerate
(
open
(
args
.
teacher_vocab_file
,
'rb'
))}
vocab
=
{
j
.
strip
().
split
(
b
'
\t
'
)[
0
].
decode
(
'utf8'
):
i
for
i
,
j
in
enumerate
(
open
(
args
.
vocab_file
,
'rb'
))}
teacher_sep_id
=
teacher_vocab
[
'[SEP]'
]
teacher_cls_id
=
teacher_vocab
[
'[CLS]'
]
teacher_unk_id
=
teacher_vocab
[
'[UNK]'
]
unk_id
=
vocab
[
'[UNK]'
]
char_tokenizer
=
utils
.
data
.
CharTokenizer
(
vocab
.
keys
())
space_tokenizer
=
utils
.
data
.
SpaceTokenizer
(
vocab
.
keys
())
supervise_feature_column
=
propeller
.
data
.
FeatureColumns
([
propeller
.
data
.
TextColumn
(
'text_a'
,
unk_id
=
unk_id
,
vocab_dict
=
vocab
,
tokenizer
=
space_tokenizer
),
propeller
.
data
.
LabelColumn
(
'label'
),
])
unsupervise_feature_column
=
propeller
.
data
.
FeatureColumns
([
propeller
.
data
.
TextColumn
(
'text_a'
,
unk_id
=
unk_id
,
vocab_dict
=
vocab
,
tokenizer
=
space_tokenizer
),
propeller
.
data
.
TextColumn
(
'teacher_text_a'
,
unk_id
=
teacher_unk_id
,
vocab_dict
=
teacher_vocab
,
tokenizer
=
char_tokenizer
),
])
def
before
(
text_a
,
label
):
sentence_a
=
text_a
[:
args
.
max_seqlen
]
return
sentence_a
,
label
def
after
(
sentence_a
,
label
):
batch_size
=
sentence_a
.
shape
[
0
]
onehot_label
=
np
.
zeros
([
batch_size
,
hparams
.
num_label
],
dtype
=
np
.
float32
)
onehot_label
[
np
.
arange
(
batch_size
),
label
]
=
9999.
sentence_a
,
=
utils
.
data
.
expand_dims
(
sentence_a
)
return
sentence_a
,
onehot_label
train_ds
=
supervise_feature_column
.
build_dataset
(
'train'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'train'
),
shuffle
=
True
,
repeat
=
True
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
))
\
.
map
(
after
)
\
dev_ds
=
supervise_feature_column
.
build_dataset
(
'dev'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'dev'
),
shuffle
=
False
,
repeat
=
False
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
))
\
.
map
(
after
)
def
unsuperve_before
(
text_a
,
teacher_text_a
):
teacher_sentence
,
teacher_segments
=
utils
.
data
.
build_1_pair
(
teacher_text_a
,
max_seqlen
=
args
.
teacher_max_seqlen
,
cls_id
=
teacher_cls_id
,
sep_id
=
teacher_sep_id
)
sentence_a
=
text_a
[:
args
.
max_seqlen
]
return
sentence_a
,
teacher_sentence
,
teacher_segments
client
=
InferenceClient
(
args
.
teacher_host
,
batch_size
=
args
.
server_batch_size
,
num_coroutine
=
args
.
num_coroutine
)
log
.
info
(
'teacher host %s'
%
args
.
teacher_host
)
def
ask_teacher_for_label
(
sentence_a
,
teacher_sentence
,
teacher_segments
):
sentence_a
,
teacher_sentence
,
teacher_segments
=
utils
.
data
.
expand_dims
(
sentence_a
,
teacher_sentence
,
teacher_segments
)
teacher_label
,
=
client
(
teacher_sentence
,
teacher_segments
)
teacher_label
=
teacher_label
[:,
:]
return
sentence_a
,
teacher_label
unsup_train_ds
=
unsupervise_feature_column
.
build_dataset
(
'unsup_train'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'unsup_train_aug'
),
shuffle
=
True
,
repeat
=
True
,
use_gz
=
False
)
\
.
buffered
(
100
)
\
.
map
(
unsuperve_before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
,
0
))
\
.
map
(
ask_teacher_for_label
)
train_ds
=
utils
.
data
.
interleave
(
train_ds
,
unsup_train_ds
)
shapes
=
([
-
1
,
args
.
max_seqlen
,
1
],
[
-
1
,
hparams
.
num_label
])
types
=
(
'int64'
,
'float32'
)
train_ds
.
data_shapes
=
shapes
train_ds
.
data_types
=
types
dev_ds
.
data_shapes
=
shapes
dev_ds
.
data_types
=
types
'''
from tqdm import tqdm
for slots in tqdm(train_ds):
pass
'''
best_exporter
=
propeller
.
train
.
exporter
.
BestExporter
(
os
.
path
.
join
(
run_config
.
model_dir
,
'best'
),
cmp_fn
=
lambda
old
,
new
:
new
[
'dev'
][
'acc'
]
>
old
[
'dev'
][
'acc'
])
propeller
.
train
.
train_and_eval
(
model_class_or_model_fn
=
ClassificationBowModel
,
params
=
hparams
,
run_config
=
run_config
,
train_dataset
=
train_ds
,
eval_dataset
=
{
'dev'
:
dev_ds
},
exporters
=
[
best_exporter
])
print
(
'dev_acc3
\t
%.5f'
%
(
best_exporter
.
_best
[
'dev'
][
'acc'
]))
distill/finetune_chnsenticorp.py
0 → 100644
浏览文件 @
cce162f8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
time
import
logging
from
random
import
random
from
functools
import
reduce
,
partial
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
model.ernie
import
ErnieModel
from
optimization
import
optimization
import
utils.data
from
propeller
import
log
import
propeller.paddle
as
propeller
log
.
setLevel
(
logging
.
DEBUG
)
class
ClassificationErnieModel
(
propeller
.
train
.
Model
):
"""propeller Model wraper for paddle-ERNIE """
def
__init__
(
self
,
hparam
,
mode
,
run_config
):
self
.
hparam
=
hparam
self
.
mode
=
mode
self
.
run_config
=
run_config
def
forward
(
self
,
features
):
src_ids
,
sent_ids
=
features
dtype
=
'float16'
if
self
.
hparam
[
'fp16'
]
else
'float32'
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
input_mask
=
L
.
cast
(
L
.
logical_not
(
L
.
equal
(
src_ids
,
zero
)),
dtype
)
# assume pad id == 0
#input_mask = L.unsqueeze(input_mask, axes=[2])
d_shape
=
L
.
shape
(
src_ids
)
seqlen
=
d_shape
[
1
]
batch_size
=
d_shape
[
0
]
pos_ids
=
L
.
unsqueeze
(
L
.
range
(
0
,
seqlen
,
1
,
dtype
=
'int32'
),
axes
=
[
0
])
pos_ids
=
L
.
expand
(
pos_ids
,
[
batch_size
,
1
])
pos_ids
=
L
.
unsqueeze
(
pos_ids
,
axes
=
[
2
])
pos_ids
=
L
.
cast
(
pos_ids
,
'int64'
)
pos_ids
.
stop_gradient
=
True
input_mask
.
stop_gradient
=
True
task_ids
=
L
.
zeros_like
(
src_ids
)
+
self
.
hparam
.
task_id
#this shit wont use at the moment
task_ids
.
stop_gradient
=
True
bert
=
ErnieModel
(
src_ids
=
src_ids
,
position_ids
=
pos_ids
,
sentence_ids
=
sent_ids
,
task_ids
=
task_ids
,
input_mask
=
input_mask
,
config
=
self
.
hparam
,
use_fp16
=
self
.
hparam
[
'fp16'
]
)
cls_feats
=
bert
.
get_pooled_output
()
cls_feats
=
L
.
dropout
(
x
=
cls_feats
,
dropout_prob
=
0.1
,
dropout_implementation
=
"upscale_in_train"
)
logits
=
L
.
fc
(
input
=
cls_feats
,
size
=
self
.
hparam
[
'num_label'
],
param_attr
=
F
.
ParamAttr
(
name
=
"cls_out_w"
,
initializer
=
F
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
bias_attr
=
F
.
ParamAttr
(
name
=
"cls_out_b"
,
initializer
=
F
.
initializer
.
Constant
(
0.
))
)
propeller
.
summary
.
histogram
(
'pred'
,
logits
)
if
self
.
mode
is
propeller
.
RunMode
.
PREDICT
:
probs
=
L
.
softmax
(
logits
)
return
probs
else
:
return
logits
def
loss
(
self
,
predictions
,
labels
):
ce_loss
,
probs
=
L
.
softmax_with_cross_entropy
(
logits
=
predictions
,
label
=
labels
,
return_softmax
=
True
)
#L.Print(ce_loss, message='per_example_loss')
loss
=
L
.
mean
(
x
=
ce_loss
)
return
loss
def
backward
(
self
,
loss
):
scheduled_lr
,
_
=
optimization
(
loss
=
loss
,
warmup_steps
=
int
(
self
.
run_config
.
max_steps
*
self
.
hparam
[
'warmup_proportion'
]),
num_train_steps
=
self
.
run_config
.
max_steps
,
learning_rate
=
self
.
hparam
[
'learning_rate'
],
train_program
=
F
.
default_main_program
(),
startup_prog
=
F
.
default_startup_program
(),
weight_decay
=
self
.
hparam
[
'weight_decay'
],
scheduler
=
"linear_warmup_decay"
,)
propeller
.
summary
.
scalar
(
'lr'
,
scheduled_lr
)
def
metrics
(
self
,
predictions
,
label
):
predictions
=
L
.
argmax
(
predictions
,
axis
=
1
)
predictions
=
L
.
unsqueeze
(
predictions
,
axes
=
[
1
])
acc
=
propeller
.
metrics
.
Acc
(
label
,
predictions
)
#auc = propeller.metrics.Auc(label, predictions)
return
{
'acc'
:
acc
}
if
__name__
==
'__main__'
:
parser
=
propeller
.
ArgumentParser
(
'DAN model with Paddle'
)
parser
.
add_argument
(
'--max_seqlen'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--vocab_file'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--do_predict'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--warm_start_from'
,
type
=
str
)
args
=
parser
.
parse_args
()
run_config
=
propeller
.
parse_runconfig
(
args
)
hparams
=
propeller
.
parse_hparam
(
args
)
vocab
=
{
j
.
strip
().
split
(
b
'
\t
'
)[
0
].
decode
(
'utf8'
):
i
for
i
,
j
in
enumerate
(
open
(
args
.
vocab_file
,
'rb'
))}
sep_id
=
vocab
[
'[SEP]'
]
cls_id
=
vocab
[
'[CLS]'
]
unk_id
=
vocab
[
'[UNK]'
]
tokenizer
=
utils
.
data
.
CharTokenizer
(
vocab
.
keys
())
def
tokenizer_func
(
inputs
):
'''avoid pickle error'''
ret
=
tokenizer
(
inputs
)
return
ret
if
not
args
.
do_predict
:
feature_column
=
propeller
.
data
.
FeatureColumns
([
propeller
.
data
.
TextColumn
(
'title'
,
unk_id
=
unk_id
,
vocab_dict
=
vocab
,
tokenizer
=
tokenizer_func
),
propeller
.
data
.
LabelColumn
(
'label'
),
])
def
before
(
seg_a
,
label
):
sentence
,
segments
=
utils
.
data
.
build_1_pair
(
seg_a
,
max_seqlen
=
args
.
max_seqlen
,
cls_id
=
cls_id
,
sep_id
=
sep_id
)
return
sentence
,
segments
,
label
def
after
(
sentence
,
segments
,
label
):
sentence
,
segments
,
label
=
utils
.
data
.
expand_dims
(
sentence
,
segments
,
label
)
return
sentence
,
segments
,
label
log
.
debug
(
os
.
path
.
join
(
args
.
data_dir
,
'train'
))
train_ds
=
feature_column
.
build_dataset
(
'train'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'train'
),
shuffle
=
True
,
repeat
=
True
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
,
0
))
\
.
map
(
after
)
dev_ds
=
feature_column
.
build_dataset
(
'dev'
,
data_dir
=
os
.
path
.
join
(
args
.
data_dir
,
'dev'
),
shuffle
=
False
,
repeat
=
False
,
use_gz
=
False
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
,
0
))
\
.
map
(
after
)
shapes
=
([
-
1
,
args
.
max_seqlen
,
1
],
[
-
1
,
args
.
max_seqlen
,
1
],
[
-
1
,
1
])
types
=
(
'int64'
,
'int64'
,
'int64'
)
train_ds
.
data_shapes
=
shapes
train_ds
.
data_types
=
types
dev_ds
.
data_shapes
=
shapes
dev_ds
.
data_types
=
types
varname_to_warmstart
=
re
.
compile
(
'encoder.*|pooled.*|.*embedding|pre_encoder_.*'
)
warm_start_dir
=
args
.
warm_start_from
ws
=
propeller
.
WarmStartSetting
(
predicate_fn
=
lambda
v
:
varname_to_warmstart
.
match
(
v
.
name
)
and
os
.
path
.
exists
(
os
.
path
.
join
(
warm_start_dir
,
v
.
name
)),
from_dir
=
warm_start_dir
)
best_exporter
=
propeller
.
train
.
exporter
.
BestInferenceModelExporter
(
os
.
path
.
join
(
run_config
.
model_dir
,
'best'
),
cmp_fn
=
lambda
old
,
new
:
new
[
'eval'
][
'acc'
]
>
old
[
'eval'
][
'acc'
])
propeller
.
train
.
train_and_eval
(
model_class_or_model_fn
=
ClassificationErnieModel
,
params
=
hparams
,
run_config
=
run_config
,
train_dataset
=
train_ds
,
eval_dataset
=
dev_ds
,
warm_start_setting
=
ws
,
exporters
=
[
best_exporter
])
print
(
'dev_acc
\t
%.5f'
%
(
best_exporter
.
_best
[
'eval'
][
'acc'
]))
else
:
feature_column
=
propeller
.
data
.
FeatureColumns
([
propeller
.
data
.
TextColumn
(
'title'
,
unk_id
=
unk_id
,
vocab_dict
=
vocab
,
tokenizer
=
tokenizer_func
),
propeller
.
data
.
LabelColumn
(
'label'
),
])
def
before
(
seg_a
):
sentence
,
segments
=
utils
.
data
.
build_1_pair
(
seg_a
,
max_seqlen
=
args
.
max_seqlen
,
cls_id
=
cls_id
,
sep_id
=
sep_id
)
return
sentence
,
segments
def
after
(
sentence
,
segments
):
sentence
,
segments
=
utils
.
data
.
expand_dims
(
sentence
,
segments
)
return
sentence
,
segments
predict_ds
=
feature_column
.
build_dataset_from_stdin
(
'predict'
)
\
.
map
(
before
)
\
.
padded_batch
(
hparams
.
batch_size
,
(
0
,
0
))
\
.
map
(
after
)
shapes
=
([
-
1
,
args
.
max_seqlen
,
1
],
[
-
1
,
args
.
max_seqlen
,
1
])
types
=
(
'int64'
,
'int64'
)
predict_ds
.
data_shapes
=
shapes
predict_ds
.
data_types
=
types
finetuned_model
=
propeller
.
Learner
(
ClassificationErnieModel
,
run_config
,
hparams
)
for
logits
,
in
finetuned_model
.
predict
(
predict_ds
,
ckpt
=-
1
):
# ckpt=-1 means last step
print
(
np
.
argmax
(
logits
))
distill/script/distill_chnsenticorp.sh
0 → 100755
浏览文件 @
cce162f8
set
-x
export
PYTHONPATH
=
.:
$PYTHONPATH
output_dir
=
./output/distill
teacher_dir
=
${
output_dir
}
/teacher
student_dir
=
${
output_dir
}
/student
# 1. finetune teacher
CUDA_VISIBLE_DEVICES
=
0
\
python3
-u
./distill/finetune_chnsenticorp.py
\
--data_dir
${
TASK_DATA_PATH
}
/distill/chnsenticorp/teacher
\
--warm_start_from
${
MODEL_PATH
}
/params
\
--vocab_file
${
MODEL_PATH
}
/vocab.txt
\
--max_seqlen
128
\
--run_config
'{
"model_dir": "'
${
teacher_dir
}
'",
"max_steps": '
$((
10
*
9600
/
32
))
',
"save_steps": 100,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}'
\
--hparam
${
MODEL_PATH
}
/ernie_config.json
\
--hparam
'{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}'
\
--hparam
'{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 32
}'
((
$?
!=
0
))
&&
echo
"Something goes wrong at Step 1, please check"
&&
exit
-1
# 2. start a prediction server
export
CUDA_VISIBLE_DEVICES
=
0
cat
${
TASK_DATA_PATH
}
/distill/chnsenticorp/student/unsup_train_aug/part.0 |awk
-F
"
\t
"
'{print $2}'
|python3
-u
./distill/finetune_chnsenticorp.py
\
--do_predict
\
--data_dir
${
TASK_DATA_PATH
}
/distill/chnsenticorp/teacher
\
--warm_start_from
${
MODEL_PATH
}
/params
\
--vocab_file
${
MODEL_PATH
}
/vocab.txt
\
--max_seqlen
128
\
--run_config
'{
"model_dir": "'
${
teacher_dir
}
'",
"log_steps": 10,
}'
\
--hparam
${
MODEL_PATH
}
/ernie_config.json
\
--hparam
'{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}'
\
--hparam
'{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 100
}'
>
prediction_label
((
$?
!=
0
))
&&
echo
"Something goes wrong at Step 2, please check"
&&
exit
-1
mkdir
prediction_output
paste
${
TASK_DATA_PATH
}
/distill/chnsenticorp/student/unsup_train_aug/part.0 prediction_label |awk
-F
"
\t
"
'{print $2"\t"$3}'
>
prediction_output/part.0
#. 3. learn from teacher
export
CUDA_VISIBLE_DEVICES
=
0
python3 ./distill/distill_chnsentocorp.py
\
--data_dir
${
TASK_DATA_PATH
}
/distill/chnsenticorp/student
\
--vocab_file
${
TASK_DATA_PATH
}
/distill/chnsenticorp/student/vocab.txt
\
--unsupervise_data_dir
./prediction_output/
\
--max_seqlen
128
\
--run_config
'{
"model_dir": "'
${
student_dir
}
'",
"max_steps": '
$((
100
*
9600
/
100
))
',
"save_steps": 1000,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}'
\
--hparam
'{
"num_label": 2,
"vocab_size": 35000,
"emb_size": 128,
"initializer_range": 0.02,
}'
\
--hparam
'{ # lr shit
"warmup_proportion": 0.1,
"weight_decay": 0.00,
"fp16": 0,
"learning_rate": 1e-4,
"batch_size": 100
}'
((
$?
!=
0
))
&&
echo
"Something goes wrong at Step 3, please check"
&&
exit
-1
distill/script/distill_chnsenticorp_with_propeller_server.sh
0 → 100755
浏览文件 @
cce162f8
set
-x
export
PYTHONPATH
=
.:
$PYTHONPATH
output_dir
=
./output/distill
teacher_dir
=
${
output_dir
}
/teacher
student_dir
=
${
output_dir
}
/student
# 1. finetune teacher
CUDA_VISIBLE_DEVICES
=
0
\
python3
-u
./distill/finetune_chnsenticorp.py
\
--data_dir
${
TASK_DATA_PATH
}
/distill/chnsenticorp/teacher
\
--warm_start_from
${
MODEL_PATH
}
/params
\
--vocab_file
${
MODEL_PATH
}
/vocab.txt
\
--max_seqlen
128
\
--run_config
'{
"model_dir": "'
${
teacher_dir
}
'",
"max_steps": '
$((
10
*
9600
/
32
))
',
"save_steps": 100,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}'
\
--hparam
${
MODEL_PATH
}
/ernie_config.json
\
--hparam
'{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}'
\
--hparam
'{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 32
}'
((
$?
!=
0
))
&&
echo
"Something goes wrong at Step 1, please check"
&&
exit
-1
# 2. start a prediction server
CUDA_VISIBLE_DEVICES
=
1
\
python3
-m
propeller.tools.start_server
-p
8113
-m
${
teacher_dir
}
/best/inference/ &
echo
$!
>
pid.server
sleep
10
#. 3. learn from teacher
export
CUDA_VISIBLE_DEVICES
=
0
python3 ./distill/distill_chnsentocorp_with_propeller_server.py
\
--data_dir
${
TASK_DATA_PATH
}
/distill/chnsenticorp/student
\
--vocab_file
${
TASK_DATA_PATH
}
/distill/chnsenticorp/student/vocab.txt
\
--teacher_vocab_file
${
MODEL_PATH
}
/vocab.txt
\
--max_seqlen
128
\
--teacher_max_seqlen
128
\
--server_batch_size
64
\
--teacher_host
tcp://localhost:8113
\
--num_coroutine
10
\
--run_config
'{
"model_dir": "'
${
student_dir
}
'",
"max_steps": '
$((
100
*
9600
/
100
))
',
"save_steps": 1000,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}'
\
--hparam
'{ # model definition
"num_label": 2,
"vocab_size": 35000,
"emb_size": 128,
"initializer_range": 0.02,
}'
\
--hparam
'{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.00,
"fp16": 0,
"learning_rate": 1e-4,
"batch_size": 100
}'
((
$?
!=
0
))
&&
echo
"Something goes wrong at Step 2, please check"
&&
exit
-1
ps
-ef
|grep
'propeller.tools.start_server'
|awk
'{print $2}'
|xargs
kill
-9
utils/data.py
0 → 100644
浏览文件 @
cce162f8
import
sys
import
numpy
as
np
import
re
from
propeller
import
log
import
itertools
from
propeller.paddle.data
import
Dataset
import
six
if
six
.
PY2
:
import
operator
def
accumulate
(
iterable
,
func
=
operator
.
add
,
initial
=
None
):
'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
# accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it
=
iter
(
iterable
)
total
=
initial
if
initial
is
None
:
try
:
total
=
next
(
it
)
except
StopIteration
:
return
yield
total
for
element
in
it
:
total
=
func
(
total
,
element
)
yield
total
else
:
from
itertools
import
accumulate
max_input_chars_per_word
=
100
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
def
wordpiece
(
token
,
vocab
,
unk_token
,
sentencepiece_style_vocab
=
False
):
"""call with single word"""
chars
=
list
(
token
)
if
len
(
chars
)
>
max_input_chars_per_word
:
return
[
unk_token
],
[(
0
,
len
(
chars
))]
is_bad
=
False
start
=
0
sub_tokens
=
[]
sub_pos
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
==
0
and
sentencepiece_style_vocab
:
substr
=
u
'
\u2581
'
+
substr
if
start
>
0
and
not
sentencepiece_style_vocab
:
substr
=
"##"
+
substr
if
substr
in
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
sub_pos
.
append
((
start
,
end
))
start
=
end
if
is_bad
:
return
[
unk_token
],
[(
0
,
len
(
chars
))]
else
:
return
sub_tokens
,
sub_pos
class
SpaceTokenizer
(
object
):
def
__init__
(
self
,
vocab
,
lower
=
True
):
"""
char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece
"""
self
.
vocab
=
set
(
vocab
)
self
.
lower
=
lower
def
__call__
(
self
,
sen
):
if
len
(
sen
)
==
0
:
return
[]
#empty line
sen
=
sen
.
decode
(
'utf8'
)
if
self
.
lower
:
sen
=
sen
.
lower
()
res
=
[]
for
s
in
sen
.
split
(
' '
):
if
s
==
' '
:
continue
if
s
in
self
.
vocab
:
res
.
append
(
s
)
else
:
res
.
append
(
'[UNK]'
)
return
res
class
CharTokenizer
(
object
):
def
__init__
(
self
,
vocab
,
lower
=
True
):
"""
char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece
"""
self
.
vocab
=
set
(
vocab
)
#self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
self
.
pat
=
re
.
compile
(
r
'\S'
)
self
.
lower
=
lower
def
__call__
(
self
,
sen
):
if
len
(
sen
)
==
0
:
return
[]
#empty line
sen
=
sen
.
decode
(
'utf8'
)
if
self
.
lower
:
sen
=
sen
.
lower
()
res
=
[]
for
match
in
self
.
pat
.
finditer
(
sen
):
words
,
_
=
wordpiece
(
match
.
group
(
0
),
vocab
=
self
.
vocab
,
unk_token
=
'[UNK]'
)
res
.
extend
(
words
)
return
res
def
build_2_pair
(
seg_a
,
seg_b
,
max_seqlen
,
cls_id
,
sep_id
):
token_type_a
=
np
.
ones_like
(
seg_a
,
dtype
=
np
.
int64
)
*
0
token_type_b
=
np
.
ones_like
(
seg_b
,
dtype
=
np
.
int64
)
*
1
sen_emb
=
np
.
concatenate
([[
cls_id
],
seg_a
,
[
sep_id
],
seg_b
,
[
sep_id
]],
0
)
token_type_emb
=
np
.
concatenate
([[
0
],
token_type_a
,
[
0
],
token_type_b
,
[
1
]],
0
)
seqlen
=
sen_emb
.
shape
[
0
]
#random truncate
random_begin
=
0
#np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
sen_emb
=
sen_emb
[
random_begin
:
random_begin
+
max_seqlen
]
token_type_emb
=
token_type_emb
[
random_begin
:
random_begin
+
max_seqlen
]
return
sen_emb
,
token_type_emb
def
build_1_pair
(
seg_a
,
max_seqlen
,
cls_id
,
sep_id
):
token_type_a
=
np
.
ones_like
(
seg_a
,
dtype
=
np
.
int64
)
*
0
sen_emb
=
np
.
concatenate
([[
cls_id
],
seg_a
,
[
sep_id
]],
0
)
token_type_emb
=
np
.
concatenate
([[
0
],
token_type_a
,
[
0
]],
0
)
seqlen
=
sen_emb
.
shape
[
0
]
#random truncate
random_begin
=
0
#np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
sen_emb
=
sen_emb
[
random_begin
:
random_begin
+
max_seqlen
]
token_type_emb
=
token_type_emb
[
random_begin
:
random_begin
+
max_seqlen
]
return
sen_emb
,
token_type_emb
def
expand_dims
(
*
args
):
func
=
lambda
i
:
np
.
expand_dims
(
i
,
-
1
)
ret
=
[
func
(
i
)
for
i
in
args
]
return
ret
def
interleave
(
ds1
,
ds2
):
def
gen
():
for
i
,
j
in
six
.
moves
.
zip_longest
(
iter
(
ds1
),
iter
(
ds2
)):
if
i
is
not
None
:
yield
i
if
j
is
not
None
:
yield
j
return
Dataset
.
from_generator_func
(
gen
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录