提交 dc8ecdf1 编写于 作者: SYSU_BOND's avatar SYSU_BOND 提交者: bbking

Update lexical_analysis code for paddle 1.6 (#3823)

* update PaddleNLP  lexical_analysis for Release/1.6 (#3664)

* update for paddle 1.6

* update optimize op in paddle 1.6

* fix ernie based in paddle 1.6

* fix coding for windows

* update downloads.py (#3672)

* Fix infer bug on Release/1.6  

* fix bug on ernie based inferring

* replace open with io.open to be compatible with windows (#3707)

* update README.md
上级 482f33ff
...@@ -45,9 +45,7 @@ def do_save_inference_model(args): ...@@ -45,9 +45,7 @@ def do_save_inference_model(args):
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
infer_loader, probs, feed_target_names = create_model( infer_loader, probs, feed_target_names = create_model(
args, args, num_labels=args.num_labels, is_prediction=True)
num_labels=args.num_labels,
is_prediction=True)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -82,10 +80,10 @@ def test_inference_model(args, texts): ...@@ -82,10 +80,10 @@ def test_inference_model(args, texts):
assert (args.inference_model_dir) assert (args.inference_model_dir)
infer_program, feed_names, fetch_targets = fluid.io.load_inference_model( infer_program, feed_names, fetch_targets = fluid.io.load_inference_model(
dirname=args.inference_model_dir, dirname=args.inference_model_dir,
executor=exe, executor=exe,
model_filename="model.pdmodel", model_filename="model.pdmodel",
params_filename="params.pdparams") params_filename="params.pdparams")
data = [] data = []
seq_lens = [] seq_lens = []
for query in texts: for query in texts:
...@@ -97,13 +95,13 @@ def test_inference_model(args, texts): ...@@ -97,13 +95,13 @@ def test_inference_model(args, texts):
seq_lens = np.array(seq_lens) seq_lens = np.array(seq_lens)
pred = exe.run(infer_program, pred = exe.run(infer_program,
feed={ feed={feed_names[0]: data,
feed_names[0]:data, feed_names[1]: seq_lens},
feed_names[1]:seq_lens}, fetch_list=fetch_targets,
fetch_list=fetch_targets, return_numpy=True)
return_numpy=True)
for probs in pred[0]: for probs in pred[0]:
print("%d\t%f\t%f\t%f" % (np.argmax(probs), probs[0], probs[1], probs[2])) print("%d\t%f\t%f\t%f" %
(np.argmax(probs), probs[0], probs[1], probs[2]))
if __name__ == "__main__": if __name__ == "__main__":
...@@ -116,4 +114,3 @@ if __name__ == "__main__": ...@@ -116,4 +114,3 @@ if __name__ == "__main__":
else: else:
texts = [u"我 讨厌 你 , 哼哼 哼 。 。", u"我 喜欢 你 , 爱 你 哟"] texts = [u"我 讨厌 你 , 哼哼 哼 。 。", u"我 喜欢 你 , 爱 你 哟"]
test_inference_model(args, texts) test_inference_model(args, texts)
...@@ -16,7 +16,7 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型 ...@@ -16,7 +16,7 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型
#### 1.PaddlePaddle 安装 #### 1.PaddlePaddle 安装
本项目依赖 PaddlePaddle 1.4.0 及以上版本和PaddleHub 1.0.0及以上版本 ,PaddlePaddle安装请参考官网 [快速安装](http://www.paddlepaddle.org/paddle#quick-start),PaddleHub安装参考 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub) 本项目依赖 PaddlePaddle 1.6.0 及以上版本和PaddleHub 1.0.0及以上版本 ,PaddlePaddle安装请参考官网 [快速安装](http://www.paddlepaddle.org/paddle#quick-start),PaddleHub安装参考 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
> Warning: GPU 和 CPU 版本的 PaddlePaddle 分别是 paddlepaddle-gpu 和 paddlepaddle,请安装时注意区别。 > Warning: GPU 和 CPU 版本的 PaddlePaddle 分别是 paddlepaddle-gpu 和 paddlepaddle,请安装时注意区别。
...@@ -26,6 +26,10 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型 ...@@ -26,6 +26,10 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型
git clone https://github.com/PaddlePaddle/models.git git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/lexical_analysis cd models/PaddleNLP/lexical_analysis
``` ```
#### 3. 环境依赖
PaddlePaddle的版本要求是:Python 2 版本是 2.7.15+、Python 3 版本是 3.5.1+/3.6/3.7。LAC的代码可支持Python2/3,无具体版本限制
### 数据准备 ### 数据准备
#### 1. 快速下载 #### 1. 快速下载
...@@ -33,45 +37,37 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型 ...@@ -33,45 +37,37 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型
本项目涉及的**数据集****预训练模型**的数据可通过执行以下脚本进行快速下载,若仅需使用部分数据,可根据需要参照下列介绍进行部分下载 本项目涉及的**数据集****预训练模型**的数据可通过执行以下脚本进行快速下载,若仅需使用部分数据,可根据需要参照下列介绍进行部分下载
```bash ```bash
sh download.sh python downloads.py all
```
或在支持运行shell脚本的环境下执行:
```bash
sh downloads.sh
``` ```
#### 2. 训练数据集 #### 2. 训练数据集
下载数据集文件,解压后会生成 `./data/` 文件夹 下载数据集文件,解压后会生成 `./data/` 文件夹
```bash ```bash
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/lexical_analysis-dataset-2.0.0.tar.gz python downloads.py dataset
tar xvf lexical_analysis-dataset-2.0.0.tar.gz
``` ```
#### 3. 预训练模型 #### 3. 预训练模型
我们开源了在自建数据集上训练的词法分析模型,可供用户直接使用,这里提供两种下载方式: 我们开源了在自建数据集上训练的词法分析模型,可供用户直接使用,可通过下述链接进行下载:
方式一:基于 PaddleHub 命令行工具,PaddleHub 的安装参考 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
```bash ```bash
# download baseline model # download baseline model
hub download lexical_analysis python downloads.py lac
tar xvf lexical_analysis-2.0.0.tar.gz
# download ERNIE finetuned model # download ERNIE finetuned model
hub download lexical_analysis_finetuned python downloads.py finetuned
tar xvf lexical_analysis_finetuned-1.0.0.tar.gz
```
方式二:直接下载 # download ERNIE model for training
```bash python downloads.py ernie
# download baseline model
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/lexical_analysis-2.0.0.tar.gz
tar xvf lexical_analysis-2.0.0.tar.gz
# download ERNIE finetuned model
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/lexical_analysis_finetuned-1.0.0.tar.gz
tar xvf lexical_analysis_finetuned-1.0.0.tar.gz
``` ```
注:若需进行ERNIE Finetune训练,需自行下载 [ERNIE](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz) 开放的模型,下载链接为: [https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz),下载后解压至 `./pretrained/` 目录下。 注:若需进行ERNIE Finetune训练,需先行下载
[ERNIE](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz) 开放的模型,通过命令`python
downloads.py ernie`可完成下载
### 模型评估 ### 模型评估
我们基于自建的数据集训练了一个词法分析的模型,可以直接用这个模型对测试集 `./data/test.tsv` 进行验证, 我们基于自建的数据集训练了一个词法分析的模型,可以直接用这个模型对测试集 `./data/test.tsv` 进行验证,
...@@ -85,8 +81,9 @@ sh run_ernie.sh eval ...@@ -85,8 +81,9 @@ sh run_ernie.sh eval
### 模型训练 ### 模型训练
基于示例的数据集,可通过下面的命令,在训练集 `./data/train.tsv` 上进行训练,示例包含程序在单机单卡/多卡,以及CPU多线程的运行设置 基于示例的数据集,可通过下面的命令,在训练集 `./data/train.tsv` 上进行训练,示例包含程序在单机单卡/多卡,以及CPU多线程的运行设置
> Warning: 若需进行ERNIE Finetune训练,需自行下载 [ERNIE](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz) 开放的模型,下载链接为: [https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz),下载后解压至 `./pretrained/` 目录下。 > Waring: 若需进行ERNIE Finetune训练,需先行下载
[ERNIE](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz) 开放的模型,通过命令`python
downloads.py ernie`可完成下载
```bash ```bash
# baseline model, using single GPU # baseline model, using single GPU
sh run.sh train_single_gpu sh run.sh train_single_gpu
...@@ -180,7 +177,7 @@ python inference_model.py \ ...@@ -180,7 +177,7 @@ python inference_model.py \
1. 从原始数据文件中抽取出句子和标签,构造句子序列和标签序列 1. 从原始数据文件中抽取出句子和标签,构造句子序列和标签序列
2. 将句子序列中的特殊字符进行转换 2. 将句子序列中的特殊字符进行转换
3. 依据词典获取词对应的整数索引 3. 依据词典获取词对应的整数索引
### 代码结构说明 ### 代码结构说明
```text ```text
. .
...@@ -189,6 +186,7 @@ python inference_model.py \ ...@@ -189,6 +186,7 @@ python inference_model.py \
├── compare.py # 执行LAC与其他开源分词的对比脚本 ├── compare.py # 执行LAC与其他开源分词的对比脚本
├── creator.py # 执行创建网络和数据读取器的脚本 ├── creator.py # 执行创建网络和数据读取器的脚本
├── data/ # 存放数据集的目录 ├── data/ # 存放数据集的目录
├── downloads.py # 用于下载数据和模型的脚本
├── downloads.sh # 用于下载数据和模型的脚本 ├── downloads.sh # 用于下载数据和模型的脚本
├── eval.py # 词法分析评估的脚本 ├── eval.py # 词法分析评估的脚本
├── inference_model.py # 执行保存inference_model的脚本,用于准备上线部署环境 ├── inference_model.py # 执行保存inference_model的脚本,用于准备上线部署环境
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,7 +12,6 @@ ...@@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -*- coding: UTF-8 -*-
""" """
evaluate wordseg for LAC and other open-source wordseg tools evaluate wordseg for LAC and other open-source wordseg tools
""" """
...@@ -20,6 +20,7 @@ from __future__ import division ...@@ -20,6 +20,7 @@ from __future__ import division
import sys import sys
import os import os
import io
def to_unicode(string): def to_unicode(string):
...@@ -70,7 +71,7 @@ def load_testdata(datapath="./data/test_data/test_part"): ...@@ -70,7 +71,7 @@ def load_testdata(datapath="./data/test_data/test_part"):
"""none""" """none"""
sentences = [] sentences = []
sent_seg_list = [] sent_seg_list = []
for line in open(datapath): for line in io.open(datapath, 'r', encoding='utf8'):
sent, label = line.strip().split("\t") sent, label = line.strip().split("\t")
sentences.append(sent) sentences.append(sent)
...@@ -109,7 +110,7 @@ def get_lac_result(): ...@@ -109,7 +110,7 @@ def get_lac_result():
`sh run.sh | tail -n 100 > result.txt` `sh run.sh | tail -n 100 > result.txt`
""" """
sent_seg_list = [] sent_seg_list = []
for line in open("./result.txt"): for line in io.open("./result.txt", 'r', encoding='utf8'):
line = line.strip().split(" ") line = line.strip().split(" ")
words = [pair.split("/")[0] for pair in line] words = [pair.split("/")[0] for pair in line]
labels = [pair.split("/")[1] for pair in line] labels = [pair.split("/")[1] for pair in line]
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,9 +12,8 @@ ...@@ -11,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -*- coding: UTF-8 -*-
""" """
The function lex_net(args) define the lexical analysis network structure Define the function to create lexical analysis model and model's data reader
""" """
import sys import sys
import os import os
...@@ -24,42 +24,51 @@ import paddle.fluid as fluid ...@@ -24,42 +24,51 @@ import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer from paddle.fluid.initializer import NormalInitializer
from reader import Dataset from reader import Dataset
from ernie_reader import SequenceLabelReader
sys.path.append("..") sys.path.append("..")
from models.sequence_labeling import nets from models.sequence_labeling import nets
from models.representation.ernie import ernie_encoder from models.representation.ernie import ernie_encoder, ernie_pyreader
from preprocess.ernie import task_reader
def create_model(args, vocab_size, num_labels, mode = 'train'): def create_model(args, vocab_size, num_labels, mode='train'):
"""create lac model""" """create lac model"""
# model's input data # model's input data
words = fluid.layers.data(name='words', shape=[-1, 1], dtype='int64',lod_level=1) words = fluid.data(name='words', shape=[-1, 1], dtype='int64', lod_level=1)
targets = fluid.layers.data(name='targets', shape=[-1, 1], dtype='int64', lod_level= 1) targets = fluid.data(
name='targets', shape=[-1, 1], dtype='int64', lod_level=1)
# for inference process # for inference process
if mode=='infer': if mode == 'infer':
crf_decode = nets.lex_net(words, args, vocab_size, num_labels, for_infer=True, target=None) crf_decode = nets.lex_net(
return { "feed_list":[words],"words":words, "crf_decode":crf_decode,} words, args, vocab_size, num_labels, for_infer=True, target=None)
return {
"feed_list": [words],
"words": words,
"crf_decode": crf_decode,
}
# for test or train process # for test or train process
avg_cost, crf_decode = nets.lex_net(words, args, vocab_size, num_labels, for_infer=False, target=targets) avg_cost, crf_decode = nets.lex_net(
words, args, vocab_size, num_labels, for_infer=False, target=targets)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks, (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval( num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode, input=crf_decode,
label=targets, label=targets,
chunk_scheme="IOB", chunk_scheme="IOB",
num_chunk_types=int(math.ceil((num_labels - 1) / 2.0))) num_chunk_types=int(math.ceil((num_labels - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator() chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset() chunk_evaluator.reset()
ret = { ret = {
"feed_list":[words, targets], "feed_list": [words, targets],
"words": words, "words": words,
"targets": targets, "targets": targets,
"avg_cost":avg_cost, "avg_cost": avg_cost,
"crf_decode": crf_decode, "crf_decode": crf_decode,
"precision" : precision, "precision": precision,
"recall": recall, "recall": recall,
"f1_score": f1_score, "f1_score": f1_score,
"chunk_evaluator": chunk_evaluator, "chunk_evaluator": chunk_evaluator,
...@@ -67,86 +76,109 @@ def create_model(args, vocab_size, num_labels, mode = 'train'): ...@@ -67,86 +76,109 @@ def create_model(args, vocab_size, num_labels, mode = 'train'):
"num_label_chunks": num_label_chunks, "num_label_chunks": num_label_chunks,
"num_correct_chunks": num_correct_chunks "num_correct_chunks": num_correct_chunks
} }
return ret return ret
def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, return_reader=False, mode='train'): def create_pyreader(args,
file_name,
feed_list,
place,
model='lac',
reader=None,
return_reader=False,
mode='train'):
# init reader # init reader
pyreader = fluid.io.PyReader(
feed_list=feed_list,
capacity=300,
use_double_buffer=True,
iterable=True
)
if model == 'lac': if model == 'lac':
if reader==None: pyreader = fluid.io.PyReader(
feed_list=feed_list,
capacity=50,
use_double_buffer=True,
iterable=True)
if reader == None:
reader = Dataset(args) reader = Dataset(args)
# create lac pyreader # create lac pyreader
if mode == 'train': if mode == 'train':
pyreader.decorate_sample_list_generator( pyreader.decorate_sample_list_generator(
paddle.batch( fluid.io.batch(
paddle.reader.shuffle( fluid.io.shuffle(
reader.file_reader(file_name), reader.file_reader(file_name),
buf_size=args.traindata_shuffle_buffer buf_size=args.traindata_shuffle_buffer),
), batch_size=args.batch_size),
batch_size=args.batch_size places=place)
),
places=place
)
else: else:
pyreader.decorate_sample_list_generator( pyreader.decorate_sample_list_generator(
paddle.batch( fluid.io.batch(
reader.file_reader(file_name, mode=mode), reader.file_reader(
batch_size=args.batch_size file_name, mode=mode),
), batch_size=args.batch_size),
places=place places=place)
)
elif model == 'ernie': elif model == 'ernie':
# create ernie pyreader # create ernie pyreader
if reader==None: pyreader = fluid.io.DataLoader.from_generator(
reader = task_reader.SequenceLabelReader( feed_list=feed_list,
capacity=50,
use_double_buffer=True,
iterable=True)
if reader == None:
reader = SequenceLabelReader(
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
label_map_config=args.label_map_config, label_map_config=args.label_map_config,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
in_tokens=False,
random_seed=args.random_seed) random_seed=args.random_seed)
if mode == 'train': if mode == 'train':
pyreader.decorate_batch_generator( pyreader.set_batch_generator(
reader.data_generator( reader.data_generator(
file_name, args.batch_size, args.epoch, shuffle=True, phase="train" file_name,
), args.batch_size,
places=place args.epoch,
) shuffle=True,
phase="train"),
places=place)
else: else:
pyreader.decorate_batch_generator( pyreader.set_batch_generator(
reader.data_generator( reader.data_generator(
file_name, args.batch_size, epoch=1, shuffle=False, phase=mode file_name,
), args.batch_size,
places=place epoch=1,
) shuffle=False,
phase=mode),
places=place)
if return_reader: if return_reader:
return pyreader, reader return pyreader, reader
else: else:
return pyreader return pyreader
def create_ernie_model(args, ernie_config):
def create_ernie_model(args, ernie_config):
""" """
Create Model for LAC based on ERNIE encoder Create Model for LAC based on ERNIE encoder
""" """
# ERNIE's input data # ERNIE's input data
src_ids = fluid.layers.data(name='src_ids', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
sent_ids = fluid.layers.data(name='sent_ids', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
pos_ids = fluid.layers.data(name='pos_ids', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
input_mask = fluid.layers.data(name='input_mask', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
padded_labels =fluid.layers.data(name='padded_labels', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
seq_lens = fluid.layers.data(name='seq_lens', shape=[1], dtype='int64',lod_level=0)
src_ids = fluid.data(
name='src_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
sent_ids = fluid.data(
name='sent_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
pos_ids = fluid.data(
name='pos_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
input_mask = fluid.data(
name='input_mask', shape=[-1, args.max_seq_len, 1], dtype='float32')
padded_labels = fluid.data(
name='padded_labels', shape=[-1, args.max_seq_len, 1], dtype='int64')
seq_lens = fluid.data(
name='seq_lens', shape=[-1], dtype='int64', lod_level=0)
squeeze_labels = fluid.layers.squeeze(padded_labels, axes=[-1])
# ernie_pyreader
ernie_inputs = { ernie_inputs = {
"src_ids": src_ids, "src_ids": src_ids,
"sent_ids": sent_ids, "sent_ids": sent_ids,
...@@ -156,53 +188,56 @@ def create_ernie_model(args, ernie_config): ...@@ -156,53 +188,56 @@ def create_ernie_model(args, ernie_config):
} }
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config) embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
words = fluid.layers.sequence_unpad(src_ids, seq_lens) padded_token_embeddings = embeddings["padded_token_embeddings"]
labels = fluid.layers.sequence_unpad(padded_labels, seq_lens)
token_embeddings = embeddings["token_embeddings"]
emission = fluid.layers.fc( emission = fluid.layers.fc(
size=args.num_labels, size=args.num_labels,
input=token_embeddings, input=padded_token_embeddings,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=-args.init_bound, high=args.init_bound), low=-args.init_bound, high=args.init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4))) regularization_coeff=1e-4)),
num_flatten_dims=2)
crf_cost = fluid.layers.linear_chain_crf( crf_cost = fluid.layers.linear_chain_crf(
input=emission, input=emission,
label=labels, label=padded_labels,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='crfw', name='crfw', learning_rate=args.crf_learning_rate),
learning_rate=args.crf_learning_rate)) length=seq_lens)
avg_cost = fluid.layers.mean(x=crf_cost) avg_cost = fluid.layers.mean(x=crf_cost)
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw')) input=emission,
param_attr=fluid.ParamAttr(name='crfw'),
length=seq_lens)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks, (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval( num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode, input=crf_decode,
label=labels, label=squeeze_labels,
chunk_scheme="IOB", chunk_scheme="IOB",
num_chunk_types=int(math.ceil((args.num_labels - 1) / 2.0))) num_chunk_types=int(math.ceil((args.num_labels - 1) / 2.0)),
seq_length=seq_lens)
chunk_evaluator = fluid.metrics.ChunkEvaluator() chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset() chunk_evaluator.reset()
ret = { ret = {
"feed_list": [src_ids, sent_ids, pos_ids, input_mask, padded_labels, seq_lens], "feed_list":
"words":words, [src_ids, sent_ids, pos_ids, input_mask, padded_labels, seq_lens],
"labels":labels, "words": src_ids,
"avg_cost":avg_cost, "labels": padded_labels,
"crf_decode":crf_decode, "seq_lens": seq_lens,
"precision" : precision, "avg_cost": avg_cost,
"crf_decode": crf_decode,
"precision": precision,
"recall": recall, "recall": recall,
"f1_score": f1_score, "f1_score": f1_score,
"chunk_evaluator":chunk_evaluator, "chunk_evaluator": chunk_evaluator,
"num_infer_chunks":num_infer_chunks, "num_infer_chunks": num_infer_chunks,
"num_label_chunks":num_label_chunks, "num_label_chunks": num_label_chunks,
"num_correct_chunks":num_correct_chunks "num_correct_chunks": num_correct_chunks
} }
return ret return ret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
FILE_INFO = {
'BASE_URL': 'https://baidu-nlp.bj.bcebos.com/',
'DATA': {
'name': 'lexical_analysis-dataset-2.0.0.tar.gz',
'md5': '71e4a9a36d0f0177929a1bccedca7dba'
},
'LAC_MODEL': {
'name': 'lexical_analysis-2.0.0.tar.gz',
'md5': "fc1daef00de9564083c7dc7b600504ca"
},
'ERNIE_MODEL': {
'name': 'ERNIE_stable-1.0.1.tar.gz',
'md5': "bab876a874b5374a78d7af93384d3bfa"
},
'FINETURN_MODEL': {
'name': 'lexical_analysis_finetuned-1.0.0.tar.gz',
'md5': "ee2c7614b06dcfd89561fbbdaac34342"
}
}
def usage():
desc = ("\nDownload datasets and pretrained models for LAC.\n"
"Usage:\n"
" 1. python download.py all\n"
" 2. python download.py dataset\n"
" 3. python download.py lac\n"
" 4. python download.py finetuned\n"
" 5. python download.py ernie\n")
print(desc)
def md5file(fname):
hash_md5 = hashlib.md5()
with io.open(fname, "rb") as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def _download(url, filename, md5sum):
"""
Download file and check md5
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download dataset ({0}) with retry {1} times.".format(
url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print(
'\r[Process ]: %s%.2f%%' %
(pr, float(size / filesize * 100)),
end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download(name, dir_path):
url = FILE_INFO['BASE_URL'] + FILE_INFO[name]['name']
file_path = os.path.join(dir_path, FILE_INFO[name]['name'])
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download data
print("Downloading : %s" % name)
_download(url, file_path, FILE_INFO[name]['md5'])
# extract data
print("Extracting : %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
sys.exit(1)
pwd = os.path.join(os.path.dirname(__file__), './')
ernie_dir = os.path.join(os.path.dirname(__file__), './pretrained')
if sys.argv[1] == 'all':
download('DATA', pwd)
download('LAC_MODEL', pwd)
download('FINETURN_MODEL', pwd)
download('ERNIE_MODEL', ernie_dir)
if sys.argv[1] == "dataset":
download('DATA', pwd)
elif sys.argv[1] == "lac":
download('LAC_MODEL', pwd)
elif sys.argv[1] == "finetuned":
download('FINETURN_MODEL', pwd)
elif sys.argv[1] == "ernie":
download('ERNIE_MODEL', ernie_dir)
else:
usage()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides reader for ernie model
"""
import sys
from collections import namedtuple
import numpy as np
sys.path.append("..")
from preprocess.ernie.task_reader import BaseReader, tokenization
def pad_batch_data(insts,
pad_idx=0,
max_len=128,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False,
return_seq_lens=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
# max_len = max(len(inst) for inst in insts)
max_len = max_len
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype("int64").reshape([-1])]
return return_list if len(return_list) > 1 else return_list[0]
class SequenceLabelReader(BaseReader):
"""SequenceLabelReader"""
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
# padding
padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
batch_token_ids,
max_len=self.max_seq_len,
pad_idx=self.pad_id,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids, max_len=self.max_seq_len, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids, max_len=self.max_seq_len, pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids,
max_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, padded_label_ids, batch_seq_lens
]
return return_list
def _reseg_token_label(self, tokens, labels, tokenizer):
assert len(tokens) == len(labels)
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
assert len(ret_tokens) == len(ret_labels)
return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
labels = tokenization.convert_to_unicode(example.label).split(u"")
tokens, labels = self._reseg_token_label(tokens, labels, tokenizer)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
labels = [
label if label in self.label_map else u"O" for label in labels
]
label_ids = [no_entity_id] + [
self.label_map[label] for label in labels
] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,7 +12,7 @@ ...@@ -11,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -*- coding: UTF-8 -*-
import argparse import argparse
import os import os
import time import time
...@@ -25,24 +26,36 @@ import reader ...@@ -25,24 +26,36 @@ import reader
import creator import creator
sys.path.append('../models/') sys.path.append('../models/')
from model_check import check_cuda from model_check import check_cuda
from model_check import check_version
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
# 1. model parameters # 1. model parameters
model_g = utils.ArgumentGroup(parser, "model", "model configuration") model_g = utils.ArgumentGroup(parser, "model", "model configuration")
model_g.add_arg("word_emb_dim", int, 128, "The dimension in which a word is embedded.") model_g.add_arg("word_emb_dim", int, 128,
model_g.add_arg("grnn_hidden_dim", int, 128, "The number of hidden nodes in the GRNN layer.") "The dimension in which a word is embedded.")
model_g.add_arg("bigru_num", int, 2, "The number of bi_gru layers in the network.") model_g.add_arg("grnn_hidden_dim", int, 128,
"The number of hidden nodes in the GRNN layer.")
model_g.add_arg("bigru_num", int, 2,
"The number of bi_gru layers in the network.")
model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
# 2. data parameters # 2. data parameters
data_g = utils.ArgumentGroup(parser, "data", "data paths") data_g = utils.ArgumentGroup(parser, "data", "data paths")
data_g.add_arg("word_dict_path", str, "./conf/word.dic", "The path of the word dictionary.") data_g.add_arg("word_dict_path", str, "./conf/word.dic",
data_g.add_arg("label_dict_path", str, "./conf/tag.dic", "The path of the label dictionary.") "The path of the word dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic", "The path of the word replacement Dictionary.") data_g.add_arg("label_dict_path", str, "./conf/tag.dic",
data_g.add_arg("test_data", str, "./data/test.tsv", "The folder where the training data is located.") "The path of the label dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic",
"The path of the word replacement Dictionary.")
data_g.add_arg("test_data", str, "./data/test.tsv",
"The folder where the training data is located.")
data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model") data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model")
data_g.add_arg("batch_size", int, 200, "The number of sequences contained in a mini-batch, " data_g.add_arg(
"or the maximum number of tokens (include paddings) contained in a mini-batch.") "batch_size", int, 200,
"The number of sequences contained in a mini-batch, "
"or the maximum number of tokens (include paddings) contained in a mini-batch."
)
def do_eval(args): def do_eval(args):
dataset = reader.Dataset(args) dataset = reader.Dataset(args)
...@@ -60,23 +73,23 @@ def do_eval(args): ...@@ -60,23 +73,23 @@ def do_eval(args):
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
pyreader = creator.create_pyreader(args, file_name=args.test_data, pyreader = creator.create_pyreader(
feed_list=test_ret['feed_list'], args,
place=place, file_name=args.test_data,
model='lac', feed_list=test_ret['feed_list'],
reader=dataset, place=place,
mode='test') model='lac',
reader=dataset,
mode='test')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# load model # load model
utils.init_checkpoint(exe, args.init_checkpoint, test_program) utils.init_checkpoint(exe, args.init_checkpoint, test_program)
test_process(exe=exe, test_process(
program=test_program, exe=exe, program=test_program, reader=pyreader, test_ret=test_ret)
reader=pyreader,
test_ret=test_ret
)
def test_process(exe, program, reader, test_ret): def test_process(exe, program, reader, test_ret):
""" """
...@@ -91,22 +104,24 @@ def test_process(exe, program, reader, test_ret): ...@@ -91,22 +104,24 @@ def test_process(exe, program, reader, test_ret):
start_time = time.time() start_time = time.time()
for data in reader(): for data in reader():
nums_infer, nums_label, nums_correct = exe.run(program, nums_infer, nums_label, nums_correct = exe.run(
fetch_list=[ program,
test_ret["num_infer_chunks"], fetch_list=[
test_ret["num_label_chunks"], test_ret["num_infer_chunks"],
test_ret["num_correct_chunks"], test_ret["num_label_chunks"],
], test_ret["num_correct_chunks"],
feed=data, ],
) feed=data, )
test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct) test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct)
precision, recall, f1 = test_ret["chunk_evaluator"].eval() precision, recall, f1 = test_ret["chunk_evaluator"].eval()
end_time = time.time() end_time = time.time()
print("[test] P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" print("[test] P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" %
% (precision, recall, f1, end_time - start_time)) (precision, recall, f1, end_time - start_time))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version()
do_eval(args) do_eval(args)
...@@ -12,6 +12,8 @@ import reader ...@@ -12,6 +12,8 @@ import reader
import utils import utils
sys.path.append('../models/') sys.path.append('../models/')
from model_check import check_cuda from model_check import check_cuda
from model_check import check_version
def save_inference_model(args): def save_inference_model(args):
...@@ -29,20 +31,19 @@ def save_inference_model(args): ...@@ -29,20 +31,19 @@ def save_inference_model(args):
args, dataset.vocab_size, dataset.num_labels, mode='infer') args, dataset.vocab_size, dataset.num_labels, mode='infer')
infer_program = infer_program.clone(for_test=True) infer_program = infer_program.clone(for_test=True)
# load pretrain check point # load pretrain check point
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
utils.init_checkpoint(exe, args.init_checkpoint, infer_program) utils.init_checkpoint(exe, args.init_checkpoint, infer_program)
fluid.io.save_inference_model(args.inference_save_dir, fluid.io.save_inference_model(
['words'], args.inference_save_dir,
infer_ret['crf_decode'], ['words'],
exe, infer_ret['crf_decode'],
main_program=infer_program, exe,
model_filename='model.pdmodel', main_program=infer_program,
params_filename='params.pdparams', model_filename='model.pdmodel',
) params_filename='params.pdparams', )
def test_inference_model(model_dir, text_list, dataset): def test_inference_model(model_dir, text_list, dataset):
...@@ -67,44 +68,46 @@ def test_inference_model(model_dir, text_list, dataset): ...@@ -67,44 +68,46 @@ def test_inference_model(model_dir, text_list, dataset):
tensor_words = fluid.create_lod_tensor(lod, base_shape, place) tensor_words = fluid.create_lod_tensor(lod, base_shape, place)
# for empty input, output the same empty # for empty input, output the same empty
if(sum(base_shape[0]) == 0 ): if (sum(base_shape[0]) == 0):
crf_decode = [tensor_words] crf_decode = [tensor_words]
else: else:
# load inference model # load inference model
inference_scope = fluid.core.Scope() inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
[inferencer, feed_target_names, [inferencer, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_dir, exe, fetch_targets] = fluid.io.load_inference_model(
model_filename='model.pdmodel', model_dir,
params_filename='params.pdparams', exe,
) model_filename='model.pdmodel',
params_filename='params.pdparams', )
assert feed_target_names[0] == "words" assert feed_target_names[0] == "words"
print("Load inference model from %s"%(model_dir)) print("Load inference model from %s" % (model_dir))
# get lac result # get lac result
crf_decode = exe.run(inferencer, crf_decode = exe.run(
feed={feed_target_names[0]:tensor_words}, inferencer,
fetch_list=fetch_targets, feed={feed_target_names[0]: tensor_words},
return_numpy=False, fetch_list=fetch_targets,
use_program_cache=True, return_numpy=False,
) use_program_cache=True, )
# parse the crf_decode result # parse the crf_decode result
result = utils.parse_result(tensor_words,crf_decode[0], dataset) result = utils.parse_result(tensor_words, crf_decode[0], dataset)
for i,(sent, tags) in enumerate(result): for i, (sent, tags) in enumerate(result):
result_list = ['(%s, %s)'%(ch, tag) for ch, tag in zip(sent,tags)] result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
print(''.join(result_list)) print(''.join(result_list))
if __name__=="__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
utils.load_yaml(parser,'conf/args.yaml') utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version()
print("save inference model") print("save inference model")
save_inference_model(args) save_inference_model(args)
print("inference model save in %s"%args.inference_save_dir) print("inference model save in %s" % args.inference_save_dir)
print("test inference model") print("test inference model")
dataset = reader.Dataset(args) dataset = reader.Dataset(args)
test_data = [u'百度是一家高科技公司', u'中山大学是岭南第一学府'] test_data = [u'百度是一家高科技公司', u'中山大学是岭南第一学府']
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,7 +12,7 @@ ...@@ -11,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -*- coding: UTF-8 -*-
import argparse import argparse
import os import os
import time import time
...@@ -25,24 +26,36 @@ import reader ...@@ -25,24 +26,36 @@ import reader
import creator import creator
sys.path.append('../models/') sys.path.append('../models/')
from model_check import check_cuda from model_check import check_cuda
from model_check import check_version
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
# 1. model parameters # 1. model parameters
model_g = utils.ArgumentGroup(parser, "model", "model configuration") model_g = utils.ArgumentGroup(parser, "model", "model configuration")
model_g.add_arg("word_emb_dim", int, 128, "The dimension in which a word is embedded.") model_g.add_arg("word_emb_dim", int, 128,
model_g.add_arg("grnn_hidden_dim", int, 256, "The number of hidden nodes in the GRNN layer.") "The dimension in which a word is embedded.")
model_g.add_arg("bigru_num", int, 2, "The number of bi_gru layers in the network.") model_g.add_arg("grnn_hidden_dim", int, 256,
"The number of hidden nodes in the GRNN layer.")
model_g.add_arg("bigru_num", int, 2,
"The number of bi_gru layers in the network.")
model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
# 2. data parameters # 2. data parameters
data_g = utils.ArgumentGroup(parser, "data", "data paths") data_g = utils.ArgumentGroup(parser, "data", "data paths")
data_g.add_arg("word_dict_path", str, "./conf/word.dic", "The path of the word dictionary.") data_g.add_arg("word_dict_path", str, "./conf/word.dic",
data_g.add_arg("label_dict_path", str, "./conf/tag.dic", "The path of the label dictionary.") "The path of the word dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic", "The path of the word replacement Dictionary.") data_g.add_arg("label_dict_path", str, "./conf/tag.dic",
data_g.add_arg("infer_data", str, "./data/infer.tsv", "The folder where the training data is located.") "The path of the label dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic",
"The path of the word replacement Dictionary.")
data_g.add_arg("infer_data", str, "./data/infer.tsv",
"The folder where the training data is located.")
data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model") data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model")
data_g.add_arg("batch_size", int, 200, "The number of sequences contained in a mini-batch, " data_g.add_arg(
"or the maximum number of tokens (include paddings) contained in a mini-batch.") "batch_size", int, 200,
"The number of sequences contained in a mini-batch, "
"or the maximum number of tokens (include paddings) contained in a mini-batch."
)
def do_infer(args): def do_infer(args):
dataset = reader.Dataset(args) dataset = reader.Dataset(args)
...@@ -60,14 +73,14 @@ def do_infer(args): ...@@ -60,14 +73,14 @@ def do_infer(args):
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
pyreader = creator.create_pyreader(
args,
pyreader = creator.create_pyreader(args, file_name=args.infer_data, file_name=args.infer_data,
feed_list=infer_ret['feed_list'], feed_list=infer_ret['feed_list'],
place=place, place=place,
model='lac', model='lac',
reader=dataset, reader=dataset,
mode='infer') mode='infer')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -80,8 +93,7 @@ def do_infer(args): ...@@ -80,8 +93,7 @@ def do_infer(args):
program=infer_program, program=infer_program,
reader=pyreader, reader=pyreader,
fetch_vars=[infer_ret['words'], infer_ret['crf_decode']], fetch_vars=[infer_ret['words'], infer_ret['crf_decode']],
dataset=dataset dataset=dataset)
)
for sent, tags in result: for sent, tags in result:
result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)] result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
print(''.join(result_list)) print(''.join(result_list))
...@@ -95,8 +107,9 @@ def infer_process(exe, program, reader, fetch_vars, dataset): ...@@ -95,8 +107,9 @@ def infer_process(exe, program, reader, fetch_vars, dataset):
:param reader: data reader :param reader: data reader
:return: the list of prediction result :return: the list of prediction result
""" """
def input_check(data): def input_check(data):
if data[0]['words'].lod()[0][-1]==0: if data[0]['words'].lod()[0][-1] == 0:
return data[0]['words'] return data[0]['words']
return None return None
...@@ -107,17 +120,18 @@ def infer_process(exe, program, reader, fetch_vars, dataset): ...@@ -107,17 +120,18 @@ def infer_process(exe, program, reader, fetch_vars, dataset):
results += utils.parse_result(crf_decode, crf_decode, dataset) results += utils.parse_result(crf_decode, crf_decode, dataset)
continue continue
words, crf_decode = exe.run(program, words, crf_decode = exe.run(
fetch_list=fetch_vars, program,
feed=data, fetch_list=fetch_vars,
return_numpy=False, feed=data,
use_program_cache=True, return_numpy=False,
) use_program_cache=True, )
results += utils.parse_result(words, crf_decode, dataset) results += utils.parse_result(words, crf_decode, dataset)
return results return results
if __name__=="__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version()
do_infer(args) do_infer(args)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
""" """
The file_reader converts raw corpus to input. The file_reader converts raw corpus to input.
""" """
import os import os
import argparse import argparse
import __future__ import __future__
...@@ -73,7 +74,7 @@ class Dataset(object): ...@@ -73,7 +74,7 @@ class Dataset(object):
def get_num_examples(self, filename): def get_num_examples(self, filename):
"""num of line of file""" """num of line of file"""
return sum(1 for line in open(filename, "r")) return sum(1 for line in io.open(filename, "r", encoding='utf8'))
def word_to_ids(self, words): def word_to_ids(self, words):
"""convert word to word index""" """convert word to word index"""
...@@ -107,16 +108,17 @@ class Dataset(object): ...@@ -107,16 +108,17 @@ class Dataset(object):
fread = io.open(filename, "r", encoding="utf-8") fread = io.open(filename, "r", encoding="utf-8")
if mode == "infer": if mode == "infer":
for line in fread: for line in fread:
words= line.strip() words = line.strip()
word_ids = self.word_to_ids(words) word_ids = self.word_to_ids(words)
yield (word_ids[0:max_seq_len],) yield (word_ids[0:max_seq_len], )
else: else:
headline = next(fread) headline = next(fread)
headline = headline.strip().split('\t') headline = headline.strip().split('\t')
assert len(headline) == 2 and headline[0] == "text_a" and headline[1] == "label" assert len(headline) == 2 and headline[
0] == "text_a" and headline[1] == "label"
for line in fread: for line in fread:
words, labels = line.strip("\n").split("\t") words, labels = line.strip("\n").split("\t")
if len(words)<1: if len(words) < 1:
continue continue
word_ids = self.word_to_ids(words.split("\002")) word_ids = self.word_to_ids(words.split("\002"))
label_ids = self.label_to_ids(labels.split("\002")) label_ids = self.label_to_ids(labels.split("\002"))
......
...@@ -37,6 +37,8 @@ import utils ...@@ -37,6 +37,8 @@ import utils
sys.path.append("..") sys.path.append("..")
from models.representation.ernie import ErnieConfig from models.representation.ernie import ErnieConfig
from models.model_check import check_cuda from models.model_check import check_cuda
from models.model_check import check_version
def evaluate(exe, test_program, test_pyreader, test_ret): def evaluate(exe, test_program, test_pyreader, test_ret):
""" """
...@@ -54,8 +56,7 @@ def evaluate(exe, test_program, test_pyreader, test_ret): ...@@ -54,8 +56,7 @@ def evaluate(exe, test_program, test_pyreader, test_ret):
test_ret["num_label_chunks"], test_ret["num_label_chunks"],
test_ret["num_correct_chunks"], test_ret["num_correct_chunks"],
], ],
feed=data[0] feed=data[0])
)
total_loss.append(loss) total_loss.append(loss)
test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct) test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct)
...@@ -63,9 +64,11 @@ def evaluate(exe, test_program, test_pyreader, test_ret): ...@@ -63,9 +64,11 @@ def evaluate(exe, test_program, test_pyreader, test_ret):
precision, recall, f1 = test_ret["chunk_evaluator"].eval() precision, recall, f1 = test_ret["chunk_evaluator"].eval()
end_time = time.time() end_time = time.time()
print("\t[test] loss: %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" print(
"\t[test] loss: %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s"
% (np.mean(total_loss), precision, recall, f1, end_time - start_time)) % (np.mean(total_loss), precision, recall, f1, end_time - start_time))
def do_train(args): def do_train(args):
""" """
Main Function Main Function
...@@ -79,14 +82,15 @@ def do_train(args): ...@@ -79,14 +82,15 @@ def do_train(args):
else: else:
dev_count = min(multiprocessing.cpu_count(), args.cpu_num) dev_count = min(multiprocessing.cpu_count(), args.cpu_num)
if (dev_count < args.cpu_num): if (dev_count < args.cpu_num):
print("WARNING: The total CPU NUM in this machine is %d, which is less than cpu_num parameter you set. " print(
"Change the cpu_num from %d to %d"%(dev_count, args.cpu_num, dev_count)) "WARNING: The total CPU NUM in this machine is %d, which is less than cpu_num parameter you set. "
"Change the cpu_num from %d to %d" %
(dev_count, args.cpu_num, dev_count))
os.environ['CPU_NUM'] = str(dev_count) os.environ['CPU_NUM'] = str(dev_count)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
startup_prog = fluid.Program() startup_prog = fluid.Program()
if args.random_seed is not None: if args.random_seed is not None:
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
...@@ -98,49 +102,56 @@ def do_train(args): ...@@ -98,49 +102,56 @@ def do_train(args):
train_ret = creator.create_ernie_model(args, ernie_config) train_ret = creator.create_ernie_model(args, ernie_config)
# ernie pyreader # ernie pyreader
train_pyreader = creator.create_pyreader(args, file_name=args.train_data, train_pyreader = creator.create_pyreader(
feed_list=train_ret['feed_list'], args,
model="ernie", file_name=args.train_data,
place=place) feed_list=train_ret['feed_list'],
model="ernie",
place=place)
test_program = train_program.clone(for_test=True) test_program = train_program.clone(for_test=True)
test_pyreader = creator.create_pyreader(args, file_name=args.test_data, test_pyreader = creator.create_pyreader(
feed_list=train_ret['feed_list'], args,
model="ernie", file_name=args.test_data,
place=place) feed_list=train_ret['feed_list'],
model="ernie",
optimizer = fluid.optimizer.Adam(learning_rate=args.base_learning_rate) place=place)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
optimizer = fluid.optimizer.Adam(
learning_rate=args.base_learning_rate)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
optimizer.minimize(train_ret["avg_cost"]) optimizer.minimize(train_ret["avg_cost"])
lower_mem, upper_mem, unit = fluid.contrib.memory_usage( lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size) program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" % print("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
print("Device count: %d" % dev_count) print("Device count: %d" % dev_count)
exe.run(startup_prog) exe.run(startup_prog)
# load checkpoints # load checkpoints
if args.init_checkpoint and args.init_pretraining_params: if args.init_checkpoint and args.init_pretraining_params:
print("WARNING: args 'init_checkpoint' and 'init_pretraining_params' " print("WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.") "both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint: if args.init_checkpoint:
utils.init_checkpoint(exe, args.init_checkpoint, startup_prog) utils.init_checkpoint(exe, args.init_checkpoint, startup_prog)
elif args.init_pretraining_params: elif args.init_pretraining_params:
utils.init_pretraining_params(exe, args.init_pretraining_params, startup_prog) utils.init_pretraining_params(exe, args.init_pretraining_params,
startup_prog)
if dev_count>1 and not args.use_cuda: if dev_count > 1 and not args.use_cuda:
device = "GPU" if args.use_cuda else "CPU" device = "GPU" if args.use_cuda else "CPU"
print("%d %s are used to train model"%(dev_count, device)) print("%d %s are used to train model" % (dev_count, device))
# multi cpu/gpu config # multi cpu/gpu config
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
compiled_prog = fluid.compiler.CompiledProgram(train_program).with_data_parallel( compiled_prog = fluid.compiler.CompiledProgram(
loss_name=train_ret['avg_cost'].name, train_program).with_data_parallel(
build_strategy=build_strategy, loss_name=train_ret['avg_cost'].name,
exec_strategy=exec_strategy) build_strategy=build_strategy,
exec_strategy=exec_strategy)
else: else:
compiled_prog = fluid.compiler.CompiledProgram(train_program) compiled_prog = fluid.compiler.CompiledProgram(train_program)
...@@ -160,16 +171,24 @@ def do_train(args): ...@@ -160,16 +171,24 @@ def do_train(args):
fetch_list = [] fetch_list = []
start_time = time.time() start_time = time.time()
outputs = exe.run(program=compiled_prog, feed=data[0], fetch_list=fetch_list)
outputs = exe.run(program=compiled_prog,
feed=data[0],
fetch_list=fetch_list)
end_time = time.time() end_time = time.time()
if steps % args.print_steps == 0: if steps % args.print_steps == 0:
loss, precision, recall, f1_score = [np.mean(x) for x in outputs] loss, precision, recall, f1_score = [
print("[train] batch_id = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f, " np.mean(x) for x in outputs
"pyreader queue_size: %d " % (steps, loss, precision, recall, f1_score, ]
end_time - start_time, train_pyreader.queue.size())) print(
"[train] batch_id = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f, "
"pyreader queue_size: %d " %
(steps, loss, precision, recall, f1_score,
end_time - start_time, train_pyreader.queue.size()))
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.model_save_dir, "step_" + str(steps)) save_path = os.path.join(args.model_save_dir,
"step_" + str(steps))
print("\tsaving model as %s" % (save_path)) print("\tsaving model as %s" % (save_path))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
...@@ -180,7 +199,6 @@ def do_train(args): ...@@ -180,7 +199,6 @@ def do_train(args):
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
def do_eval(args): def do_eval(args):
# init executor # init executor
if args.use_cuda: if args.use_cuda:
...@@ -196,11 +214,13 @@ def do_eval(args): ...@@ -196,11 +214,13 @@ def do_eval(args):
test_ret = creator.create_ernie_model(args, ernie_config) test_ret = creator.create_ernie_model(args, ernie_config)
test_program = test_program.clone(for_test=True) test_program = test_program.clone(for_test=True)
pyreader = creator.create_pyreader(args, file_name=args.test_data, pyreader = creator.create_pyreader(
feed_list=test_ret['feed_list'], args,
model="ernie", file_name=args.test_data,
place=place, feed_list=test_ret['feed_list'],
mode='test',) model="ernie",
place=place,
mode='test', )
print('program startup') print('program startup')
...@@ -210,11 +230,13 @@ def do_eval(args): ...@@ -210,11 +230,13 @@ def do_eval(args):
print('program loading') print('program loading')
# load model # load model
if not args.init_checkpoint: if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if only doing test or infer!") raise ValueError(
"args 'init_checkpoint' should be set if only doing test or infer!")
utils.init_checkpoint(exe, args.init_checkpoint, test_program) utils.init_checkpoint(exe, args.init_checkpoint, test_program)
evaluate(exe, test_program, pyreader, test_ret) evaluate(exe, test_program, pyreader, test_ret)
def do_infer(args): def do_infer(args):
# init executor # init executor
if args.use_cuda: if args.use_cuda:
...@@ -231,46 +253,58 @@ def do_infer(args): ...@@ -231,46 +253,58 @@ def do_infer(args):
infer_ret = creator.create_ernie_model(args, ernie_config) infer_ret = creator.create_ernie_model(args, ernie_config)
infer_program = infer_program.clone(for_test=True) infer_program = infer_program.clone(for_test=True)
print(args.test_data) print(args.test_data)
pyreader, reader = creator.create_pyreader(args, file_name=args.test_data, pyreader, reader = creator.create_pyreader(
feed_list=infer_ret['feed_list'], args,
model="ernie", file_name=args.test_data,
place=place, feed_list=infer_ret['feed_list'],
return_reader=True, model="ernie",
mode='test') place=place,
return_reader=True,
mode='test')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# load model # load model
if not args.init_checkpoint: if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if only doing test or infer!") raise ValueError(
"args 'init_checkpoint' should be set if only doing test or infer!")
utils.init_checkpoint(exe, args.init_checkpoint, infer_program) utils.init_checkpoint(exe, args.init_checkpoint, infer_program)
# create dict # create dict
id2word_dict = dict([(str(word_id), word) for word, word_id in reader.vocab.items()]) id2word_dict = dict(
id2label_dict = dict([(str(label_id), label) for label, label_id in reader.label_map.items()]) [(str(word_id), word) for word, word_id in reader.vocab.items()])
id2label_dict = dict([(str(label_id), label)
for label, label_id in reader.label_map.items()])
Dataset = namedtuple("Dataset", ["id2word_dict", "id2label_dict"]) Dataset = namedtuple("Dataset", ["id2word_dict", "id2label_dict"])
dataset = Dataset(id2word_dict, id2label_dict) dataset = Dataset(id2word_dict, id2label_dict)
# make prediction # make prediction
for data in pyreader(): for data in pyreader():
(words, crf_decode) = exe.run(infer_program, (words, crf_decode, seq_lens) = exe.run(infer_program,
fetch_list=[infer_ret["words"], infer_ret["crf_decode"]], fetch_list=[
feed=data[0], infer_ret["words"],
return_numpy=False) infer_ret["crf_decode"],
infer_ret["seq_lens"]
],
feed=data[0],
return_numpy=True)
# User should notice that words had been clipped if long than args.max_seq_len # User should notice that words had been clipped if long than args.max_seq_len
results = utils.parse_result(words, crf_decode, dataset) results = utils.parse_padding_result(words, crf_decode, seq_lens,
dataset)
for sent, tags in results: for sent, tags in results:
result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)] result_list = [
'(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)
]
print(''.join(result_list)) print(''.join(result_list))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
utils.load_yaml(parser, './conf/ernie_args.yaml') utils.load_yaml(parser, './conf/ernie_args.yaml')
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version()
utils.print_arguments(args) utils.print_arguments(args)
if args.mode == 'train': if args.mode == 'train':
...@@ -281,4 +315,3 @@ if __name__ == "__main__": ...@@ -281,4 +315,3 @@ if __name__ == "__main__":
do_infer(args) do_infer(args)
else: else:
print("Usage: %s --mode train|eval|infer " % sys.argv[0]) print("Usage: %s --mode train|eval|infer " % sys.argv[0])
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,7 +12,6 @@ ...@@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -*- coding: UTF-8 -*-
import os import os
import sys import sys
...@@ -28,9 +28,11 @@ import paddle.fluid as fluid ...@@ -28,9 +28,11 @@ import paddle.fluid as fluid
import reader import reader
import utils import utils
import creator import creator
from eval import test_process from eval import test_process
sys.path.append('../models/') sys.path.append('../models/')
from model_check import check_cuda from model_check import check_cuda
from model_check import check_version
# the function to train model # the function to train model
def do_train(args): def do_train(args):
...@@ -47,10 +49,10 @@ def do_train(args): ...@@ -47,10 +49,10 @@ def do_train(args):
args, dataset.vocab_size, dataset.num_labels, mode='train') args, dataset.vocab_size, dataset.num_labels, mode='train')
test_program = train_program.clone(for_test=True) test_program = train_program.clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=args.base_learning_rate) optimizer = fluid.optimizer.Adam(
learning_rate=args.base_learning_rate)
optimizer.minimize(train_ret["avg_cost"]) optimizer.minimize(train_ret["avg_cost"])
# init executor # init executor
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
...@@ -58,43 +60,48 @@ def do_train(args): ...@@ -58,43 +60,48 @@ def do_train(args):
else: else:
dev_count = min(multiprocessing.cpu_count(), args.cpu_num) dev_count = min(multiprocessing.cpu_count(), args.cpu_num)
if (dev_count < args.cpu_num): if (dev_count < args.cpu_num):
print("WARNING: The total CPU NUM in this machine is %d, which is less than cpu_num parameter you set. " print(
"Change the cpu_num from %d to %d" % (dev_count, args.cpu_num, dev_count)) "WARNING: The total CPU NUM in this machine is %d, which is less than cpu_num parameter you set. "
"Change the cpu_num from %d to %d" %
(dev_count, args.cpu_num, dev_count))
os.environ['CPU_NUM'] = str(dev_count) os.environ['CPU_NUM'] = str(dev_count)
place = fluid.CPUPlace() place = fluid.CPUPlace()
train_reader = creator.create_pyreader(args, file_name=args.train_data, train_reader = creator.create_pyreader(
feed_list=train_ret['feed_list'], args,
place=place, file_name=args.train_data,
model='lac', feed_list=train_ret['feed_list'],
reader=dataset) place=place,
model='lac',
test_reader = creator.create_pyreader(args, file_name=args.test_data, reader=dataset)
feed_list=train_ret['feed_list'],
place=place, test_reader = creator.create_pyreader(
model='lac', args,
reader=dataset, file_name=args.test_data,
mode='test') feed_list=train_ret['feed_list'],
place=place,
model='lac',
reader=dataset,
mode='test')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program) exe.run(startup_program)
if args.init_checkpoint: if args.init_checkpoint:
utils.init_checkpoint(exe, args.init_checkpoint, train_program) utils.init_checkpoint(exe, args.init_checkpoint, train_program)
if dev_count>1: if dev_count > 1:
device = "GPU" if args.use_cuda else "CPU" device = "GPU" if args.use_cuda else "CPU"
print("%d %s are used to train model"%(dev_count, device)) print("%d %s are used to train model" % (dev_count, device))
# multi cpu/gpu config # multi cpu/gpu config
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
# exec_strategy.num_threads = dev_count * 6
build_strategy = fluid.compiler.BuildStrategy() build_strategy = fluid.compiler.BuildStrategy()
# build_strategy.enable_inplace = True
compiled_prog = fluid.compiler.CompiledProgram(train_program).with_data_parallel( compiled_prog = fluid.compiler.CompiledProgram(
loss_name=train_ret['avg_cost'].name, train_program).with_data_parallel(
build_strategy=build_strategy, loss_name=train_ret['avg_cost'].name,
exec_strategy=exec_strategy build_strategy=build_strategy,
) exec_strategy=exec_strategy)
else: else:
compiled_prog = fluid.compiler.CompiledProgram(train_program) compiled_prog = fluid.compiler.CompiledProgram(train_program)
...@@ -112,10 +119,8 @@ def do_train(args): ...@@ -112,10 +119,8 @@ def do_train(args):
# this is for minimizing the fetching op, saving the training speed. # this is for minimizing the fetching op, saving the training speed.
if step % args.print_steps == 0: if step % args.print_steps == 0:
fetch_list = [ fetch_list = [
train_ret["avg_cost"], train_ret["avg_cost"], train_ret["precision"],
train_ret["precision"], train_ret["recall"], train_ret["f1_score"]
train_ret["recall"],
train_ret["f1_score"]
] ]
else: else:
fetch_list = [] fetch_list = []
...@@ -124,15 +129,18 @@ def do_train(args): ...@@ -124,15 +129,18 @@ def do_train(args):
outputs = exe.run( outputs = exe.run(
compiled_prog, compiled_prog,
fetch_list=fetch_list, fetch_list=fetch_list,
feed=data[0], feed=data[0], )
)
end_time = time.time() end_time = time.time()
if step % args.print_steps == 0: if step % args.print_steps == 0:
avg_cost, precision, recall, f1_score = [np.mean(x) for x in outputs] avg_cost, precision, recall, f1_score = [
np.mean(x) for x in outputs
]
print("[train] step = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f" % ( print(
step, avg_cost, precision, recall, f1_score, end_time - start_time)) "[train] step = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f"
% (step, avg_cost, precision, recall, f1_score,
end_time - start_time))
if step % args.validation_steps == 0: if step % args.validation_steps == 0:
test_process(exe, test_program, test_reader, train_ret) test_process(exe, test_program, test_reader, train_ret)
...@@ -142,11 +150,10 @@ def do_train(args): ...@@ -142,11 +150,10 @@ def do_train(args):
# save checkpoints # save checkpoints
if step % args.save_steps == 0 and step != 0: if step % args.save_steps == 0 and step != 0:
save_path = os.path.join(args.model_save_dir, "step_" + str(step)) save_path = os.path.join(args.model_save_dir,
"step_" + str(step))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
step += 1 step += 1
if args.enable_ce: if args.enable_ce:
card_num = get_cards() card_num = get_cards()
...@@ -163,16 +170,11 @@ def do_train(args): ...@@ -163,16 +170,11 @@ def do_train(args):
ce_f1 = ce_info[-2][4] ce_f1 = ce_info[-2][4]
except: except:
print("ce info error") print("ce info error")
print("kpis\teach_step_duration_card%s\t%s" % print("kpis\teach_step_duration_card%s\t%s" % (card_num, ce_time))
(card_num, ce_time)) print("kpis\ttrain_cost_card%s\t%f" % (card_num, ce_cost))
print("kpis\ttrain_cost_card%s\t%f" % print("kpis\ttrain_precision_card%s\t%f" % (card_num, ce_p))
(card_num, ce_cost)) print("kpis\ttrain_recall_card%s\t%f" % (card_num, ce_r))
print("kpis\ttrain_precision_card%s\t%f" % print("kpis\ttrain_f1_card%s\t%f" % (card_num, ce_f1))
(card_num, ce_p))
print("kpis\ttrain_recall_card%s\t%f" %
(card_num, ce_r))
print("kpis\ttrain_f1_card%s\t%f" %
(card_num, ce_f1))
def get_cards(): def get_cards():
...@@ -182,17 +184,18 @@ def get_cards(): ...@@ -182,17 +184,18 @@ def get_cards():
num = len(cards.split(",")) num = len(cards.split(","))
return num return num
if __name__ == "__main__": if __name__ == "__main__":
# 参数控制可以根据需求使用argparse,yaml或者json # 参数控制可以根据需求使用argparse,yaml或者json
# 对NLP任务推荐使用PALM下定义的configure,可以统一argparse,yaml或者json格式的配置文件。 # 对NLP任务推荐使用PALM下定义的configure,可以统一argparse,yaml或者json格式的配置文件。
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
utils.load_yaml(parser,'conf/args.yaml') utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version()
print(args) print(args)
do_train(args) do_train(args)
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import yaml import yaml
import io
def str2bool(v): def str2bool(v):
...@@ -48,19 +49,21 @@ class ArgumentGroup(object): ...@@ -48,19 +49,21 @@ class ArgumentGroup(object):
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.',
**kwargs) **kwargs)
def load_yaml(parser, file_name, **kwargs): def load_yaml(parser, file_name, **kwargs):
with open(file_name) as f: with io.open(file_name, 'r', encoding='utf8') as f:
args = yaml.load(f) args = yaml.load(f)
for title in args: for title in args:
group = parser.add_argument_group(title=title, description='') group = parser.add_argument_group(title=title, description='')
for name in args[title]: for name in args[title]:
_type = type(args[title][name]['val']) _type = type(args[title][name]['val'])
_type = str2bool if _type==bool else _type _type = str2bool if _type == bool else _type
group.add_argument( group.add_argument(
"--"+name, "--" + name,
default=args[title][name]['val'], default=args[title][name]['val'],
type=_type, type=_type,
help=args[title][name]['meaning'] + ' Default: %(default)s.', help=args[title][name]['meaning'] +
' Default: %(default)s.',
**kwargs) **kwargs)
...@@ -115,7 +118,53 @@ def parse_result(words, crf_decode, dataset): ...@@ -115,7 +118,53 @@ def parse_result(words, crf_decode, dataset):
for sent_index in range(batch_size): for sent_index in range(batch_size):
begin, end = offset_list[sent_index], offset_list[sent_index + 1] begin, end = offset_list[sent_index], offset_list[sent_index + 1]
sent = [dataset.id2word_dict[str(id[0])] for id in words[begin:end]] sent = [dataset.id2word_dict[str(id[0])] for id in words[begin:end]]
tags = [dataset.id2label_dict[str(id[0])] for id in crf_decode[begin:end]] tags = [
dataset.id2label_dict[str(id[0])] for id in crf_decode[begin:end]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
def parse_padding_result(words, crf_decode, seq_lens, dataset):
""" parse padding result """
words = np.squeeze(words)
batch_size = len(seq_lens)
batch_out = []
for sent_index in range(batch_size):
sent = [
dataset.id2word_dict[str(id)]
for id in words[sent_index][1:seq_lens[sent_index] - 1]
]
tags = [
dataset.id2label_dict[str(id)]
for id in crf_decode[sent_index][1:seq_lens[sent_index] - 1]
]
sent_out = [] sent_out = []
tags_out = [] tags_out = []
...@@ -128,7 +177,7 @@ def parse_result(words, crf_decode, dataset): ...@@ -128,7 +177,7 @@ def parse_result(words, crf_decode, dataset):
continue continue
# for the beginning of word # for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind-1]!="O"): if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word) sent_out.append(parital_word)
tags_out.append(tag.split('-')[0]) tags_out.append(tag.split('-')[0])
parital_word = sent[ind] parital_word = sent[ind]
...@@ -137,12 +186,13 @@ def parse_result(words, crf_decode, dataset): ...@@ -137,12 +186,13 @@ def parse_result(words, crf_decode, dataset):
parital_word += sent[ind] parital_word += sent[ind]
# append the last word, except for len(tags)=0 # append the last word, except for len(tags)=0
if len(sent_out)<len(tags_out): if len(sent_out) < len(tags_out):
sent_out.append(parital_word) sent_out.append(parital_word)
batch_out.append([sent_out,tags_out]) batch_out.append([sent_out, tags_out])
return batch_out return batch_out
def init_checkpoint(exe, init_checkpoint_path, main_program): def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
Init CheckPoint Init CheckPoint
...@@ -165,6 +215,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -165,6 +215,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
predicate=existed_persitables) predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
def init_pretraining_params(exe, def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program, main_program,
......
...@@ -21,7 +21,8 @@ import math ...@@ -21,7 +21,8 @@ import math
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer from paddle.fluid.initializer import NormalInitializer
def lex_net(word, args, vocab_size, num_labels, for_infer = True, target=None):
def lex_net(word, args, vocab_size, num_labels, for_infer=True, target=None):
""" """
define the lexical analysis network structure define the lexical analysis network structure
word: stores the input of the model word: stores the input of the model
...@@ -85,7 +86,7 @@ def lex_net(word, args, vocab_size, num_labels, for_infer = True, target=None): ...@@ -85,7 +86,7 @@ def lex_net(word, args, vocab_size, num_labels, for_infer = True, target=None):
""" """
Configure the network Configure the network
""" """
word_embedding = fluid.layers.embedding( word_embedding = fluid.embedding(
input=word, input=word,
size=[vocab_size, word_emb_dim], size=[vocab_size, word_emb_dim],
dtype='float32', dtype='float32',
...@@ -115,18 +116,16 @@ def lex_net(word, args, vocab_size, num_labels, for_infer = True, target=None): ...@@ -115,18 +116,16 @@ def lex_net(word, args, vocab_size, num_labels, for_infer = True, target=None):
input=emission, input=emission,
label=target, label=target,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='crfw', name='crfw', learning_rate=crf_lr))
learning_rate=crf_lr))
avg_cost = fluid.layers.mean(x=crf_cost) avg_cost = fluid.layers.mean(x=crf_cost)
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw')) input=emission, param_attr=fluid.ParamAttr(name='crfw'))
return avg_cost,crf_decode return avg_cost, crf_decode
else: else:
size = emission.shape[1] size = emission.shape[1]
fluid.layers.create_parameter(shape = [size + 2, size], fluid.layers.create_parameter(
dtype=emission.dtype, shape=[size + 2, size], dtype=emission.dtype, name='crfw')
name='crfw')
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw')) input=emission, param_attr=fluid.ParamAttr(name='crfw'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册