提交 b6f9f5f4 编写于 作者: M MRXLT

fix senta reader && lac reader

上级 8f820382
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import io
class SentaReader():
def __init__(self, vocab_path, max_seq_len=20):
self.max_seq_len = max_seq_len
self.word_dict = self.load_vocab(vocab_path)
def load_vocab(self, vocab_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(vocab_path, 'r', encoding='utf8') as f:
wid = 0
for line in f:
if line.strip() not in vocab:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def process(self, cols):
unk_id = len(self.word_dict)
pad_id = 0
wids = [
self.word_dict[x] if x in self.word_dict else unk_id for x in cols
]
seq_len = len(wids)
if seq_len < self.max_seq_len:
for i in range(self.max_seq_len - seq_len):
wids.append(pad_id)
else:
wids = wids[:self.max_seq_len]
seq_len = self.max_seq_len
return wids
...@@ -24,7 +24,7 @@ from multiprocessing import Process, Queue ...@@ -24,7 +24,7 @@ from multiprocessing import Process, Queue
class SentaService(WebService): class SentaService(WebService):
def __init__( def set_config(
self, self,
lac_model_path, lac_model_path,
lac_dict_path, lac_dict_path,
...@@ -33,14 +33,17 @@ class SentaService(WebService): ...@@ -33,14 +33,17 @@ class SentaService(WebService):
self.lac_client_config_path = lac_model_path + "/serving_server_conf.prototxt" self.lac_client_config_path = lac_model_path + "/serving_server_conf.prototxt"
self.lac_dict_path = lac_dict_path self.lac_dict_path = lac_dict_path
self.senta_dict_path = senta_dict_path self.senta_dict_path = senta_dict_path
self.show = False
def show_detail(self, show=False):
self.show = show
def start_lac_service(self): def start_lac_service(self):
print(" ---- start lac service ---- ")
os.chdir('./lac_serving') os.chdir('./lac_serving')
self.lac_port = self.port + 100 self.lac_port = self.port + 100
r = os.popen( r = os.popen(
"python -m paddle_serving_server_gpu.serve --model {} --port {} &". "python -m paddle_serving_server.serve --model {} --port {} &".
format(self.lac_model_path, self.lac_port)) format("../" + self.lac_model_path, self.lac_port))
os.chdir('..') os.chdir('..')
def init_lac_service(self): def init_lac_service(self):
...@@ -67,41 +70,42 @@ class SentaService(WebService): ...@@ -67,41 +70,42 @@ class SentaService(WebService):
self.senta_reader = SentaReader(vocab_path=self.senta_dict_path) self.senta_reader = SentaReader(vocab_path=self.senta_dict_path)
def preprocess(self, feed={}, fetch={}): def preprocess(self, feed={}, fetch={}):
print("---- preprocess ----")
print(feed)
if "words" not in feed: if "words" not in feed:
raise ("feed data error!") raise ("feed data error!")
feed_data = self.lac_reader.process(feed["words"]) feed_data = self.lac_reader.process(feed["words"])
fetch = ["crf_decode"] fetch = ["crf_decode"]
print("---- lac reader ----") if self.show:
print(feed_data) print("---- lac reader ----")
print(feed_data)
lac_result = self.lac_predict(feed_data) lac_result = self.lac_predict(feed_data)
print("---- lac out ----") if self.show:
print(lac_result) print("---- lac out ----")
print(lac_result)
segs = self.lac_reader.parse_result(feed["words"], segs = self.lac_reader.parse_result(feed["words"],
lac_result["crf_decode"]) lac_result["crf_decode"])
print("---- lac parse ----") if self.show:
print("---- lac parse ----")
print(segs)
feed_data = self.senta_reader.process(segs) feed_data = self.senta_reader.process(segs)
print("---- senta reader ----") if self.show:
print("feed_data", feed_data) print("---- senta reader ----")
fetch = ["sentence_feature"] print("feed_data", feed_data)
fetch = ["class_probs"]
return {"words": feed_data}, fetch return {"words": feed_data}, fetch
senta_service = SentaService( senta_service = SentaService(name="senta")
name="senta", #senta_service.show_detail(True)
lac_model_path="../../lac/jieba_server_model/", senta_service.set_config(
lac_client_config_path="../lac/jieba_client_conf/serving_client_conf.prototxt", lac_model_path="./infer_model",
lac_dict="../lac/lac_dict", lac_dict_path="../lac/lac_dict",
senta_dict="./senta_data/word_dict.txt") senta_dict_path="./vocab.txt")
senta_service.load_model_config(sys.argv[1]) senta_service.load_model_config(sys.argv[1])
senta_service.prepare_server( senta_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu")
senta_service.init_lac_reader() senta_service.init_lac_reader()
senta_service.init_senta_reader() senta_service.init_senta_reader()
print("Init senta done")
senta_service.init_lac_service() senta_service.init_lac_service()
print("init lac service done")
senta_service.run_server() senta_service.run_server()
#senta_service.run_flask() #senta_service.run_flask()
......
...@@ -53,12 +53,14 @@ class LACReader(object): ...@@ -53,12 +53,14 @@ class LACReader(object):
#folder = os.path.dirname(basepath) #folder = os.path.dirname(basepath)
word_dict_path = os.path.join(dict_folder, "word.dic") word_dict_path = os.path.join(dict_folder, "word.dic")
label_dict_path = os.path.join(dict_folder, "tag.dic") label_dict_path = os.path.join(dict_folder, "tag.dic")
replace_dict_path = os.path.join(dict_folder, "q2b.dic")
self.word2id_dict = load_kv_dict( self.word2id_dict = load_kv_dict(
word_dict_path, reverse=True, value_func=int) word_dict_path, reverse=True, value_func=int)
self.id2word_dict = load_kv_dict(word_dict_path) self.id2word_dict = load_kv_dict(word_dict_path)
self.label2id_dict = load_kv_dict( self.label2id_dict = load_kv_dict(
label_dict_path, reverse=True, value_func=int) label_dict_path, reverse=True, value_func=int)
self.id2label_dict = load_kv_dict(label_dict_path) self.id2label_dict = load_kv_dict(label_dict_path)
self.word_replace_dict = load_kv_dict(replace_dict_path)
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -79,6 +81,7 @@ class LACReader(object): ...@@ -79,6 +81,7 @@ class LACReader(object):
except: except:
pass pass
for word in words: for word in words:
word = self.word_replace_dict.get(word, word)
if word not in self.word2id_dict: if word not in self.word2id_dict:
word = "OOV" word = "OOV"
word_id = self.word2id_dict[word] word_id = self.word2id_dict[word]
......
...@@ -27,11 +27,16 @@ class SentaReader(): ...@@ -27,11 +27,16 @@ class SentaReader():
""" """
vocab = {} vocab = {}
with io.open(vocab_path, 'r', encoding='utf8') as f: with io.open(vocab_path, 'r', encoding='utf8') as f:
wid = 0
for line in f: for line in f:
if line.strip() not in vocab: if line.strip() not in vocab:
vocab[line.strip()] = wid data = line.strip().split("\t")
wid += 1 if len(data) < 2:
word = ""
wid = data[0]
else:
word = data[0]
wid = data[1]
vocab[word] = int(wid)
vocab["<unk>"] = len(vocab) vocab["<unk>"] = len(vocab)
return vocab return vocab
...@@ -41,6 +46,7 @@ class SentaReader(): ...@@ -41,6 +46,7 @@ class SentaReader():
wids = [ wids = [
self.word_dict[x] if x in self.word_dict else unk_id for x in cols self.word_dict[x] if x in self.word_dict else unk_id for x in cols
] ]
'''
seq_len = len(wids) seq_len = len(wids)
if seq_len < self.max_seq_len: if seq_len < self.max_seq_len:
for i in range(self.max_seq_len - seq_len): for i in range(self.max_seq_len - seq_len):
...@@ -48,5 +54,5 @@ class SentaReader(): ...@@ -48,5 +54,5 @@ class SentaReader():
else: else:
wids = wids[:self.max_seq_len] wids = wids[:self.max_seq_len]
seq_len = self.max_seq_len seq_len = self.max_seq_len
'''
return wids return wids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册