提交 425102fb 编写于 作者: W wangjiawei04

add web service benchmark

上级 b2c61d1c
...@@ -29,20 +29,24 @@ class BertService(WebService): ...@@ -29,20 +29,24 @@ class BertService(WebService):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
feed_res = [] feed_res = []
is_batch = False is_batch = True
for ins in feed: for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8")) feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys(): for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape( feed_dict[key] = np.array(feed_dict[key]).reshape(
(len(feed_dict[key]), 1)) (1, len(feed_dict[key]), 1))
feed_res.append(feed_dict) feed_res.append(feed_dict)
return feed_res, fetch, is_batch feed_dict = {}
for key in feed_res[0].keys():
feed_dict[key] = np.concatenate([x[key] for x in feed_res], axis=0)
print(key, feed_dict[key].shape)
return feed_dict, fetch, is_batch
bert_service = BertService(name="bert") bert_service = BertService(name="bert")
bert_service.setup_profile(30)
bert_service.load() bert_service.load()
bert_service.load_model_config(sys.argv[1]) bert_service.load_model_config(sys.argv[1])
bert_service.prepare_server( bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="cpu") workdir="workdir", port=int(sys.argv[2]), device="cpu")
bert_service.run_rpc_service() bert_service.run_debugger_service()
bert_service.run_web_service() bert_service.run_web_service()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#!flask/bin/python #!flask/bin/python
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from time import time as _time
from flask import Flask, request, abort from flask import Flask, request, abort
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_server import OpMaker, OpSeqMaker, Server
...@@ -23,7 +24,9 @@ import socket ...@@ -23,7 +24,9 @@ import socket
import numpy as np import numpy as np
from paddle_serving_server import pipeline from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op from paddle_serving_server.pipeline import Op
import collections
from .profiler import TimeProfiler, PerformanceTracer
import os
def port_is_available(port): def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
...@@ -44,6 +47,15 @@ class WebService(object): ...@@ -44,6 +47,15 @@ class WebService(object):
def get_pipeline_response(self, read_op): def get_pipeline_response(self, read_op):
return None return None
def setup_profile(self, trace_interval=10, thread_num=1):
self.is_profile = True
if self.is_profile:
self._tracer = PerformanceTracer(True, 10 ,1)
self.trace_buffer = self._tracer.data_buffer()
self._profiler = TimeProfiler()
self._profiler.enable(True)
self.data_id = 0
def prepare_pipeline_config(self, yaml_file): def prepare_pipeline_config(self, yaml_file):
# build dag # build dag
read_op = pipeline.RequestOp() read_op = pipeline.RequestOp()
...@@ -135,20 +147,62 @@ class WebService(object): ...@@ -135,20 +147,62 @@ class WebService(object):
abort(400) abort(400)
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
start_call, end_call = None, None
if self.is_profile:
trace_que = collections.deque()
start_call = self._profiler.record("call_{}".format(self.data_id))
try: try:
start = int(round(_time() * 1000000))
feed, fetch, is_batch = self.preprocess(request.json["feed"], feed, fetch, is_batch = self.preprocess(request.json["feed"],
request.json["fetch"]) request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed: if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"] del feed["fetch"]
if len(feed) == 0: if len(feed) == 0:
raise ValueError("empty input") raise ValueError("empty input")
end = int(round(_time() * 1000000))
prep_time = end - start
start = int(round(_time() * 1000000))
fetch_map = self.client.predict( fetch_map = self.client.predict(
feed=feed, fetch=fetch, batch=is_batch) feed=feed, fetch=fetch, batch=is_batch)
end = int(round(_time() * 1000000))
midp_time = end - start
start = int(round(_time() * 1000000))
result = self.postprocess( result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result} result = {"result": result}
end = int(round(_time() * 1000000))
postp_time = end - start
succ = 1
except ValueError as err: except ValueError as err:
succ = 0
result = {"result": str(err)} result = {"result": str(err)}
if self.is_profile:
end_call = self._profiler.record("call_{}".format(self.data_id))
self.data_id += 1
if self.trace_buffer is not None:
self.trace_buffer.put({
"name": "DAG",
"id": self.data_id,
"succ": succ,
"actions": {
"call_{}".format(self.data_id): end_call - start_call,
},
})
trace_que.append({
"name": "demo",
"actions": {
"prep": prep_time,
"midp": midp_time,
"postp": postp_time
}
})
while trace_que:
info = trace_que[0]
try:
self.trace_buffer.put_nowait(info)
trace_que.popleft()
except Queue.Full:
break
return result return result
def run_rpc_service(self): def run_rpc_service(self):
...@@ -202,6 +256,8 @@ class WebService(object): ...@@ -202,6 +256,8 @@ class WebService(object):
"{}".format(self.model_config), use_gpu=False) "{}".format(self.model_config), use_gpu=False)
def run_web_service(self): def run_web_service(self):
if self.is_profile:
self._tracer.start()
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
self.app_instance.run(host="0.0.0.0", port=self.port, threaded=True) self.app_instance.run(host="0.0.0.0", port=self.port, threaded=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册