提交 196a75cd 编写于 作者: B Bruce 提交者: pkpk

add check_cuda & change infer output format (#2717)

上级 668567e0
......@@ -18,6 +18,8 @@ Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型
本项目依赖 PaddlePaddle 1.3.2 及以上版本,安装请参考官网 [快速安装](http://www.paddlepaddle.org/paddle#quick-start)
> Warning: GPU 和 CPU 版本的 PaddlePaddle 分别是 paddlepaddle-gpu 和 paddlepaddle,请安装时注意区别。
#### 2. 克隆代码
克隆工具集代码库到本地
```bash
......
......@@ -93,7 +93,7 @@ class Dataset(object):
for line in fread:
words = line.strip("\n").split("\002")
word_ids = self.word_to_ids(words)
yield word_ids[0:max_seq_len]
yield word_ids[0:max_seq_len], [0 for _ in word_ids][0: max_seq_len]
else:
assert len(headline) == 2 and headline[0] == "text_a" and headline[1] == "label"
for line in fread:
......
......@@ -75,6 +75,10 @@ run_type_g.add_arg("do_infer", bool, True, "Whether to perform inference.")
args = parser.parse_args()
# yapf: enable.
sys.path.append('../models/')
from model_check import check_cuda
check_cuda(args.use_cuda)
def ernie_pyreader(args, pyreader_name):
"""define standard ernie pyreader"""
pyreader = fluid.layers.py_reader(
......
......@@ -25,7 +25,6 @@ import utils
sys.path.append("../")
from models.sequence_labeling import nets
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
......@@ -71,6 +70,10 @@ parser.add_argument('--enable_ce', action='store_true', help='If set, run the ta
args = parser.parse_args()
# yapf: enable.
sys.path.append('../models/')
from model_check import check_cuda
check_cuda(args.use_cuda)
print(args)
......
......@@ -81,13 +81,27 @@ def parse_result(words, crf_decode, dataset):
for sent_index in range(batch_size):
sent_out_str = ""
sent_len = offset_list[sent_index + 1] - offset_list[sent_index]
last_word = ""
last_tag = ""
for tag_index in range(sent_len): # iterate every word in sent
index = tag_index + offset_list[sent_index]
cur_word_id = str(words[index][0])
cur_tag_id = str(crf_decode[index][0])
cur_word = dataset.id2word_dict[cur_word_id]
cur_tag = dataset.id2label_dict[cur_tag_id]
sent_out_str += cur_word + u"/" + cur_tag + u" "
if last_word == "":
last_word = cur_word
last_tag = cur_tag[:-2]
elif cur_tag.endswith("-B") or cur_tag == "O":
sent_out_str += last_word + u"/" + last_tag + u" "
last_word = cur_word
last_tag = cur_tag[:-2]
elif cur_tag.endswith("-I"):
last_word += cur_word
else:
raise ValueError("invalid tag: %s" % (cur_tag))
if cur_word != "":
sent_out_str += last_word + u"/" + last_tag + u" "
sent_out_str = to_str(sent_out_str.strip())
batch_out_str.append(sent_out_str)
return batch_out_str
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册