diff --git a/python/examples/imagenet/image_classification_service.py b/python/examples/imagenet/image_classification_service.py index ee3ae6dd1c64bda154bbadabe8d1e91da734fb5a..a041d9d9447184512f010251c85c919985f02c7c 100644 --- a/python/examples/imagenet/image_classification_service.py +++ b/python/examples/imagenet/image_classification_service.py @@ -17,7 +17,7 @@ import sys import cv2 import base64 import numpy as np -from image_reader import ImageReader +from paddle_serving_app import ImageReader class ImageService(WebService): diff --git a/python/examples/imagenet/image_classification_service_gpu.py b/python/examples/imagenet/image_classification_service_gpu.py index d8ba4ed8cda9f600fb6d33441b90accdf5ecc532..ccb2adac26bb0fd16165dd2dfb3c9507b1bf193f 100644 --- a/python/examples/imagenet/image_classification_service_gpu.py +++ b/python/examples/imagenet/image_classification_service_gpu.py @@ -16,7 +16,7 @@ import sys import cv2 import base64 import numpy as np -from image_reader import ImageReader +from paddle_serving_app import ImageReader from paddle_serving_server_gpu.web_service import WebService diff --git a/python/examples/lac/lac_reader.py b/python/examples/lac/lac_reader.py index 0c44177c2d56e5de94a18ce3514d0439a33361c5..3895277dbbf98a9b4ff2d6592d82fe02fe9a4c12 100644 --- a/python/examples/lac/lac_reader.py +++ b/python/examples/lac/lac_reader.py @@ -101,7 +101,7 @@ class LACReader(object): return word_ids def parse_result(self, words, crf_decode): - tags = [self.id2label_dict[str(x)] for x in crf_decode] + tags = [self.id2label_dict[str(x[0])] for x in crf_decode] sent_out = [] tags_out = [] diff --git a/python/examples/senta/README.md b/python/examples/senta/README.md new file mode 100644 index 0000000000000000000000000000000000000000..307f4829407b2fb03b64035c94ac00c3d55c27f5 --- /dev/null +++ b/python/examples/senta/README.md @@ -0,0 +1,16 @@ +# Chinese sentence sentiment classification + +## Get model files and sample data +``` +sh get_data.sh +``` +## Start http service +``` +python senta_web_service.py senta_bilstm_model/ workdir 9292 +``` +In the Chinese sentiment classification task, the Chinese word segmentation needs to be done through [LAC task] (../lac). Set model path by ```lac_model_path``` and dictionary path by ```lac_dict_path```. +In this demo, the LAC task is placed in the preprocessing part of the HTTP prediction service of the sentiment classification task. The LAC prediction service is deployed on the CPU, and the sentiment classification task is deployed on the GPU, which can be changed according to the actual situation. +## Client prediction +``` +curl -H "Content-Type:application/json" -X POST -d '{"words": "天气不错 | 0", "fetch":["sentence_feature"]}' http://127.0.0.1:9292/senta/prediction +``` diff --git a/python/examples/senta/README_CN.md b/python/examples/senta/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..cd7cd8564242db8c30525978a4b806c866cd0d0f --- /dev/null +++ b/python/examples/senta/README_CN.md @@ -0,0 +1,17 @@ +# 中文语句情感分类 + +## 获取模型文件和样例数据 +``` +sh get_data.sh +``` +## 启动HTTP服务 +``` +python senta_web_service.py senta_bilstm_model/ workdir 9292 +``` +中文情感分类任务中需要先通过[LAC任务](../lac)进行中文分词,在脚本中通过```lac_model_path```参数配置LAC任务的模型文件路径,```lac_dict_path```参数配置LAC任务词典路径。 +示例中将LAC任务放在情感分类任务的HTTP预测服务的预处理部分,LAC预测服务部署在CPU上,情感分类任务部署在GPU上,可以根据实际情况进行更改。 + +## 客户端预测 +``` +curl -H "Content-Type:application/json" -X POST -d '{"words": "天气不错 | 0", "fetch":["sentence_feature"]}' http://127.0.0.1:9292/senta/prediction +``` diff --git a/python/examples/senta/get_data.sh b/python/examples/senta/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..75bc8539721268aa212d5d6d726e1e9d600188b1 --- /dev/null +++ b/python/examples/senta/get_data.sh @@ -0,0 +1,7 @@ +#wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/senta_bilstm.tar.gz --no-check-certificate +#tar -xzvf senta_bilstm.tar.gz +wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/LexicalAnalysis/lac_model.tar.gz --no-check-certificate +tar -xzvf lac_model.tar.gz +wget https://paddle-serving.bj.bcebos.com/reader/lac/lac_dict.tar.gz --no-check-certificate +tar -xzvf lac_dict.tar.gz +wget https://paddle-serving.bj.bcebos.com/reader/senta/vocab.txt --no-check-certificate diff --git a/python/examples/senta/senta_web_service.py b/python/examples/senta/senta_web_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7077b84b7a97cac6387b8cb2e88e31c0b0e5d70e --- /dev/null +++ b/python/examples/senta/senta_web_service.py @@ -0,0 +1,137 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_server_gpu.web_service import WebService +from paddle_serving_client import Client +from paddle_serving_app import LACReader, SentaReader +import numpy as np +import os +import io +import sys +import subprocess +from multiprocessing import Process, Queue + + +class SentaService(WebService): + def set_config( + self, + lac_model_path, + lac_dict_path, + senta_dict_path, ): + self.lac_model_path = lac_model_path + self.lac_client_config_path = lac_model_path + "/serving_server_conf.prototxt" + self.lac_dict_path = lac_dict_path + self.senta_dict_path = senta_dict_path + self.show = False + + def show_detail(self, show=False): + self.show = show + + def start_lac_service(self): + os.chdir('./lac_serving') + self.lac_port = self.port + 100 + r = os.popen( + "python -m paddle_serving_server.serve --model {} --port {} &". + format("../" + self.lac_model_path, self.lac_port)) + os.chdir('..') + + def init_lac_service(self): + ps = Process(target=self.start_lac_service()) + ps.start() + #self.init_lac_client() + + def lac_predict(self, feed_data): + self.init_lac_client() + lac_result = self.lac_client.predict( + feed={"words": feed_data}, fetch=["crf_decode"]) + self.lac_client.release() + return lac_result + + def init_lac_client(self): + self.lac_client = Client() + self.lac_client.load_client_config(self.lac_client_config_path) + self.lac_client.connect(["127.0.0.1:{}".format(self.lac_port)]) + + def init_lac_reader(self): + self.lac_reader = LACReader(self.lac_dict_path) + + def init_senta_reader(self): + self.senta_reader = SentaReader(vocab_path=self.senta_dict_path) + + def preprocess(self, feed={}, fetch={}): + if "words" not in feed: + raise ("feed data error!") + feed_data = self.lac_reader.process(feed["words"]) + fetch = ["crf_decode"] + if self.show: + print("---- lac reader ----") + print(feed_data) + lac_result = self.lac_predict(feed_data) + if self.show: + print("---- lac out ----") + print(lac_result) + segs = self.lac_reader.parse_result(feed["words"], + lac_result["crf_decode"]) + if self.show: + print("---- lac parse ----") + print(segs) + feed_data = self.senta_reader.process(segs) + if self.show: + print("---- senta reader ----") + print("feed_data", feed_data) + fetch = ["class_probs"] + return {"words": feed_data}, fetch + + +senta_service = SentaService(name="senta") +#senta_service.show_detail(True) +senta_service.set_config( + lac_model_path="./lac_model", + lac_dict_path="./lac_dict", + senta_dict_path="./vocab.txt") +senta_service.load_model_config(sys.argv[1]) +senta_service.prepare_server( + workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") +senta_service.init_lac_reader() +senta_service.init_senta_reader() +senta_service.init_lac_service() +senta_service.run_server() +#senta_service.run_flask() + +from flask import Flask, request + +app_instance = Flask(__name__) + + +@app_instance.before_first_request +def init(): + global uci_service + senta_service._launch_web_service() + + +service_name = "/" + senta_service.name + "/prediction" + + +@app_instance.route(service_name, methods=["POST"]) +def run(): + print("---- run ----") + print(request.json) + return senta_service.get_prediction(request) + + +if __name__ == "__main__": + app_instance.run(host="0.0.0.0", + port=senta_service.port, + threaded=False, + processes=4) diff --git a/python/paddle_serving_app/__init__.py b/python/paddle_serving_app/__init__.py index 968e5582cc286455d5200e154033087b71ac86de..0b3f60e28d3094ebb695c2caa0cc9f333030cf3f 100644 --- a/python/paddle_serving_app/__init__.py +++ b/python/paddle_serving_app/__init__.py @@ -12,3 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .reader.chinese_bert_reader import ChineseBertReader +from .reader.image_reader import ImageReader +from .reader.lac_reader import LACReader +from .reader.senta_reader import SentaReader diff --git a/python/paddle_serving_app/models/model_list.py b/python/paddle_serving_app/models/model_list.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb109c1fc5f36a5d39b31233dff648ea0f29c44 --- /dev/null +++ b/python/paddle_serving_app/models/model_list.py @@ -0,0 +1,98 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from collections import OrderedDict + + +class ServingModels(object): + def __init__(self): + self.model_dict = OrderedDict() + #senta + for key in [ + "senta_bilstm", "senta_bow", "senta_cnn", "senta_gru", + "senta_lstm" + ]: + self.model_dict[ + key] = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/" + key + ".tar.gz" + #image classification + for key in [ + "alexnet_imagenet", + "darknet53-imagenet", + "densenet121_imagenet", + "densenet161_imagenet", + "densenet169_imagenet", + "densenet201_imagenet", + "densenet264_imagenet" + "dpn107_imagenet", + "dpn131_imagenet", + "dpn68_imagenet", + "dpn92_imagenet", + "dpn98_imagenet", + "efficientnetb0_imagenet", + "efficientnetb1_imagenet", + "efficientnetb2_imagenet", + "efficientnetb3_imagenet", + "efficientnetb4_imagenet", + "efficientnetb5_imagenet", + "efficientnetb6_imagenet", + "googlenet_imagenet", + "inception_v4_imagenet", + "inception_v2_imagenet", + "nasnet_imagenet", + "pnasnet_imagenet", + "resnet_v2_101_imagenet", + "resnet_v2_151_imagenet", + "resnet_v2_18_imagenet", + "resnet_v2_34_imagenet", + " resnet_v2_50_imagenet", + "resnext101_32x16d_wsl", + "resnext101_32x32d_wsl", + "resnext101_32x48d_wsl", + "resnext101_32x8d_wsl", + "resnext101_32x4d_imagenet", + "resnext101_64x4d_imagenet", + "resnext101_vd_32x4d_imagenet", + "resnext101_vd_64x4d_imagenet", + "resnext152_64x4d_imagenet", + "resnext152_vd_64x4d_imagenet", + "resnext50_64x4d_imagenet", + "resnext50_vd_32x4d_imagenet", + "resnext50_vd_64x4d_imagenet", + "se_resnext101_32x4d_imagenet", + "se_resnext50_32x4d_imagenet", + "shufflenet_v2_imagenet", + "vgg11_imagenet", + "vgg13_imagenet", + "vgg16_imagenet", + "vgg19_imagenet", + "xception65_imagenet", + "xception71_imagenet", + ]: + self.model_dict[ + key] = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/" + key + ".tar.gz" + + def get_model_list(self): + return (self.model_dict.keys()) + + def download(self, model_name): + if model_name in self.model_dict: + url = self.model_dict[model_name] + r = os.system('wget ' + url + ' --no-check-certificate') + + +if __name__ == "__main__": + models = ServingModels() + print(models.get_model_list()) diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..2647eb6fdf3ca0f1682ca794051b9d0dd95a9a07 --- /dev/null +++ b/python/paddle_serving_app/reader/image_reader.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class ImageReader(): + def __init__(self, + image_shape=[3, 224, 224], + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + resize_short_size=256, + interpolation=None, + crop_center=True): + self.image_mean = image_mean + self.image_std = image_std + self.image_shape = image_shape + self.resize_short_size = resize_short_size + self.interpolation = interpolation + self.crop_center = crop_center + + def resize_short(self, img, target_size, interpolation=None): + """resize image + + Args: + img: image data + target_size: resize short target size + interpolation: interpolation mode + + Returns: + resized image data + """ + percent = float(target_size) / min(img.shape[0], img.shape[1]) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + if interpolation: + resized = cv2.resize( + img, (resized_width, resized_height), + interpolation=interpolation) + else: + resized = cv2.resize(img, (resized_width, resized_height)) + return resized + + def crop_image(self, img, target_size, center): + """crop image + + Args: + img: images data + target_size: crop target size + center: crop mode + + Returns: + img: cropped image data + """ + height, width = img.shape[:2] + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + + def process_image(self, sample): + """ process_image """ + mean = self.image_mean + std = self.image_std + crop_size = self.image_shape[1] + + data = np.fromstring(sample, np.uint8) + img = cv2.imdecode(data, cv2.IMREAD_COLOR) + + if img is None: + print("img is None, pass it.") + return None + + if crop_size > 0: + target_size = self.resize_short_size + img = self.resize_short( + img, target_size, interpolation=self.interpolation) + img = self.crop_image( + img, target_size=crop_size, center=self.crop_center) + + img = img[:, :, ::-1] + + img = img.astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + return img diff --git a/python/paddle_serving_app/reader/lac_reader.py b/python/paddle_serving_app/reader/lac_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ed0bbe44460993649675f627310e1a7b53c344 --- /dev/null +++ b/python/paddle_serving_app/reader/lac_reader.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import sys +reload(sys) +sys.setdefaultencoding('utf-8') +import os +import io + + +def load_kv_dict(dict_path, + reverse=False, + delimiter="\t", + key_func=None, + value_func=None): + result_dict = {} + for line in io.open(dict_path, "r", encoding="utf8"): + terms = line.strip("\n").split(delimiter) + if len(terms) != 2: + continue + if reverse: + value, key = terms + else: + key, value = terms + if key in result_dict: + raise KeyError("key duplicated with [%s]" % (key)) + if key_func: + key = key_func(key) + if value_func: + value = value_func(value) + result_dict[key] = value + return result_dict + + +class LACReader(object): + """data reader""" + + def __init__(self, dict_folder): + # read dict + #basepath = os.path.abspath(__file__) + #folder = os.path.dirname(basepath) + word_dict_path = os.path.join(dict_folder, "word.dic") + label_dict_path = os.path.join(dict_folder, "tag.dic") + replace_dict_path = os.path.join(dict_folder, "q2b.dic") + self.word2id_dict = load_kv_dict( + word_dict_path, reverse=True, value_func=int) + self.id2word_dict = load_kv_dict(word_dict_path) + self.label2id_dict = load_kv_dict( + label_dict_path, reverse=True, value_func=int) + self.id2label_dict = load_kv_dict(label_dict_path) + self.word_replace_dict = load_kv_dict(replace_dict_path) + + @property + def vocab_size(self): + """vocabulary size""" + return max(self.word2id_dict.values()) + 1 + + @property + def num_labels(self): + """num_labels""" + return max(self.label2id_dict.values()) + 1 + + def word_to_ids(self, words): + """convert word to word index""" + word_ids = [] + idx = 0 + try: + words = unicode(words, 'utf-8') + except: + pass + for word in words: + word = self.word_replace_dict.get(word, word) + if word not in self.word2id_dict: + word = "OOV" + word_id = self.word2id_dict[word] + word_ids.append(word_id) + return word_ids + + def label_to_ids(self, labels): + """convert label to label index""" + label_ids = [] + for label in labels: + if label not in self.label2id_dict: + label = "O" + label_id = self.label2id_dict[label] + label_ids.append(label_id) + return label_ids + + def process(self, sent): + words = sent.strip() + word_ids = self.word_to_ids(words) + return word_ids + + def parse_result(self, words, crf_decode): + tags = [self.id2label_dict[str(x[0])] for x in crf_decode] + + sent_out = [] + tags_out = [] + partial_word = "" + for ind, tag in enumerate(tags): + if partial_word == "": + partial_word = words[ind] + tags_out.append(tag.split('-')[0]) + continue + if tag.endswith("-B") or (tag == "O" and tag[ind - 1] != "O"): + sent_out.append(partial_word) + tags_out.append(tag.split('-')[0]) + partial_word = words[ind] + continue + partial_word += words[ind] + + if len(sent_out) < len(tags_out): + sent_out.append(partial_word) + + return sent_out diff --git a/python/paddle_serving_app/reader/senta_reader.py b/python/paddle_serving_app/reader/senta_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..6e608b822fbb66f11288ea0080c8e264d8e5c34a --- /dev/null +++ b/python/paddle_serving_app/reader/senta_reader.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import io + + +class SentaReader(): + def __init__(self, vocab_path, max_seq_len=20): + self.max_seq_len = max_seq_len + self.word_dict = self.load_vocab(vocab_path) + + def load_vocab(self, vocab_path): + """ + load the given vocabulary + """ + vocab = {} + with io.open(vocab_path, 'r', encoding='utf8') as f: + for line in f: + if line.strip() not in vocab: + data = line.strip().split("\t") + if len(data) < 2: + word = "" + wid = data[0] + else: + word = data[0] + wid = data[1] + vocab[word] = int(wid) + vocab[""] = len(vocab) + return vocab + + def process(self, cols): + unk_id = len(self.word_dict) + pad_id = 0 + wids = [ + self.word_dict[x] if x in self.word_dict else unk_id for x in cols + ] + ''' + seq_len = len(wids) + if seq_len < self.max_seq_len: + for i in range(self.max_seq_len - seq_len): + wids.append(pad_id) + else: + wids = wids[:self.max_seq_len] + seq_len = self.max_seq_len + ''' + return wids diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 801e2acb323ba64f246609684bc33194891a7250..07a9ab6630fa5a907423236d37dd66951b012f72 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -313,8 +313,8 @@ class Client(object): result_map[name] = np.array(result_map[name], dtype='int64') result_map[name].shape = shape if name in self.lod_tensor_set: - result_map["{}.lod".format( - name)] = result_batch.get_lod(mi, name) + result_map["{}.lod".format(name)] = np.array( + result_batch.get_lod(mi, name)) elif self.fetch_names_to_type_[name] == float_type: result_map[name] = result_batch.get_float_by_name(mi, name) shape = result_batch.get_shape(mi, name) @@ -322,10 +322,9 @@ class Client(object): result_map[name], dtype='float32') result_map[name].shape = shape if name in self.lod_tensor_set: - result_map["{}.lod".format( - name)] = result_batch.get_lod(mi, name) + result_map["{}.lod".format(name)] = np.array( + result_batch.get_lod(mi, name)) multi_result_map.append(result_map) - ret = None if len(model_engine_names) == 1: # If only one model result is returned, the format of ret is result_map diff --git a/python/paddle_serving_server/serve.py b/python/paddle_serving_server/serve.py index 396ed17c2074923003515b26352e87fc8309a252..395177a8c77e5c608c2e0364b1d43ac534172d66 100644 --- a/python/paddle_serving_server/serve.py +++ b/python/paddle_serving_server/serve.py @@ -79,6 +79,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing server.set_num_threads(thread_num) server.set_memory_optimize(mem_optim) server.set_max_body_size(max_body_size) + server.set_port(port) server.load_model_config(model) server.prepare_server(workdir=workdir, port=port, device=device)