diff --git a/python/examples/bert/bert_web_service.py b/python/examples/bert/bert_web_service.py index 7cd34fb99e0ecebbf2f6bec47e9c9d163ac3a44c..64091817d84ec2bb7e608acf06a8f550ce5972de 100644 --- a/python/examples/bert/bert_web_service.py +++ b/python/examples/bert/bert_web_service.py @@ -29,20 +29,24 @@ class BertService(WebService): def preprocess(self, feed=[], fetch=[]): feed_res = [] - is_batch = False + is_batch = True for ins in feed: feed_dict = self.reader.process(ins["words"].encode("utf-8")) for key in feed_dict.keys(): feed_dict[key] = np.array(feed_dict[key]).reshape( - (len(feed_dict[key]), 1)) + (1, len(feed_dict[key]), 1)) feed_res.append(feed_dict) - return feed_res, fetch, is_batch - + feed_dict = {} + for key in feed_res[0].keys(): + feed_dict[key] = np.concatenate([x[key] for x in feed_res], axis=0) + print(key, feed_dict[key].shape) + return feed_dict, fetch, is_batch bert_service = BertService(name="bert") +bert_service.setup_profile(30) bert_service.load() bert_service.load_model_config(sys.argv[1]) bert_service.prepare_server( workdir="workdir", port=int(sys.argv[2]), device="cpu") -bert_service.run_rpc_service() +bert_service.run_debugger_service() bert_service.run_web_service() diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 3be818f0ed778bc8e7a1297cae6638fac88ac20c..bec001dfdbb3e453485cd9b7613f3e55c53b6b4a 100644 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -14,6 +14,7 @@ #!flask/bin/python # pylint: disable=doc-string-missing +from time import time as _time from flask import Flask, request, abort from multiprocessing import Pool, Process from paddle_serving_server import OpMaker, OpSeqMaker, Server @@ -23,7 +24,9 @@ import socket import numpy as np from paddle_serving_server import pipeline from paddle_serving_server.pipeline import Op - +import collections +from .profiler import TimeProfiler, PerformanceTracer +import os def port_is_available(port): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: @@ -44,6 +47,15 @@ class WebService(object): def get_pipeline_response(self, read_op): return None + def setup_profile(self, trace_interval=10, thread_num=1): + self.is_profile = True + if self.is_profile: + self._tracer = PerformanceTracer(True, 10 ,1) + self.trace_buffer = self._tracer.data_buffer() + self._profiler = TimeProfiler() + self._profiler.enable(True) + self.data_id = 0 + def prepare_pipeline_config(self, yaml_file): # build dag read_op = pipeline.RequestOp() @@ -135,20 +147,62 @@ class WebService(object): abort(400) if "fetch" not in request.json: abort(400) + start_call, end_call = None, None + if self.is_profile: + trace_que = collections.deque() + start_call = self._profiler.record("call_{}".format(self.data_id)) try: + start = int(round(_time() * 1000000)) feed, fetch, is_batch = self.preprocess(request.json["feed"], request.json["fetch"]) if isinstance(feed, dict) and "fetch" in feed: del feed["fetch"] if len(feed) == 0: raise ValueError("empty input") + end = int(round(_time() * 1000000)) + prep_time = end - start + start = int(round(_time() * 1000000)) fetch_map = self.client.predict( feed=feed, fetch=fetch, batch=is_batch) + end = int(round(_time() * 1000000)) + midp_time = end - start + start = int(round(_time() * 1000000)) result = self.postprocess( feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) result = {"result": result} + end = int(round(_time() * 1000000)) + postp_time = end - start + succ = 1 except ValueError as err: + succ = 0 result = {"result": str(err)} + if self.is_profile: + end_call = self._profiler.record("call_{}".format(self.data_id)) + self.data_id += 1 + if self.trace_buffer is not None: + self.trace_buffer.put({ + "name": "DAG", + "id": self.data_id, + "succ": succ, + "actions": { + "call_{}".format(self.data_id): end_call - start_call, + }, + }) + trace_que.append({ + "name": "demo", + "actions": { + "prep": prep_time, + "midp": midp_time, + "postp": postp_time + } + }) + while trace_que: + info = trace_que[0] + try: + self.trace_buffer.put_nowait(info) + trace_que.popleft() + except Queue.Full: + break return result def run_rpc_service(self): @@ -202,6 +256,8 @@ class WebService(object): "{}".format(self.model_config), use_gpu=False) def run_web_service(self): + if self.is_profile: + self._tracer.start() print("This API will be deprecated later. Please do not use it") self.app_instance.run(host="0.0.0.0", port=self.port, threaded=True)