未验证 提交 2bd8028a 编写于 作者: G gmcather 提交者: GitHub

Merge pull request #886 from jshower/develop

增加在云平台上运行的中文序列标注任务
# 使用ParallelExecutor的中文命名实体识别示例
以下是本例的简要目录结构及说明:
```text
.
├── data # 存储运行本例所依赖的数据,从外部获取
├── reader.py # 数据读取接口, 从外部获取
├── README.md # 文档
├── train.py # 训练脚本
├── infer.py # 预测脚本
```
## 数据
在data目录下,有两个文件夹,train_files中保存的是训练数据,test_files中保存的是测试数据,作为示例,在目录下我们各放置了两个文件,实际训练时,根据自己的实际需要将数据放置在对应目录,并根据数据格式,修改reader.py中的数据读取函数。
## 训练
修改 [train.py](./train.py)`main` 函数,指定数据路径,运行`python train.py`开始训练。
训练记录形如
```txt
pass_id:0, time_cost:4.92960214615s
[Train] precision:0.000862136531076, recall:0.0059880239521, f1:0.00150726226363
[Test] precision:0.000796178343949, recall:0.00335758254057, f1:0.00128713933283
pass_id:1, time_cost:0.715255975723s
[Train] precision:0.00474094141551, recall:0.00762112139358, f1:0.00584551148225
[Test] precision:0.0228873239437, recall:0.00727476217124, f1:0.0110403397028
pass_id:2, time_cost:0.740842103958s
[Train] precision:0.0120967741935, recall:0.00163309744148, f1:0.00287769784173
[Test] precision:0, recall:0.0, f1:0
```
## 预测
修改 [infer.py](./infer.py)`infer` 函数,指定:需要测试的模型的路径、测试数据、预测标记文件的路径,运行`python infer.py`开始预测。
预测结果如下
```txt
152804 O O
130048 O O
38862 10-B O
784 O O
1540 O O
4145 O O
2255 O O
0 O O
1279 O O
7793 O O
373 O O
1621 O O
815 O O
2 O O
247 24-B O
401 24-I O
```
输出分为三列,以"\t"分割,第一列是输入的词语的序号,第二列是标准结果,第三列为标记结果。多条输入序列之间以空行分隔。
24-B
24-I
27-B
27-I
20-B
20-I
21-B
21-I
22-B
22-I
23-B
23-I
28-B
28-I
29-B
29-I
12-B
12-I
11-B
11-I
10-B
10-I
13-B
13-I
38-B
38-I
14-B
14-I
16-B
16-I
33-B
33-I
18-B
18-I
31-B
31-I
30-B
30-I
37-B
37-I
36-B
36-I
35-B
35-I
19-B
19-I
32-B
32-I
O
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
import numpy as np
import paddle.fluid as fluid
import paddle
import reader
def load_reverse_dict(dict_path):
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))
def infer(model_path, batch_size, test_data_file, target_file):
word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1)
mention = fluid.layers.data(
name='mention', shape=[1], dtype='int64', lod_level=1)
target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1)
label_reverse_dict = load_reverse_dict(target_file)
test_data = paddle.batch(
reader.file_reader(test_data_file), batch_size=batch_size)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[word, mention, target], place=place)
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
for data in test_data():
crf_decode = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=fetch_targets,
return_numpy=False)
lod_info = (crf_decode[0].lod())[0]
np_data = np.array(crf_decode[0])
assert len(data) == len(lod_info) - 1
for sen_index in xrange(len(data)):
assert len(data[sen_index][0]) == lod_info[
sen_index + 1] - lod_info[sen_index]
word_index = 0
for tag_index in xrange(lod_info[sen_index],
lod_info[sen_index + 1]):
word = str(data[sen_index][0][word_index])
gold_tag = label_reverse_dict[data[sen_index][2][
word_index]]
tag = label_reverse_dict[np_data[tag_index][0]]
print word + "\t" + gold_tag + "\t" + tag
word_index += 1
print ""
if __name__ == "__main__":
infer(
model_path="output/params_pass_0",
batch_size=6,
test_data_file="data/test_files",
target_file="data/label_dict")
import os
def file_reader(file_dir):
def reader():
files = os.listdir(file_dir)
for fi in files:
for line in open(file_dir + '/' + fi, 'r'):
line = line.strip()
features = line.split(";")
word_idx = []
for item in features[1].strip().split(" "):
word_idx.append(int(item))
target_idx = []
for item in features[2].strip().split(" "):
label_index = int(item)
if label_index == 0:
label_index = 48
else:
label_index -= 1
target_idx.append(label_index)
mention_idx = []
for item in features[3].strip().split(" "):
mention_idx.append(int(item))
yield word_idx, mention_idx, target_idx,
return reader
import os
import math
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer
import reader
def load_reverse_dict(dict_path):
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def ner_net(word_dict_len, label_dict_len):
IS_SPARSE = False
word_dim = 32
mention_dict_len = 57
mention_dim = 20
grnn_hidden = 36
emb_lr = 5
init_bound = 0.1
def _net_conf(word, mark, target):
word_embedding = fluid.layers.embedding(
input=word,
size=[word_dict_len, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=emb_lr,
name="word_emb",
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound)))
mention_embedding = fluid.layers.embedding(
input=mention,
size=[mention_dict_len, mention_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=emb_lr,
name="mention_emb",
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound)))
word_embedding_r = fluid.layers.embedding(
input=word,
size=[word_dict_len, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=emb_lr,
name="word_emb_r",
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound)))
mention_embedding_r = fluid.layers.embedding(
input=mention,
size=[mention_dict_len, mention_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=emb_lr,
name="mention_emb_r",
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound)))
word_mention_vector = fluid.layers.concat(
input=[word_embedding, mention_embedding], axis=1)
word_mention_vector_r = fluid.layers.concat(
input=[word_embedding_r, mention_embedding_r], axis=1)
pre_gru = fluid.layers.fc(
input=word_mention_vector,
size=grnn_hidden * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru = fluid.layers.dynamic_gru(
input=pre_gru,
size=grnn_hidden,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
pre_gru_r = fluid.layers.fc(
input=word_mention_vector_r,
size=grnn_hidden * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru_r = fluid.layers.dynamic_gru(
input=pre_gru_r,
size=grnn_hidden,
is_reverse=True,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru_merged = fluid.layers.concat(input=[gru, gru_r], axis=1)
emission = fluid.layers.fc(
size=label_dict_len,
input=gru_merged,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
crf_cost = fluid.layers.linear_chain_crf(
input=emission,
label=target,
param_attr=fluid.ParamAttr(
name='crfw',
learning_rate=0.2, ))
avg_cost = fluid.layers.mean(x=crf_cost)
return avg_cost, emission
word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1)
mention = fluid.layers.data(
name='mention', shape=[1], dtype='int64', lod_level=1)
target = fluid.layers.data(
name="target", shape=[1], dtype='int64', lod_level=1)
avg_cost, emission = _net_conf(word, mention, target)
return avg_cost, emission, word, mention, target
def test2(exe, chunk_evaluator, inference_program, test_data, place,
cur_fetch_list):
chunk_evaluator.reset()
for data in test_data():
word = to_lodtensor(map(lambda x: x[0], data), place)
mention = to_lodtensor(map(lambda x: x[1], data), place)
target = to_lodtensor(map(lambda x: x[2], data), place)
result_list = exe.run(
inference_program,
feed={"word": word,
"mention": mention,
"target": target},
fetch_list=cur_fetch_list)
number_infer = np.array(result_list[0])
number_label = np.array(result_list[1])
number_correct = np.array(result_list[2])
chunk_evaluator.update(number_infer[0], number_label[0],
number_correct[0])
return chunk_evaluator.eval()
def test(test_exe, chunk_evaluator, inference_program, test_data, place,
cur_fetch_list):
chunk_evaluator.reset()
for data in test_data():
word = to_lodtensor(map(lambda x: x[0], data), place)
mention = to_lodtensor(map(lambda x: x[1], data), place)
target = to_lodtensor(map(lambda x: x[2], data), place)
result_list = test_exe.run(
fetch_list=cur_fetch_list,
feed={"word": word,
"mention": mention,
"target": target})
number_infer = np.array(result_list[0])
number_label = np.array(result_list[1])
number_correct = np.array(result_list[2])
chunk_evaluator.update(number_infer.sum(),
number_label.sum(), number_correct.sum())
return chunk_evaluator.eval()
def main(train_data_file, test_data_file, model_save_dir, num_passes):
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
BATCH_SIZE = 256
word_dict_len = 1942563
label_dict_len = 49
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
avg_cost, feature_out, word, mention, target = ner_net(word_dict_len,
label_dict_len)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(
name='crfw', ))
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator()
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
[num_infer_chunks, num_label_chunks, num_correct_chunks])
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.file_reader(train_data_file), buf_size=2000000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.file_reader(test_data_file), buf_size=2000000),
batch_size=BATCH_SIZE)
place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder(
feed_list=[word, mention, target], place=place)
exe = fluid.Executor(place)
exe.run(startup)
train_exe = fluid.ParallelExecutor(
loss_name=avg_cost.name, use_cuda=True)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=inference_program,
share_vars_from=train_exe)
batch_id = 0
for pass_id in xrange(num_passes):
chunk_evaluator.reset()
train_reader_iter = train_reader()
start_time = time.time()
while True:
try:
cur_batch = next(train_reader_iter)
cost, nums_infer, nums_label, nums_correct = train_exe.run(
fetch_list=[
avg_cost.name, num_infer_chunks.name,
num_label_chunks.name, num_correct_chunks.name
],
feed=feeder.feed(cur_batch))
chunk_evaluator.update(
np.array(nums_infer).sum(),
np.array(nums_label).sum(),
np.array(nums_correct).sum())
cost_list = np.array(cost)
batch_id += 1
except StopIteration:
break
end_time = time.time()
print("pass_id:" + str(pass_id) + ", time_cost:" + str(
end_time - start_time) + "s")
precision, recall, f1_score = chunk_evaluator.eval()
print("[Train] precision:" + str(precision) + ", recall:" + str(
recall) + ", f1:" + str(f1_score))
p, r, f1 = test2(
exe, chunk_evaluator, inference_program, test_reader, place,
[num_infer_chunks, num_label_chunks, num_correct_chunks])
print("[Test] precision:" + str(p) + ", recall:" + str(r) + ", f1:"
+ str(f1))
save_dirname = os.path.join(model_save_dir,
"params_pass_%d" % pass_id)
fluid.io.save_inference_model(
save_dirname, ['word', 'mention', 'target'], [crf_decode], exe)
if __name__ == "__main__":
main(
train_data_file="./data/train_files",
test_data_file="./data/test_files",
model_save_dir="./output",
num_passes=1000)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册