提交 109a3c75 编写于 作者: X Xing Wu 提交者: JesseyXujin

LAC dygraph implementation (#4190)

* lac dygraph for version 1.7

* lac dygraph for version 1.7

* add eval, predict and README

* remove unused links
上级 4a2fb50e
# 中文词法分析
## 1. 简介
Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型,在单个模型中完成中文分词、词性标注、专名识别任务。我们在自建的数据集上对分词、词性标注、专名识别进行整体的评估效果,具体数值见下表;此外,我们在百度开放的 [ERNIE](https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE) 模型上 finetune,并对比基线模型、BERT finetuned 和 ERNIE finetuned 的效果,可以看出会有显著的提升。可通过 [AI开放平台-词法分析](http://ai.baidu.com/tech/nlp/lexical) 线上体验百度的词法分析服务。
这里的是LAC的动态图实现,相同网络结构的静态图实现可以参照:[LAC静态图实现](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/lexical_analysis)
|模型|Precision|Recall|F1-score|
|:-:|:-:|:-:|:-:|
|Lexical Analysis|89.2%|89.4%|89.3%|
## 2. 快速开始
### 安装说明
#### 1.PaddlePaddle 安装
本项目依赖 PaddlePaddle 1.7.0 及以上版本和PaddleHub 1.0.0及以上版本 ,PaddlePaddle安装请参考官网 [快速安装](http://www.paddlepaddle.org/paddle#quick-start),PaddleHub安装参考 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
> Warning: GPU 和 CPU 版本的 PaddlePaddle 分别是 paddlepaddle-gpu 和 paddlepaddle,请安装时注意区别。
#### 2. 克隆代码
克隆工具集代码库到本地
```bash
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/lexical_analysis
```
#### 3. 环境依赖
PaddlePaddle的版本要求是:Python 2 版本是 2.7.15+、Python 3 版本是 3.5.1+/3.6/3.7。LAC的代码可支持Python2/3,无具体版本限制
### 数据准备
#### 训练数据集
下载数据集文件,解压后会生成 `./data/` 文件夹
```bash
python downloads.py dataset
```
### 模型训练
基于示例的数据集,可通过下面的命令,在训练集 `./data/train.tsv` 上进行训练
```bash
bash run.sh
```
### 模型评估
我们基于自建的数据集训练了一个词法分析的模型,可以直接用这个模型对测试集 `./data/test.tsv` 进行验证,
```bash
# baseline model
sh eval.sh
```
### 模型预测
加载已有的模型,对未知的数据进行预测
```bash
# baseline model
sh predict.sh
```
## 3. 进阶使用
### 任务定义与建模
词法分析任务的输入是一个字符串(我们后面使用『句子』来指代它),而输出是句子中的词边界和词性、实体类别。序列标注是词法分析的经典建模方式。我们使用基于 GRU 的网络结构学习特征,将学习到的特征接入 CRF 解码层完成序列标注。CRF 解码层本质上是将传统 CRF 中的线性模型换成了非线性神经网络,基于句子级别的似然概率,因而能够更好的解决标记偏置问题。模型要点如下,具体细节请参考 `run_sequence_labeling.py` 代码。
1. 输入采用 one-hot 方式表示,每个字以一个 id 表示
2. one-hot 序列通过字表,转换为实向量表示的字向量序列;
3. 字向量序列作为双向 GRU 的输入,学习输入序列的特征表示,得到新的特性表示序列,我们堆叠了两层双向GRU以增加学习能力;
4. CRF 以 GRU 学习到的特征为输入,以标记序列为监督信号,实现序列标注。
词性和专名类别标签集合如下表,其中词性标签 24 个(小写字母),专名类别标签 4 个(大写字母)。这里需要说明的是,人名、地名、机构名和时间四个类别,在上表中存在两套标签(PER / LOC / ORG / TIME 和 nr / ns / nt / t),被标注为第二套标签的词,是模型判断为低置信度的人名、地名、机构名和时间词。开发者可以基于这两套标签,在四个类别的准确、召回之间做出自己的权衡。
| 标签 | 含义 | 标签 | 含义 | 标签 | 含义 | 标签 | 含义 |
| ---- | -------- | ---- | -------- | ---- | -------- | ---- | -------- |
| n | 普通名词 | f | 方位名词 | s | 处所名词 | t | 时间 |
| nr | 人名 | ns | 地名 | nt | 机构名 | nw | 作品名 |
| nz | 其他专名 | v | 普通动词 | vd | 动副词 | vn | 名动词 |
| a | 形容词 | ad | 副形词 | an | 名形词 | d | 副词 |
| m | 数量词 | q | 量词 | r | 代词 | p | 介词 |
| c | 连词 | u | 助词 | xc | 其他虚词 | w | 标点符号 |
| PER | 人名 | LOC | 地名 | ORG | 机构名 | TIME | 时间 |
### 模型原理介绍
上面介绍的模型原理如下图所示:<br />
![GRU-CRF-MODEL](./gru-crf-model.png)
### 数据格式
训练使用的数据可以由用户根据实际的应用场景,自己组织数据。除了第一行是 `text_a\tlabel` 固定的开头,后面的每行数据都是由两列组成,以制表符分隔,第一列是 utf-8 编码的中文文本,以 `\002` 分割,第二列是对应每个字的标注,以 `\002` 分隔。我们采用 IOB2 标注体系,即以 X-B 作为类型为 X 的词的开始,以 X-I 作为类型为 X 的词的持续,以 O 表示不关注的字(实际上,在词性、专名联合标注中,不存在 O )。示例如下:
```text
除\002了\002他\002续\002任\002十\002二\002届\002政\002协\002委\002员\002,\002马\002化\002腾\002,\002雷\002军\002,\002李\002彦\002宏\002也\002被\002推\002选\002为\002新\002一\002届\002全\002国\002人\002大\002代\002表\002或\002全\002国\002政\002协\002委\002员 p-B\002p-I\002r-B\002v-B\002v-I\002m-B\002m-I\002m-I\002ORG-B\002ORG-I\002n-B\002n-I\002w-B\002PER-B\002PER-I\002PER-I\002w-B\002PER-B\002PER-I\002w-B\002PER-B\002PER-I\002PER-I\002d-B\002p-B\002v-B\002v-I\002v-B\002a-B\002m-B\002m-I\002ORG-B\002ORG-I\002ORG-I\002ORG-I\002n-B\002n-I\002c-B\002n-B\002n-I\002ORG-B\002ORG-I\002n-B\002n-I
```
+ 我们随同代码一并发布了完全版的模型和相关的依赖数据。但是,由于模型的训练数据过于庞大,我们没有发布训练数据,仅在`data`目录下放置少数样本用以示例输入数据格式。
+ 模型依赖数据包括:
1. 输入文本的词典,在`conf`目录下,对应`word.dic`
2. 对输入文本中特殊字符进行转换的字典,在`conf`目录下,对应`q2b.dic`
3. 标记标签的词典,在`conf`目录下,对应`tag.dic`
+ 在训练和预测阶段,我们都需要进行原始数据的预处理,具体处理工作包括:
1. 从原始数据文件中抽取出句子和标签,构造句子序列和标签序列
2. 将句子序列中的特殊字符进行转换
3. 依据词典获取词对应的整数索引
## 4. 其他
### 在论文中引用 LAC
如果您的学术工作成果中使用了 LAC,请您增加下述引用。我们非常欣慰 LAC 能够对您的学术工作带来帮助。
```text
@article{jiao2018LAC,
title={Chinese Lexical Analysis with Deep Bi-GRU-CRF Network},
author={Jiao, Zhenyu and Sun, Shuqi and Sun, Ke},
journal={arXiv preprint arXiv:1807.01882},
year={2018},
url={https://arxiv.org/abs/1807.01882}
}
```
### 如何贡献代码
如果你可以修复某个 issue 或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度 进行打分(0-5分,越高越好)。如果你累计获得了 10 分,可以联系我们获得面试机会或为你写推荐信。
model:
word_emb_dim:
val: 128
meaning: "The dimension in which a word is embedded."
grnn_hidden_dim:
val: 128
meaning: "The number of hidden nodes in the GRNN layer."
bigru_num:
val: 2
meaning: "The number of bi_gru layers in the network."
init_checkpoint:
val: ""
meaning: "Path to init model"
inference_save_dir:
val: ""
meaning: "Path to save inference model"
train:
random_seed:
val: 0
meaning: "Random seed for training"
print_steps:
val: 1
meaning: "Print the result per xxx batch of training"
save_steps:
val: 10
meaning: "Save the model once per xxxx batch of training"
validation_steps:
val: 10
meaning: "Do the validation once per xxxx batch of training"
batch_size:
val: 300
meaning: "The number of sequences contained in a mini-batch"
epoch:
val: 10
meaning: "Corpus iteration num"
use_cuda:
val: False
meaning: "If set, use GPU for training."
traindata_shuffle_buffer:
val: 20000
meaning: "The buffer size used in shuffle the training data."
base_learning_rate:
val: 0.001
meaning: "The basic learning rate that affects the entire network."
emb_learning_rate:
val: 2
meaning: "The real learning rate of the embedding layer will be (emb_learning_rate * base_learning_rate)."
crf_learning_rate:
val: 0.2
meaning: "The real learning rate of the embedding layer will be (crf_learning_rate * base_learning_rate)."
enable_ce:
val: false
meaning: 'If set, run the task with continuous evaluation logs.'
cpu_num:
val: 10
meaning: "The number of cpu used to train model, this argument wouldn't be valid if use_cuda=true"
use_data_parallel:
val: False
meaning: "The flag indicating whether to use data parallel mode to train the model."
data:
word_dict_path:
val: "./conf/word.dic"
meaning: "The path of the word dictionary."
label_dict_path:
val: "./conf/tag.dic"
meaning: "The path of the label dictionary."
word_rep_dict_path:
val: "./conf/q2b.dic"
meaning: "The path of the word replacement Dictionary."
train_data:
val: "./data/train.tsv"
meaning: "The folder where the training data is located."
test_data:
val: "./data/test.tsv"
meaning: "The folder where the test data is located."
infer_data:
val: "./data/infer.tsv"
meaning: "The folder where the infer data is located."
model_save_dir:
val: "./models"
meaning: "The model will be saved in this path."
model:
ernie_config_path:
val: "../LARK/ERNIE/config/ernie_config.json"
meaning: "Path to the json file for ernie model config."
init_checkpoint:
val: ""
meaning: "Path to init model"
mode:
val: "train"
meaning: "Setting to train or eval or infer"
init_pretraining_params:
val: "pretrained/params/"
meaning: "Init pre-training params which preforms fine-tuning from. If the arg 'init_checkpoint' has been set, this argument wouldn't be valid."
train:
random_seed:
val: 0
meaning: "Random seed for training"
batch_size:
val: 10
meaning: "The number of sequences contained in a mini-batch"
epoch:
val: 10
meaning: "Corpus iteration num"
use_cuda:
val: True
meaning: "If set, use GPU for training."
base_learning_rate:
val: 0.0002
meaning: "The basic learning rate that affects the entire network."
init_bound:
val: 0.1
meaning: "init bound for initialization."
crf_learning_rate:
val: 0.2
meaning: "The real learning rate of the embedding layer will be (crf_learning_rate * base_learning_rate)."
cpu_num:
val: 10
meaning: "The number of cpu used to train model, it works when use_cuda=False"
print_steps:
val: 1
meaning: "Print the result per xxx batch of training"
save_steps:
val: 10
meaning: "Save the model once per xxxx batch of training"
validation_steps:
val: 5
meaning: "Do the validation once per xxxx batch of training"
data:
vocab_path:
val: "../LARK/ERNIE/config/vocab.txt"
meaning: "The path of the vocabulary."
label_map_config:
val: "./conf/label_map.json"
meaning: "The path of the label dictionary."
num_labels:
val: 57
meaning: "label number"
max_seq_len:
val: 128
meaning: "Number of words of the longest seqence."
do_lower_case:
val: True
meaning: "Whether to lower case the input text. Should be True for uncased models and False for cased models."
train_data:
val: "./data/train.tsv"
meaning: "The folder where the training data is located."
test_data:
val: "./data/test.tsv"
meaning: "The folder where the test data is located."
infer_data:
val: "./data/test.tsv"
meaning: "The folder where the infer data is located."
model_save_dir:
val: "./ernie_models"
meaning: "The model will be saved in this path."
{"d-B": 8, "c-I": 7, "PER-I": 49, "nr-B": 16, "u-B": 36, "c-B": 6, "nr-I": 17, "an-I": 5, "ns-B": 18, "vn-I": 43, "w-B": 44, "an-B": 4, "PER-B": 48, "vn-B": 42, "ns-I": 19, "a-I": 1, "r-B": 30, "xc-B": 46, "LOC-B": 50, "ad-I": 3, "nz-B": 24, "u-I": 37, "a-B": 0, "ad-B": 2, "vd-I": 41, "nw-B": 22, "m-I": 13, "d-I": 9, "n-B": 14, "nz-I": 25, "vd-B": 40, "nw-I": 23, "n-I": 15, "nt-B": 20, "ORG-I": 53, "nt-I": 21, "ORG-B": 52, "LOC-I": 51, "t-B": 34, "TIME-I": 55, "O": 56, "s-I": 33, "f-I": 11, "TIME-B": 54, "t-I": 35, "f-B": 10, "s-B": 32, "r-I": 31, "q-B": 28, "v-I": 39, "v-B": 38, "w-I": 45, "q-I": 29, "p-B": 26, "xc-I": 47, "m-B": 12, "p-I": 27}
\ No newline at end of file
 
、 ,
。 .
— -
~ ~
‖ |
… .
‘ '
’ '
“ "
” "
〔 (
〕 )
〈 <
〉 >
「 '
」 '
『 "
』 "
〖 [
〗 ]
【 [
】 ]
∶ :
$ $
! !
" "
# #
% %
& &
' '
( (
) )
* *
+ +
, ,
- -
. .
/ /
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
: :
; ;
< <
= =
> >
? ?
@ @
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
[ [
\ \
] ]
^ ^
_ _
` `
a a
b b
c c
d d
e e
f f
g g
h h
i i
j j
k k
l l
m m
n n
o o
p p
q q
r r
s s
t t
u u
v v
w w
x x
y y
z z
{ {
| |
} }
 ̄ ~
〝 "
〞 "
﹐ ,
﹑ ,
﹒ .
﹔ ;
﹕ :
﹖ ?
﹗ !
﹙ (
﹚ )
﹛ {
﹜ {
﹝ [
﹞ ]
﹟ #
﹠ &
﹡ *
﹢ +
﹣ -
﹤ <
﹥ >
﹦ =
﹨ \
﹩ $
﹪ %
﹫ @
,
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
0 a-B
1 a-I
2 ad-B
3 ad-I
4 an-B
5 an-I
6 c-B
7 c-I
8 d-B
9 d-I
10 f-B
11 f-I
12 m-B
13 m-I
14 n-B
15 n-I
16 nr-B
17 nr-I
18 ns-B
19 ns-I
20 nt-B
21 nt-I
22 nw-B
23 nw-I
24 nz-B
25 nz-I
26 p-B
27 p-I
28 q-B
29 q-I
30 r-B
31 r-I
32 s-B
33 s-I
34 t-B
35 t-I
36 u-B
37 u-I
38 v-B
39 v-I
40 vd-B
41 vd-I
42 vn-B
43 vn-I
44 w-B
45 w-I
46 xc-B
47 xc-I
48 PER-B
49 PER-I
50 LOC-B
51 LOC-I
52 ORG-B
53 ORG-I
54 TIME-B
55 TIME-I
56 O
因为 它太大了无法显示 source diff 。你可以改为 查看blob
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
FILE_INFO = {
'DATA': {
'name': 'lexical_analysis-dataset-2.0.0.tar.gz',
'md5': '71e4a9a36d0f0177929a1bccedca7dba'
},
}
def usage():
desc = ("\nDownload datasets and pretrained models for LAC.\n"
"Usage:\n"
" 1. python download.py dataset\n"
print(desc)
def md5file(fname):
hash_md5 = hashlib.md5()
with io.open(fname, "rb") as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def _download(url, filename, md5sum):
"""
Download file and check md5
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download dataset ({0}) with retry {1} times.".format(
url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print(
'\r[Process ]: %s%.2f%%' %
(pr, float(size / filesize * 100)),
end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download(name, dir_path):
url = FILE_INFO['BASE_URL'] + FILE_INFO[name]['name']
file_path = os.path.join(dir_path, FILE_INFO[name]['name'])
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download data
print("Downloading : %s" % name)
_download(url, file_path, FILE_INFO[name]['md5'])
# extract data
print("Extracting : %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
sys.exit(1)
pwd = os.path.join(os.path.dirname(__file__), './')
if sys.argv[1] == "dataset":
download('DATA', pwd)
else:
usage()
#!/bin/bash
# download dataset file to ./data/
if [ -d ./data/ ]
then
echo "./data/ directory already existed, ignore download"
else
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/lexical_analysis-dataset-2.0.0.tar.gz
tar xvf lexical_analysis-dataset-2.0.0.tar.gz
/bin/rm lexical_analysis-dataset-2.0.0.tar.gz
fi
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import sys
import paddle.fluid as fluid
import paddle
import utils
import reader
import math
from sequence_labeling import lex_net, Chunk_eval
parser = argparse.ArgumentParser(__doc__)
# 1. model parameters
utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args()
def do_eval(args):
dataset = reader.Dataset(args)
if args.use_cuda:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
test_loader = reader.create_dataloader(
args,
file_name=args.test_data,
place=place,
model='lac',
reader=dataset,
mode='test')
model = lex_net(args, dataset.vocab_size, dataset.num_labels)
load_path = args.init_checkpoint
state_dict, _ = fluid.dygraph.load_dygraph(load_path)
#import ipdb; ipdb.set_trace()
state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"]
model.set_dict(state_dict)
model.eval()
chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset()
# test_process(test_loader, chunk_evaluator)
def test_process(reader, chunk_evaluator):
start_time = time.time()
for batch in reader():
words, targets, length = batch
crf_decode = model(words, length=length)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = chunk_eval(
input=crf_decode,
label=targets,
seq_length=length)
chunk_evaluator.update(num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
precision, recall, f1 = chunk_evaluator.eval()
end_time = time.time()
print("[test] P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" %
(precision, recall, f1, end_time - start_time))
test_process(test_loader, chunk_evaluator)
if __name__ == '__main__':
args = parser.parse_args()
do_eval(args)
#!/bin/bash
export CUDA_VISIBLE_DEVICES=7
python eval.py --batch_size 200 --word_emb_dim 128 --grnn_hidden_dim 128 --bigru_num 2 --use_cuda False --init_checkpoint ./padding_models/step_120000 --test_data ./data/test.tsv --word_dict_path ./conf/word.dic --label_dict_path ./conf/tag.dic --word_rep_dict_path ./conf/q2b.dic
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import sys
import paddle.fluid as fluid
import paddle
import utils
import reader
import math
from sequence_labeling import lex_net, Chunk_eval
parser = argparse.ArgumentParser(__doc__)
# 1. model parameters
utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args()
def do_infer(args):
dataset = reader.Dataset(args)
if args.use_cuda:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
infer_loader = reader.create_dataloader(
args,
file_name=args.infer_data,
place=place,
model='lac',
reader=dataset,
mode='infer')
model = lex_net(args, dataset.vocab_size, dataset.num_labels)
load_path = args.init_checkpoint
state_dict, _ = fluid.dygraph.load_dygraph(load_path)
#import ipdb; ipdb.set_trace()
state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"]
model.set_dict(state_dict)
model.eval()
chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset()
def input_check(data):
if data.lod()[0][-1] == 0:
return data[0]['words']
return None
def infer_process(reader):
results = []
for batch in reader():
# import ipdb; ipdb.set_trace()
words, length = batch
#crf_decode = input_check(words)
#if crf_decode:
# results += utils.parse_result(crf_decode, crf_decode, dataset)
# continue
crf_decode = model(words, length=length)
results += utils.parse_padding_result(words.numpy(), crf_decode.numpy(), length.numpy(), dataset)
return results
result = infer_process(infer_loader)
for sent, tags in result:
result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
print(''.join(result_list))
if __name__ == '__main__':
args = parser.parse_args()
do_infer(args)
#!/bin/bash
export CUDA_VISIBLE_DEVICES=7
python predict.py --batch_size 200 --word_emb_dim 128 --grnn_hidden_dim 128 --bigru_num 2 --use_cuda False --init_checkpoint ./padding_models/step_120000 --infer_data ./data/infer.tsv --word_dict_path ./conf/word.dic --label_dict_path ./conf/tag.dic --word_rep_dict_path ./conf/q2b.dic
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The file_reader converts raw corpus to input.
"""
import os
import argparse
import __future__
import io
import glob
import paddle.fluid as fluid
def load_kv_dict(dict_path,
reverse=False,
delimiter="\t",
key_func=None,
value_func=None):
"""
Load key-value dict from file
"""
result_dict = {}
for line in io.open(dict_path, "r", encoding='utf8'):
terms = line.strip("\n").split(delimiter)
if len(terms) != 2:
continue
if reverse:
value, key = terms
else:
key, value = terms
if key in result_dict:
raise KeyError("key duplicated with [%s]" % (key))
if key_func:
key = key_func(key)
if value_func:
value = value_func(value)
result_dict[key] = value
return result_dict
class Dataset(object):
"""data reader"""
def __init__(self, args, mode="train"):
# read dict
self.word2id_dict = load_kv_dict(
args.word_dict_path, reverse=True, value_func=int)
self.id2word_dict = load_kv_dict(args.word_dict_path)
self.label2id_dict = load_kv_dict(
args.label_dict_path, reverse=True, value_func=int)
self.id2label_dict = load_kv_dict(args.label_dict_path)
self.word_replace_dict = load_kv_dict(args.word_rep_dict_path)
@property
def vocab_size(self):
"""vocabuary size"""
return max(self.word2id_dict.values()) + 1
@property
def num_labels(self):
"""num_labels"""
return max(self.label2id_dict.values()) + 1
def get_num_examples(self, filename):
"""num of line of file"""
return sum(1 for line in io.open(filename, "r", encoding='utf8'))
def word_to_ids(self, words):
"""convert word to word index"""
word_ids = []
for word in words:
word = self.word_replace_dict.get(word, word)
if word not in self.word2id_dict:
word = "OOV"
word_id = self.word2id_dict[word]
word_ids.append(word_id)
return word_ids
def label_to_ids(self, labels):
"""convert label to label index"""
label_ids = []
for label in labels:
if label not in self.label2id_dict:
label = "O"
label_id = self.label2id_dict[label]
label_ids.append(label_id)
return label_ids
def file_reader(self, filename, batch_size=32, _max_seq_len=64, mode="train"):
"""
yield (word_idx, target_idx) one by one from file,
or yield (word_idx, ) in `infer` mode
"""
def wrapper():
fread = io.open(filename, "r", encoding="utf-8")
if mode == "infer":
batch, init_lens = [], []
for line in fread:
words= line.strip()
word_ids = self.word_to_ids(words)
init_lens.append(len(word_ids))
batch.append(word_ids)
if len(batch) == batch_size:
max_seq_len = min(max(init_lens), _max_seq_len)
new_batch = []
for words_len, words in zip(init_lens, batch):
word_ids = words[0:max_seq_len]
words_len = len(word_ids)
# expand to max_seq_len
word_ids += [0 for _ in range(max_seq_len-words_len)]
new_batch.append((word_ids,words_len))
yield new_batch
batch, init_lens = [], []
if len(batch) > 0:
max_seq_len = min(max(init_lens), max_seq_len)
new_batch = []
for words_len, words in zip(init_lens, batch):
word_ids = word_ids[0:max_seq_len]
words_len = len(word_ids)
# expand to max_seq_len
word_ids += [0 for _ in range(max_seq_len-words_len)]
new_batch.append((word_ids,words_len))
yield new_batch
else:
headline = next(fread)
batch, init_lens = [], []
for line in fread:
words, labels = line.strip("\n").split("\t")
if len(words)<1:
continue
word_ids = self.word_to_ids(words.split("\002"))
label_ids = self.label_to_ids(labels.split("\002"))
init_lens.append(len(word_ids))
batch.append((word_ids, label_ids))
if len(batch) == batch_size:
max_seq_len = min(max(init_lens), _max_seq_len)
new_batch = []
for words_len, (word_ids, label_ids) in zip(init_lens, batch):
word_ids = word_ids[0:max_seq_len]
words_len = len(word_ids)
word_ids += [0 for _ in range(max_seq_len-words_len)]
label_ids = label_ids[0:max_seq_len]
label_ids += [0 for _ in range(max_seq_len-words_len)]
assert len(word_ids) == len(label_ids)
new_batch.append((word_ids, label_ids, words_len))
yield new_batch
batch, init_lens = [], []
if len(batch) == batch_size:
max_seq_len = min(max(init_lens), max_seq_len)
new_batch = []
for words_len, (word_ids, label_ids) in zip(init_lens, batch):
max_seq_len = min(max(init_lens), max_seq_len)
word_ids = words[0:max_seq_len]
words_len = len(word_ids)
word_ids += [0 for _ in range(max_seq_len-words_len)]
label_ids = label_ids[0:max_seq_len]
label_ids += [0 for _ in range(max_seq_len-words_len)]
assert len(word_ids) == len(label_ids)
new_batch.append((word_ids, label_ids, words_len))
yield new_batch
fread.close()
return wrapper
def create_dataloader(args,
file_name,
place,
model='lac',
reader=None,
return_reader=False,
mode='train'):
# init reader
if model == 'lac':
data_loader = fluid.io.DataLoader.from_generator(
capacity=50,
use_double_buffer=True,
iterable=True)
if reader == None:
reader = Dataset(args)
# create lac pyreader
if mode == 'train':
#data_loader.set_sample_list_generator(
# fluid.io.batch(
# fluid.io.shuffle(
# reader.file_reader(file_name),
# buf_size=args.traindata_shuffle_buffer),
# batch_size=args.batch_size),
# places=place)
data_loader.set_sample_list_generator(
reader.file_reader(
file_name, batch_size=args.batch_size, _max_seq_len=64, mode=mode),
places=place)
else:
data_loader.set_sample_list_generator(
reader.file_reader(
file_name, batch_size=args.batch_size, _max_seq_len=64, mode=mode),
places=place)
if return_reader:
return data_loader, reader
else:
return data_loader
#!/bin/bash
export FLAGS_fraction_of_gpu_memory_to_use=0.02
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
python -m paddle.distributed.launch --selected_gpus=3,4,6,7 train.py \
--train_data ./data/train.tsv \
--test_data ./data/test.tsv \
--model_save_dir ./padding_models \
--validation_steps 1000 \
--save_steps 10000 \
--print_steps 200 \
--batch_size 400 \
--epoch 10 \
--traindata_shuffle_buffer 20000 \
--word_emb_dim 128 \
--grnn_hidden_dim 128 \
--bigru_num 2 \
--base_learning_rate 1e-3 \
--emb_learning_rate 2 \
--crf_learning_rate 0.2 \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
--enable_ce false \
--use_cuda true \
--cpu_num 1 \
--use_data_paralle True
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The function lex_net(args) define the lexical analysis network structure
"""
import sys
import os
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.nn import Embedding, Linear, GRUUnit
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
h_0=None,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[ :, i:i+1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class BiGRU(fluid.dygraph.Layer):
def __init__(self,
input_dim,
grnn_hidden_dim,
init_bound,
h_0=None):
super(BiGRU, self).__init__()
self.pre_gru = Linear(input_dim=input_dim,
output_dim=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))#,
#num_flatten_dims=2)
self.gru = DynamicGRU(size=grnn_hidden_dim,
h_0=h_0,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
self.pre_gru_r = Linear(input_dim=input_dim,
output_dim=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))#,
#num_flatten_dims=2)
self.gru_r = DynamicGRU(size=grnn_hidden_dim,
is_reverse=True,
h_0=h_0,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
def forward(self, input_feature):
res_pre_gru = self.pre_gru(input_feature)
res_gru = self.gru(res_pre_gru)
res_pre_gru_r = self.pre_gru_r(input_feature)
res_gru_r = self.gru_r(res_pre_gru_r)
bi_merge = fluid.layers.concat(input=[res_gru, res_gru_r], axis=-1)
return bi_merge
class Linear_chain_crf(fluid.dygraph.Layer):
def __init__(self,
param_attr,
size=None,
is_test=False,
dtype='float32'):
super(Linear_chain_crf, self).__init__()
self._param_attr = param_attr
self._dtype = dtype
self._size = size
self._is_test=is_test
self._transition = self.create_parameter(
attr=self._param_attr,
shape=[self._size + 2, self._size],
dtype=self._dtype)
@property
def weight(self):
return self._transition
@weight.setter
def weight(self, value):
self._transition = value
def forward(self, input, label, length=None):
alpha = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
emission_exps = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
transition_exps = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
log_likelihood = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
this_inputs = {
"Emission": [input],
"Transition": self._transition,
"Label": [label]
}
if length:
this_inputs['Length'] = [length]
self._helper.append_op(
type='linear_chain_crf',
inputs=this_inputs,
outputs={
"Alpha": [alpha],
"EmissionExps": [emission_exps],
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
},
attrs={
"is_test": self._is_test,
})
return log_likelihood
class Crf_decoding(fluid.dygraph.Layer):
def __init__(self,
param_attr,
size=None,
is_test=False,
dtype='float32'):
super(Crf_decoding, self).__init__()
self._dtype = dtype
self._size = size
self._is_test = is_test
self._param_attr = param_attr
self._transition = self.create_parameter(
attr=self._param_attr,
shape=[self._size + 2, self._size],
dtype=self._dtype)
@property
def weight(self):
return self._transition
@weight.setter
def weight(self, value):
self._transition = value
def forward(self, input, label=None, length=None):
viterbi_path = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
this_inputs = {"Emission": [input], "Transition": self._transition, "Label": label}
if length:
this_inputs['Length'] = [length]
self._helper.append_op(
type='crf_decoding',
inputs=this_inputs,
outputs={"ViterbiPath": [viterbi_path]},
attrs={
"is_test": self._is_test,
})
return viterbi_path
class Chunk_eval(fluid.dygraph.Layer):
def __init__(self,
num_chunk_types,
chunk_scheme,
excluded_chunk_types=None):
super(Chunk_eval, self).__init__()
self.num_chunk_types = num_chunk_types
self.chunk_scheme = chunk_scheme
self.excluded_chunk_types = excluded_chunk_types
def forward(self, input, label, seq_length=None):
precision = self._helper.create_variable_for_type_inference(dtype="float32")
recall = self._helper.create_variable_for_type_inference(dtype="float32")
f1_score = self._helper.create_variable_for_type_inference(dtype="float32")
num_infer_chunks = self._helper.create_variable_for_type_inference(dtype="int64")
num_label_chunks = self._helper.create_variable_for_type_inference(dtype="int64")
num_correct_chunks = self._helper.create_variable_for_type_inference(dtype="int64")
this_input = {"Inference": [input], "Label": [label]}
if seq_length:
this_input["SeqLength"] = [seq_length]
self._helper.append_op(
type='chunk_eval',
inputs=this_input,
outputs={
"Precision": [precision],
"Recall": [recall],
"F1-Score": [f1_score],
"NumInferChunks": [num_infer_chunks],
"NumLabelChunks": [num_label_chunks],
"NumCorrectChunks": [num_correct_chunks]
},
attrs={
"num_chunk_types": self.num_chunk_types,
"chunk_scheme": self.chunk_scheme,
"excluded_chunk_types": self.excluded_chunk_types or []
})
return (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks)
class lex_net(fluid.dygraph.Layer):
def __init__(self,
args,
vocab_size,
num_labels,
length=None):
super(lex_net, self).__init__()
"""
define the lexical analysis network structure
word: stores the input of the model
for_infer: a boolean value, indicating if the model to be created is for training or predicting.
return:
for infer: return the prediction
otherwise: return the prediction
"""
self.word_emb_dim = args.word_emb_dim
self.vocab_size = vocab_size
self.num_labels = num_labels
self.grnn_hidden_dim = args.grnn_hidden_dim
self.emb_lr = args.emb_learning_rate if 'emb_learning_rate' in dir(args) else 1.0
self.crf_lr = args.emb_learning_rate if 'crf_learning_rate' in dir(args) else 1.0
self.bigru_num = args.bigru_num
self.init_bound = 0.1
#self.IS_SPARSE = True
self.word_embedding = Embedding(
size=[self.vocab_size, self.word_emb_dim],
dtype='float32',
#is_sparse=self.IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=self.emb_lr,
name="word_emb",
initializer=fluid.initializer.Uniform(
low=-self.init_bound, high=self.init_bound)))
h_0 = np.zeros((args.batch_size, self.grnn_hidden_dim), dtype="float32")
h_0 = to_variable(h_0)
self.bigru_units = []
for i in range(self.bigru_num):
if i == 0:
self.bigru_units.append(
self.add_sublayer("bigru_units%d" % i,
BiGRU(self.grnn_hidden_dim, self.grnn_hidden_dim, self.init_bound, h_0=h_0)
))
else:
self.bigru_units.append(
self.add_sublayer("bigru_units%d" % i,
BiGRU(self.grnn_hidden_dim * 2, self.grnn_hidden_dim, self.init_bound, h_0=h_0)
))
self.fc = Linear(input_dim=self.grnn_hidden_dim * 2,
output_dim=self.num_labels,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-self.init_bound, high=self.init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))#,
#num_flatten_dims=2)
self.linear_chain_crf = Linear_chain_crf(
param_attr=fluid.ParamAttr(
name='crfw', learning_rate=self.crf_lr),
size=self.num_labels)
self.crf_decoding = Crf_decoding(
param_attr=fluid.ParamAttr(
name='crfw', learning_rate=self.crf_lr),
size=self.num_labels)
def forward(self, word, target=None, length=None):
"""
Configure the network
"""
#word = fluid.layers.unsqueeze(word, [2])
word_embed = self.word_embedding(word)
input_feature = word_embed
for i in range(self.bigru_num):
bigru_output = self.bigru_units[i](input_feature)
input_feature = bigru_output
emission = self.fc(bigru_output)
if target is not None:
crf_cost = self.linear_chain_crf(
input=emission,
label=target,
length=length)
avg_cost = fluid.layers.mean(x=crf_cost)
self.crf_decoding.weight = self.linear_chain_crf.weight
crf_decode = self.crf_decoding(
input=emission,
length=length)
return avg_cost, crf_decode#, word_embed, bigru_output, emission
else:
crf_decode = self.crf_decoding(
input=emission,
length=length)
return crf_decode
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import time
import random
import argparse
import multiprocessing
import numpy as np
import paddle
import paddle.fluid as fluid
np.set_printoptions(threshold=np.inf)
import reader
import utils
from sequence_labeling import lex_net, Chunk_eval
#from eval import test_process
# the function to train model
def do_train(args):
dataset = reader.Dataset(args)
if args.use_cuda:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
#fluid.default_startup_program().random_seed = 102
#fluid.default_main_program().random_seed = 102
#np.random.seed(102)
#random.seed(102)
train_loader = reader.create_dataloader(
args,
file_name=args.train_data,
place=place,
model='lac',
reader=dataset)
if args.use_data_parallel:
train_loader = fluid.contrib.reader.distributed_batch_reader(
train_loader)
test_loader = reader.create_dataloader(
args,
file_name=args.test_data,
place=place,
model='lac',
reader=dataset,
mode='test')
model = lex_net(args, dataset.vocab_size, dataset.num_labels)
if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=args.base_learning_rate,
parameter_list=model.parameters())
chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
num_train_examples = dataset.get_num_examples(args.train_data)
max_train_steps = args.epoch * num_train_examples // args.batch_size
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
step = 0
print_start_time = time.time()
chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset()
def test_process(reader, chunk_evaluator):
model.eval()
chunk_evaluator.reset()
start_time = time.time()
for batch in reader():
words, targets, length = batch
crf_decode = model(words, length=length)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = chunk_eval(
input=crf_decode,
label=targets,
seq_length=length)
chunk_evaluator.update(num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
precision, recall, f1 = chunk_evaluator.eval()
end_time = time.time()
print("[test] P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" %
(precision, recall, f1, end_time - start_time))
model.train()
for epoch_id in range(args.epoch):
for batch in train_loader():
words, targets, length = batch
start_time = time.time()
avg_cost, crf_decode = model(words, targets, length)
if args.use_data_parallel:
avg_cost = model.scale_loss(avg_cost)
avg_cost.backward()
model.apply_collective_grads()
else:
avg_cost.backward()
optimizer.minimize(avg_cost)
model.clear_gradients()
end_time = time.time()
if step % args.print_steps == 0:
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = chunk_eval(
input=crf_decode,
label=targets,
seq_length=length)
outputs = [avg_cost, precision, recall, f1_score]
avg_cost, precision, recall, f1_score = [np.mean(x.numpy()) for x in outputs]
print("[train] step = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f" % (
step, avg_cost, precision, recall, f1_score, end_time - start_time))
if step % args.validation_steps == 0:
test_process(test_loader, chunk_evaluator)
# save checkpoints
if step % args.save_steps == 0 and step != 0:
save_path = os.path.join(args.model_save_dir, "step_" + str(step))
paddle.fluid.save_dygraph(model.state_dict(), save_path)
step += 1
if __name__ == "__main__":
# 参数控制可以根据需求使用argparse,yaml或者json
# 对NLP任务推荐使用PALM下定义的configure,可以统一argparse,yaml或者json格式的配置文件。
parser = argparse.ArgumentParser(__doc__)
utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args()
print(args)
do_train(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
util tools
"""
from __future__ import print_function
import os
import sys
import numpy as np
import paddle.fluid as fluid
import yaml
import io
def str2bool(v):
"""
argparse does not support True or False in python
"""
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
Put arguments to one group
"""
def __init__(self, parser, title, des):
"""none"""
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
""" Add argument """
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def load_yaml(parser, file_name, **kwargs):
with io.open(file_name, 'r', encoding='utf8') as f:
args = yaml.load(f)
for title in args:
group = parser.add_argument_group(title=title, description='')
for name in args[title]:
_type = type(args[title][name]['val'])
_type = str2bool if _type == bool else _type
group.add_argument(
"--" + name,
default=args[title][name]['val'],
type=_type,
help=args[title][name]['meaning'] +
' Default: %(default)s.',
**kwargs)
def print_arguments(args):
"""none"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def to_str(string, encoding="utf-8"):
"""convert to str for print"""
if sys.version_info.major == 3:
if isinstance(string, bytes):
return string.decode(encoding)
elif sys.version_info.major == 2:
if isinstance(string, unicode):
if os.name == 'nt':
return string
else:
return string.encode(encoding)
return string
def parse_padding_result(words, crf_decode, seq_lens, dataset):
""" parse padding result """
# words = np.squeeze(words)
batch_size = len(seq_lens)
batch_out = []
for sent_index in range(batch_size):
sent = [
dataset.id2word_dict[str(id)]
for id in words[sent_index][1:seq_lens[sent_index] - 1]
]
tags = [
dataset.id2label_dict[str(id)]
for id in crf_decode[sent_index][1:seq_lens[sent_index] - 1]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
Init CheckPoint
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
"""
If existed presitabels
"""
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
def init_pretraining_params(exe,
pretraining_params_path,
main_program,
use_fp16=False):
"""load params of pretrained model, NOT including moment, learning_rate"""
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def _existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=_existed_params)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册