diff --git a/python/examples/lac/benchmark.py b/python/examples/lac/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..53d0881ed74e5e19104a70fb93d6872141d27afd --- /dev/null +++ b/python/examples/lac/benchmark.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import sys +import time +import requests +from lac_reader import LACReader +from paddle_serving_client import Client +from paddle_serving_client.utils import MultiThreadRunner +from paddle_serving_client.utils import benchmark_args + +args = benchmark_args() + + +def single_func(idx, resource): + reader = LACReader("lac_dict") + start = time.time() + if args.request == "rpc": + client = Client() + client.load_client_config(args.model) + client.connect([args.endpoint]) + fin = open("jieba_test.txt") + for line in fin: + feed_data = reader.process(line) + fetch_map = client.predict( + feed={"words": feed_data}, fetch=["crf_decode"]) + elif args.request == "http": + fin = open("jieba_test.txt") + for line in fin: + req_data = {"words": line.strip(), "fetch": ["crf_decode"]} + r = requests.post( + "http://{}/lac/prediction".format(args.endpoint), + data={"words": line.strip(), + "fetch": ["crf_decode"]}) + end = time.time() + return [[end - start]] + + +multi_thread_runner = MultiThreadRunner() +result = multi_thread_runner.run(single_func, args.thread, {}) +print(result) diff --git a/python/examples/lac/get_data.sh b/python/examples/lac/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b72850d35b7a7b5e43b34d31c7a903e05f07440 --- /dev/null +++ b/python/examples/lac/get_data.sh @@ -0,0 +1,2 @@ +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/lac/lac_model.tar.gz +tar -zxvf lac_model.tar.gz diff --git a/python/examples/lac/lac_client.py b/python/examples/lac/lac_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a8e858ed72ac4043a2bb3162a39a2aff233043 --- /dev/null +++ b/python/examples/lac/lac_client.py @@ -0,0 +1,35 @@ +# encoding=utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +from paddle_serving_client import Client +from lac_reader import LACReader +import sys +import os +import io + +client = Client() +client.load_client_config(sys.argv[1]) +client.connect(["127.0.0.1:9280"]) + +reader = LACReader(sys.argv[2]) +for line in sys.stdin: + if len(line) <= 0: + continue + feed_data = reader.process(line) + if len(feed_data) <= 0: + continue + fetch_map = client.predict(feed={"words": feed_data}, fetch=["crf_decode"]) + print(fetch_map) diff --git a/python/examples/lac/lac_http_client.py b/python/examples/lac/lac_http_client.py new file mode 100644 index 0000000000000000000000000000000000000000..852d785f368e95bb16bfd5804e3153b022945f59 --- /dev/null +++ b/python/examples/lac/lac_http_client.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#coding=utf-8 +import requests +import json +import time + +if __name__ == "__main__": + server = "http://127.0.0.1:9280/lac/prediction" + fin = open("jieba_test.txt", "r") + start = time.time() + for line in fin: + req_data = {"words": line.strip(), "fetch": ["crf_decode"]} + r = requests.post(server, json=req_data) + end = time.time() + print(end - start) diff --git a/python/examples/lac/lac_reader.py b/python/examples/lac/lac_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..087ec8bb9e1a44afa2ba5a1cc9931e350aa76fb7 --- /dev/null +++ b/python/examples/lac/lac_reader.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import sys +reload(sys) +sys.setdefaultencoding('utf-8') +import os +import io + + +def load_kv_dict(dict_path, + reverse=False, + delimiter="\t", + key_func=None, + value_func=None): + result_dict = {} + for line in io.open(dict_path, "r", encoding="utf8"): + terms = line.strip("\n").split(delimiter) + if len(terms) != 2: + continue + if reverse: + value, key = terms + else: + key, value = terms + if key in result_dict: + raise KeyError("key duplicated with [%s]" % (key)) + if key_func: + key = key_func(key) + if value_func: + value = value_func(value) + result_dict[key] = value + return result_dict + + +class LACReader(object): + """data reader""" + + def __init__(self, dict_folder): + # read dict + #basepath = os.path.abspath(__file__) + #folder = os.path.dirname(basepath) + word_dict_path = os.path.join(dict_folder, "word.dic") + label_dict_path = os.path.join(dict_folder, "tag.dic") + self.word2id_dict = load_kv_dict( + word_dict_path, reverse=True, value_func=int) + self.id2word_dict = load_kv_dict(word_dict_path) + self.label2id_dict = load_kv_dict( + label_dict_path, reverse=True, value_func=int) + self.id2label_dict = load_kv_dict(label_dict_path) + + @property + def vocab_size(self): + """vocabulary size""" + return max(self.word2id_dict.values()) + 1 + + @property + def num_labels(self): + """num_labels""" + return max(self.label2id_dict.values()) + 1 + + def word_to_ids(self, words): + """convert word to word index""" + word_ids = [] + idx = 0 + try: + words = unicode(words, 'utf-8') + except: + pass + for word in words: + if word not in self.word2id_dict: + word = "OOV" + word_id = self.word2id_dict[word] + word_ids.append(word_id) + return word_ids + + def label_to_ids(self, labels): + """convert label to label index""" + label_ids = [] + for label in labels: + if label not in self.label2id_dict: + label = "O" + label_id = self.label2id_dict[label] + label_ids.append(label_id) + return label_ids + + def process(self, sent): + words = sent.strip() + word_ids = self.word_to_ids(words) + return word_ids diff --git a/python/examples/lac/lac_web_service.py b/python/examples/lac/lac_web_service.py new file mode 100644 index 0000000000000000000000000000000000000000..4a58c6a43caea4045220546488226da121bfdc17 --- /dev/null +++ b/python/examples/lac/lac_web_service.py @@ -0,0 +1,36 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_server.web_service import WebService +import sys +from lac_reader import LACReader + + +class LACService(WebService): + def load_reader(self): + self.reader = LACReader("lac_dict") + + def preprocess(self, feed={}, fetch=[]): + if "words" not in feed: + raise ("feed data error!") + feed_data = self.reader.process(feed["words"]) + return {"words": feed_data}, fetch + + +lac_service = LACService(name="lac") +lac_service.load_model_config(sys.argv[1]) +lac_service.load_reader() +lac_service.prepare_server( + workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") +lac_service.run_server() diff --git a/python/examples/lac/utils.py b/python/examples/lac/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64602902f362cc847c705a3e18d3e76255961314 --- /dev/null +++ b/python/examples/lac/utils.py @@ -0,0 +1,141 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +util tools +""" +from __future__ import print_function +import os +import sys +import numpy as np +import paddle.fluid as fluid +import io + + +def str2bool(v): + """ + argparse does not support True or False in python + """ + return v.lower() in ("true", "t", "1") + + +def parse_result(words, crf_decode, dataset): + """ parse result """ + offset_list = (crf_decode.lod())[0] + words = np.array(words) + crf_decode = np.array(crf_decode) + batch_size = len(offset_list) - 1 + + for sent_index in range(batch_size): + begin, end = offset_list[sent_index], offset_list[sent_index + 1] + sent = [] + for id in words[begin:end]: + if dataset.id2word_dict[str(id[0])] == 'OOV': + sent.append(' ') + else: + sent.append(dataset.id2word_dict[str(id[0])]) + tags = [ + dataset.id2label_dict[str(id[0])] for id in crf_decode[begin:end] + ] + + sent_out = [] + tags_out = [] + parital_word = "" + for ind, tag in enumerate(tags): + # for the first word + if parital_word == "": + parital_word = sent[ind] + tags_out.append(tag.split('-')[0]) + continue + + # for the beginning of word + if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"): + sent_out.append(parital_word) + tags_out.append(tag.split('-')[0]) + parital_word = sent[ind] + continue + + parital_word += sent[ind] + + # append the last word, except for len(tags)=0 + if len(sent_out) < len(tags_out): + sent_out.append(parital_word) + return sent_out, tags_out + + +def parse_padding_result(words, crf_decode, seq_lens, dataset): + """ parse padding result """ + words = np.squeeze(words) + batch_size = len(seq_lens) + + batch_out = [] + for sent_index in range(batch_size): + + sent = [] + for id in words[begin:end]: + if dataset.id2word_dict[str(id[0])] == 'OOV': + sent.append(' ') + else: + sent.append(dataset.id2word_dict[str(id[0])]) + tags = [ + dataset.id2label_dict[str(id)] + for id in crf_decode[sent_index][1:seq_lens[sent_index] - 1] + ] + + sent_out = [] + tags_out = [] + parital_word = "" + for ind, tag in enumerate(tags): + # for the first word + if parital_word == "": + parital_word = sent[ind] + tags_out.append(tag.split('-')[0]) + continue + + # for the beginning of word + if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"): + sent_out.append(parital_word) + tags_out.append(tag.split('-')[0]) + parital_word = sent[ind] + continue + + parital_word += sent[ind] + + # append the last word, except for len(tags)=0 + if len(sent_out) < len(tags_out): + sent_out.append(parital_word) + + batch_out.append([sent_out, tags_out]) + return batch_out + + +def init_checkpoint(exe, init_checkpoint_path, main_program): + """ + Init CheckPoint + """ + assert os.path.exists( + init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path + + def existed_persitables(var): + """ + If existed presitabels + """ + if not fluid.io.is_persistable(var): + return False + return os.path.exists(os.path.join(init_checkpoint_path, var.name)) + + fluid.io.load_vars( + exe, + init_checkpoint_path, + main_program=main_program, + predicate=existed_persitables)