提交 c4aa3e00 编写于 作者: W wangjiawei04

fix webservice unittest

上级 fd6d69fd
...@@ -23,7 +23,7 @@ args = benchmark_args() ...@@ -23,7 +23,7 @@ args = benchmark_args()
reader = ChineseBertReader({"max_seq_len": 128}) reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"] fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9292"] endpoint_list = [':8861']
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.connect(endpoint_list) client.connect(endpoint_list)
...@@ -35,3 +35,4 @@ for line in sys.stdin: ...@@ -35,3 +35,4 @@ for line in sys.stdin:
#print(feed_dict) #print(feed_dict)
result = client.predict(feed=feed_dict, fetch=fetch) result = client.predict(feed=feed_dict, fetch=fetch)
print(result) print(result)
print(result)
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_server_gpu.web_service import WebService from paddle_serving_server.web_service import WebService
from paddle_serving_app.reader import ChineseBertReader from paddle_serving_app.reader import ChineseBertReader
import sys import sys
import os import os
import numpy as np
class BertService(WebService): class BertService(WebService):
...@@ -27,18 +28,20 @@ class BertService(WebService): ...@@ -27,18 +28,20 @@ class BertService(WebService):
}) })
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
feed_res = [ feed_res = []
self.reader.process(ins["words"].encode("utf-8")) for ins in feed for ins in feed:
] feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape(
(1, len(feed_dict[key]), 1))
feed_res.append(feed_dict)
return feed_res, fetch return feed_res, fetch
bert_service = BertService(name="bert") bert_service = BertService(name="bert")
bert_service.load() bert_service.load()
bert_service.load_model_config(sys.argv[1]) bert_service.load_model_config(sys.argv[1])
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"]
bert_service.set_gpus(gpu_ids)
bert_service.prepare_server( bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu") workdir="workdir", port=int(sys.argv[2]), device="cpu")
bert_service.run_rpc_service() bert_service.run_rpc_service()
bert_service.run_web_service() bert_service.run_web_service()
...@@ -27,7 +27,7 @@ postprocess = RCNNPostprocess("label_list.txt", "output") ...@@ -27,7 +27,7 @@ postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.connect(['127.0.0.1:9494']) client.connect([':8870'])
im = preprocess(sys.argv[3]) im = preprocess(sys.argv[3])
fetch_map = client.predict( fetch_map = client.predict(
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import sys import sys
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize
import numpy as np
if len(sys.argv) != 4: if len(sys.argv) != 4:
print("python resnet50_web_service.py model device port") print("python resnet50_web_service.py model device port")
sys.exit(-1) sys.exit(-1)
...@@ -47,7 +47,7 @@ class ImageService(WebService): ...@@ -47,7 +47,7 @@ class ImageService(WebService):
if "image" not in ins: if "image" not in ins:
raise ("feed data error!") raise ("feed data error!")
img = self.seq(ins["image"]) img = self.seq(ins["image"])
feed_batch.append({"image": img}) feed_batch.append({"image": img[np.newaxis, :]})
return feed_batch, fetch return feed_batch, fetch
def postprocess(self, feed=[], fetch=[], fetch_map={}): def postprocess(self, feed=[], fetch=[], fetch_map={}):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from paddle_serving_server.web_service import WebService from paddle_serving_server.web_service import WebService
from paddle_serving_app.reader import IMDBDataset from paddle_serving_app.reader import IMDBDataset
import sys import sys
import numpy as np
class IMDBService(WebService): class IMDBService(WebService):
...@@ -26,15 +27,15 @@ class IMDBService(WebService): ...@@ -26,15 +27,15 @@ class IMDBService(WebService):
self.dataset.load_resource(args["dict_file_path"]) self.dataset.load_resource(args["dict_file_path"])
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
res_feed = [{ feed_batch = []
"words": self.dataset.get_words_only(ins["words"]) words_lod = [0]
} for ins in feed] for ins in feed:
words = self.dataset.get_words_only(ins["words"])
feed = { words = np.array(words).reshape(len(words), 1)
"words": np.array(word_ids).reshape(word_len, 1), words_lod.append(words_lod[-1] + len(words))
"words.lod": [0, word_len] feed_batch.append(words)
} feed = {"words": np.concatenate(feed_batch), "words.lod": words_lod}
return res_feed, fetch return feed, fetch
imdb_service = IMDBService(name="imdb") imdb_service = IMDBService(name="imdb")
......
...@@ -22,8 +22,8 @@ import io ...@@ -22,8 +22,8 @@ import io
import numpy as np import numpy as np
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.connect([':8868'])
client.connect(["127.0.0.1:9292"]) client.connect([':8868'])
reader = LACReader() reader = LACReader()
for line in sys.stdin: for line in sys.stdin:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddle_serving_server.web_service import WebService from paddle_serving_server.web_service import WebService
import sys import sys
from paddle_serving_app.reader import LACReader from paddle_serving_app.reader import LACReader
import numpy as np
class LACService(WebService): class LACService(WebService):
...@@ -23,13 +24,17 @@ class LACService(WebService): ...@@ -23,13 +24,17 @@ class LACService(WebService):
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
feed_batch = [] feed_batch = []
words_lod = [0]
for ins in feed: for ins in feed:
if "words" not in ins: if "words" not in ins:
raise ("feed data error!") raise ("feed data error!")
feed_data = self.reader.process(ins["words"]) feed_data = self.reader.process(ins["words"])
feed_batch.append({"words": feed_data}) words_lod.append(words_lod[-1] + len(feed_data))
feed_batch.append(np.array(feed_data).reshape(len(feed_data), 1))
words = np.concatenate(feed_batch, axis=0)
fetch = ["crf_decode"] fetch = ["crf_decode"]
return feed_batch, fetch return {"words": words, "words.lod": words_lod}, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}): def postprocess(self, feed={}, fetch=[], fetch_map={}):
batch_ret = [] batch_ret = []
......
...@@ -18,7 +18,7 @@ from paddle_serving_client import Client ...@@ -18,7 +18,7 @@ from paddle_serving_client import Client
from paddle_serving_app.reader import LACReader, SentaReader from paddle_serving_app.reader import LACReader, SentaReader
import os import os
import sys import sys
import numpy as np
#senta_web_service.py #senta_web_service.py
from paddle_serving_server.web_service import WebService from paddle_serving_server.web_service import WebService
from paddle_serving_client import Client from paddle_serving_client import Client
...@@ -36,26 +36,42 @@ class SentaService(WebService): ...@@ -36,26 +36,42 @@ class SentaService(WebService):
#定义senta模型预测服务的预处理,调用顺序:lac reader->lac模型预测->预测结果后处理->senta reader #定义senta模型预测服务的预处理,调用顺序:lac reader->lac模型预测->预测结果后处理->senta reader
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
feed_data = [{
"words": self.lac_reader.process(x["words"])
} for x in feed]
lac_result = self.lac_client.predict(
feed=feed_data, fetch=["crf_decode"])
feed_batch = [] feed_batch = []
words_lod = [0]
for ins in feed:
if "words" not in ins:
raise ("feed data error!")
feed_data = self.lac_reader.process(ins["words"])
words_lod.append(words_lod[-1] + len(feed_data))
feed_batch.append(np.array(feed_data).reshape(len(feed_data), 1))
words = np.concatenate(feed_batch, axis=0)
lac_result = self.lac_client.predict(
feed={"words": words,
"words.lod": words_lod},
fetch=["crf_decode"],
batch=True)
result_lod = lac_result["crf_decode.lod"] result_lod = lac_result["crf_decode.lod"]
feed_batch = []
words_lod = [0]
for i in range(len(feed)): for i in range(len(feed)):
segs = self.lac_reader.parse_result( segs = self.lac_reader.parse_result(
feed[i]["words"], feed[i]["words"],
lac_result["crf_decode"][result_lod[i]:result_lod[i + 1]]) lac_result["crf_decode"][result_lod[i]:result_lod[i + 1]])
feed_data = self.senta_reader.process(segs) feed_data = self.senta_reader.process(segs)
feed_batch.append({"words": feed_data}) feed_batch.append(np.array(feed_data).reshape(len(feed_data), 1))
return feed_batch, fetch words_lod.append(words_lod[-1] + len(feed_data))
return {
"words": np.concatenate(feed_batch),
"words.lod": words_lod
}, fetch
senta_service = SentaService(name="senta") senta_service = SentaService(name="senta")
senta_service.load_model_config("senta_bilstm_model") senta_service.load_model_config("senta_bilstm_model")
senta_service.prepare_server(workdir="workdir") senta_service.prepare_server(workdir="workdir")
senta_service.init_lac_client( senta_service.init_lac_client(
lac_port=9300, lac_client_config="lac_model/serving_server_conf.prototxt") lac_port=9300,
lac_client_config="lac/lac_model/serving_server_conf.prototxt")
senta_service.run_rpc_service() senta_service.run_rpc_service()
senta_service.run_web_service() senta_service.run_web_service()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册