提交 b6f9f5f4 编写于 作者: M MRXLT

fix senta reader && lac reader

上级 8f820382
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import io
class SentaReader():
def __init__(self, vocab_path, max_seq_len=20):
self.max_seq_len = max_seq_len
self.word_dict = self.load_vocab(vocab_path)
def load_vocab(self, vocab_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(vocab_path, 'r', encoding='utf8') as f:
wid = 0
for line in f:
if line.strip() not in vocab:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def process(self, cols):
unk_id = len(self.word_dict)
pad_id = 0
wids = [
self.word_dict[x] if x in self.word_dict else unk_id for x in cols
]
seq_len = len(wids)
if seq_len < self.max_seq_len:
for i in range(self.max_seq_len - seq_len):
wids.append(pad_id)
else:
wids = wids[:self.max_seq_len]
seq_len = self.max_seq_len
return wids
......@@ -24,7 +24,7 @@ from multiprocessing import Process, Queue
class SentaService(WebService):
def __init__(
def set_config(
self,
lac_model_path,
lac_dict_path,
......@@ -33,14 +33,17 @@ class SentaService(WebService):
self.lac_client_config_path = lac_model_path + "/serving_server_conf.prototxt"
self.lac_dict_path = lac_dict_path
self.senta_dict_path = senta_dict_path
self.show = False
def show_detail(self, show=False):
self.show = show
def start_lac_service(self):
print(" ---- start lac service ---- ")
os.chdir('./lac_serving')
self.lac_port = self.port + 100
r = os.popen(
"python -m paddle_serving_server_gpu.serve --model {} --port {} &".
format(self.lac_model_path, self.lac_port))
"python -m paddle_serving_server.serve --model {} --port {} &".
format("../" + self.lac_model_path, self.lac_port))
os.chdir('..')
def init_lac_service(self):
......@@ -67,41 +70,42 @@ class SentaService(WebService):
self.senta_reader = SentaReader(vocab_path=self.senta_dict_path)
def preprocess(self, feed={}, fetch={}):
print("---- preprocess ----")
print(feed)
if "words" not in feed:
raise ("feed data error!")
feed_data = self.lac_reader.process(feed["words"])
fetch = ["crf_decode"]
print("---- lac reader ----")
print(feed_data)
if self.show:
print("---- lac reader ----")
print(feed_data)
lac_result = self.lac_predict(feed_data)
print("---- lac out ----")
print(lac_result)
if self.show:
print("---- lac out ----")
print(lac_result)
segs = self.lac_reader.parse_result(feed["words"],
lac_result["crf_decode"])
print("---- lac parse ----")
if self.show:
print("---- lac parse ----")
print(segs)
feed_data = self.senta_reader.process(segs)
print("---- senta reader ----")
print("feed_data", feed_data)
fetch = ["sentence_feature"]
if self.show:
print("---- senta reader ----")
print("feed_data", feed_data)
fetch = ["class_probs"]
return {"words": feed_data}, fetch
senta_service = SentaService(
name="senta",
lac_model_path="../../lac/jieba_server_model/",
lac_client_config_path="../lac/jieba_client_conf/serving_client_conf.prototxt",
lac_dict="../lac/lac_dict",
senta_dict="./senta_data/word_dict.txt")
senta_service = SentaService(name="senta")
#senta_service.show_detail(True)
senta_service.set_config(
lac_model_path="./infer_model",
lac_dict_path="../lac/lac_dict",
senta_dict_path="./vocab.txt")
senta_service.load_model_config(sys.argv[1])
senta_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu")
senta_service.init_lac_reader()
senta_service.init_senta_reader()
print("Init senta done")
senta_service.init_lac_service()
print("init lac service done")
senta_service.run_server()
#senta_service.run_flask()
......
......@@ -53,12 +53,14 @@ class LACReader(object):
#folder = os.path.dirname(basepath)
word_dict_path = os.path.join(dict_folder, "word.dic")
label_dict_path = os.path.join(dict_folder, "tag.dic")
replace_dict_path = os.path.join(dict_folder, "q2b.dic")
self.word2id_dict = load_kv_dict(
word_dict_path, reverse=True, value_func=int)
self.id2word_dict = load_kv_dict(word_dict_path)
self.label2id_dict = load_kv_dict(
label_dict_path, reverse=True, value_func=int)
self.id2label_dict = load_kv_dict(label_dict_path)
self.word_replace_dict = load_kv_dict(replace_dict_path)
@property
def vocab_size(self):
......@@ -79,6 +81,7 @@ class LACReader(object):
except:
pass
for word in words:
word = self.word_replace_dict.get(word, word)
if word not in self.word2id_dict:
word = "OOV"
word_id = self.word2id_dict[word]
......
......@@ -27,11 +27,16 @@ class SentaReader():
"""
vocab = {}
with io.open(vocab_path, 'r', encoding='utf8') as f:
wid = 0
for line in f:
if line.strip() not in vocab:
vocab[line.strip()] = wid
wid += 1
data = line.strip().split("\t")
if len(data) < 2:
word = ""
wid = data[0]
else:
word = data[0]
wid = data[1]
vocab[word] = int(wid)
vocab["<unk>"] = len(vocab)
return vocab
......@@ -41,6 +46,7 @@ class SentaReader():
wids = [
self.word_dict[x] if x in self.word_dict else unk_id for x in cols
]
'''
seq_len = len(wids)
if seq_len < self.max_seq_len:
for i in range(self.max_seq_len - seq_len):
......@@ -48,5 +54,5 @@ class SentaReader():
else:
wids = wids[:self.max_seq_len]
seq_len = self.max_seq_len
'''
return wids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册