提交 c4aa3e00 编写于 作者: W wangjiawei04

fix webservice unittest

上级 fd6d69fd
......@@ -23,7 +23,7 @@ args = benchmark_args()
reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9292"]
endpoint_list = [':8861']
client = Client()
client.load_client_config(args.model)
client.connect(endpoint_list)
......@@ -35,3 +35,4 @@ for line in sys.stdin:
#print(feed_dict)
result = client.predict(feed=feed_dict, fetch=fetch)
print(result)
print(result)
......@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server_gpu.web_service import WebService
from paddle_serving_server.web_service import WebService
from paddle_serving_app.reader import ChineseBertReader
import sys
import os
import numpy as np
class BertService(WebService):
......@@ -27,18 +28,20 @@ class BertService(WebService):
})
def preprocess(self, feed=[], fetch=[]):
feed_res = [
self.reader.process(ins["words"].encode("utf-8")) for ins in feed
]
feed_res = []
for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape(
(1, len(feed_dict[key]), 1))
feed_res.append(feed_dict)
return feed_res, fetch
bert_service = BertService(name="bert")
bert_service.load()
bert_service.load_model_config(sys.argv[1])
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"]
bert_service.set_gpus(gpu_ids)
bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu")
workdir="workdir", port=int(sys.argv[2]), device="cpu")
bert_service.run_rpc_service()
bert_service.run_web_service()
......@@ -27,7 +27,7 @@ postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config(sys.argv[1])
client.connect(['127.0.0.1:9494'])
client.connect([':8870'])
im = preprocess(sys.argv[3])
fetch_map = client.predict(
......
......@@ -14,7 +14,7 @@
import sys
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize
import numpy as np
if len(sys.argv) != 4:
print("python resnet50_web_service.py model device port")
sys.exit(-1)
......@@ -47,7 +47,7 @@ class ImageService(WebService):
if "image" not in ins:
raise ("feed data error!")
img = self.seq(ins["image"])
feed_batch.append({"image": img})
feed_batch.append({"image": img[np.newaxis, :]})
return feed_batch, fetch
def postprocess(self, feed=[], fetch=[], fetch_map={}):
......
......@@ -16,6 +16,7 @@
from paddle_serving_server.web_service import WebService
from paddle_serving_app.reader import IMDBDataset
import sys
import numpy as np
class IMDBService(WebService):
......@@ -26,15 +27,15 @@ class IMDBService(WebService):
self.dataset.load_resource(args["dict_file_path"])
def preprocess(self, feed={}, fetch=[]):
res_feed = [{
"words": self.dataset.get_words_only(ins["words"])
} for ins in feed]
feed = {
"words": np.array(word_ids).reshape(word_len, 1),
"words.lod": [0, word_len]
}
return res_feed, fetch
feed_batch = []
words_lod = [0]
for ins in feed:
words = self.dataset.get_words_only(ins["words"])
words = np.array(words).reshape(len(words), 1)
words_lod.append(words_lod[-1] + len(words))
feed_batch.append(words)
feed = {"words": np.concatenate(feed_batch), "words.lod": words_lod}
return feed, fetch
imdb_service = IMDBService(name="imdb")
......
......@@ -22,8 +22,8 @@ import io
import numpy as np
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
client.connect([':8868'])
client.connect([':8868'])
reader = LACReader()
for line in sys.stdin:
......
......@@ -15,6 +15,7 @@
from paddle_serving_server.web_service import WebService
import sys
from paddle_serving_app.reader import LACReader
import numpy as np
class LACService(WebService):
......@@ -23,13 +24,17 @@ class LACService(WebService):
def preprocess(self, feed={}, fetch=[]):
feed_batch = []
words_lod = [0]
for ins in feed:
if "words" not in ins:
raise ("feed data error!")
feed_data = self.reader.process(ins["words"])
feed_batch.append({"words": feed_data})
words_lod.append(words_lod[-1] + len(feed_data))
feed_batch.append(np.array(feed_data).reshape(len(feed_data), 1))
words = np.concatenate(feed_batch, axis=0)
fetch = ["crf_decode"]
return feed_batch, fetch
return {"words": words, "words.lod": words_lod}, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}):
batch_ret = []
......
......@@ -18,7 +18,7 @@ from paddle_serving_client import Client
from paddle_serving_app.reader import LACReader, SentaReader
import os
import sys
import numpy as np
#senta_web_service.py
from paddle_serving_server.web_service import WebService
from paddle_serving_client import Client
......@@ -36,26 +36,42 @@ class SentaService(WebService):
#定义senta模型预测服务的预处理,调用顺序:lac reader->lac模型预测->预测结果后处理->senta reader
def preprocess(self, feed=[], fetch=[]):
feed_data = [{
"words": self.lac_reader.process(x["words"])
} for x in feed]
lac_result = self.lac_client.predict(
feed=feed_data, fetch=["crf_decode"])
feed_batch = []
words_lod = [0]
for ins in feed:
if "words" not in ins:
raise ("feed data error!")
feed_data = self.lac_reader.process(ins["words"])
words_lod.append(words_lod[-1] + len(feed_data))
feed_batch.append(np.array(feed_data).reshape(len(feed_data), 1))
words = np.concatenate(feed_batch, axis=0)
lac_result = self.lac_client.predict(
feed={"words": words,
"words.lod": words_lod},
fetch=["crf_decode"],
batch=True)
result_lod = lac_result["crf_decode.lod"]
feed_batch = []
words_lod = [0]
for i in range(len(feed)):
segs = self.lac_reader.parse_result(
feed[i]["words"],
lac_result["crf_decode"][result_lod[i]:result_lod[i + 1]])
feed_data = self.senta_reader.process(segs)
feed_batch.append({"words": feed_data})
return feed_batch, fetch
feed_batch.append(np.array(feed_data).reshape(len(feed_data), 1))
words_lod.append(words_lod[-1] + len(feed_data))
return {
"words": np.concatenate(feed_batch),
"words.lod": words_lod
}, fetch
senta_service = SentaService(name="senta")
senta_service.load_model_config("senta_bilstm_model")
senta_service.prepare_server(workdir="workdir")
senta_service.init_lac_client(
lac_port=9300, lac_client_config="lac_model/serving_server_conf.prototxt")
lac_port=9300,
lac_client_config="lac/lac_model/serving_server_conf.prototxt")
senta_service.run_rpc_service()
senta_service.run_web_service()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册