提交 13534d39 编写于 作者: G guru4elephant

refine web_service.py

上级 17067a47
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from flask import Flask, request, abort from flask import Flask, request, abort
from multiprocessing import Pool, Process from multiprocessing import Pool, Process, Queue
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
import paddle_serving_server_gpu as serving import paddle_serving_server_gpu as serving
from paddle_serving_client import Client from paddle_serving_client import Client
...@@ -29,27 +29,13 @@ class WebService(object): ...@@ -29,27 +29,13 @@ class WebService(object):
self.name = name self.name = name
self.gpus = [] self.gpus = []
self.rpc_service_list = [] self.rpc_service_list = []
self.input_queues = []
def producers(self, input_queue, output_queue, endpoint):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect([endpoint])
while True:
request_json = input_queue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
output_queue.put(fetch_map)
def load_model_config(self, model_config): def load_model_config(self, model_config):
self.model_config = model_config self.model_config = model_config
def set_gpus(self, gpus): def set_gpus(self, gpus):
self.gpus = gpus self.gpus = [int(x) for x in gpus.split(",")]
def default_rpc_service(self, def default_rpc_service(self,
workdir="conf", workdir="conf",
...@@ -93,21 +79,6 @@ class WebService(object): ...@@ -93,21 +79,6 @@ class WebService(object):
self.default_rpc_service( self.default_rpc_service(
self.workdir, self.port + 1, -1, thread_num=10)) self.workdir, self.port + 1, -1, thread_num=10))
else: else:
self.producer_pros = []
for i, gpuid in enumerate(self.gpus):
producer_p = Process(
target=self.producers,
args=(
self,
self.input_queue[i],
self.output_queue,
self.port + 1 + i,
gpuid, ))
self.producer_pros.append(producer_p)
for p in producer_pros:
p.start()
for i, gpuid in enumerate(self.gpus): for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append( self.rpc_service_list.append(
self.default_rpc_service( self.default_rpc_service(
...@@ -122,7 +93,7 @@ class WebService(object): ...@@ -122,7 +93,7 @@ class WebService(object):
self.model_config)) self.model_config))
client.connect([endpoint]) client.connect([endpoint])
while True: while True:
request_json = input_queue.get() request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"]) feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed: if "fetch" in feed:
del feed["fetch"] del feed["fetch"]
...@@ -135,23 +106,29 @@ class WebService(object): ...@@ -135,23 +106,29 @@ class WebService(object):
app_instance = Flask(__name__) app_instance = Flask(__name__)
service_name = "/" + self.name + "/prediction" service_name = "/" + self.name + "/prediction"
input_queues = [] self.input_queues = []
output_queue = Queue() self.output_queue = Queue()
for i in range(gpu_num): for i in range(gpu_num):
input_queues.append(Queue()) self.input_queues.append(Queue())
producer_list = [] producer_list = []
for i, input_q in enumerate(input_queues): for i, input_q in enumerate(self.input_queues):
producer_processes = Process( producer_processes = Process(
target=self.producers, target=self.producers,
input_q, args=(
"0.0.0.0:{}".format(self.port + 1 + i)) input_q,
"0.0.0.0:{}".format(self.port + 1 + i), ))
producer_list.append(producer_processes) producer_list.append(producer_processes)
for p in producer_list: for p in producer_list:
p.start() p.start()
idx = 0 client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect(["0.0.0.0:{}".format(self.port + 1)])
self.idx = 0
@app_instance.route(service_name, methods=['POST']) @app_instance.route(service_name, methods=['POST'])
def get_prediction(): def get_prediction():
...@@ -160,12 +137,23 @@ class WebService(object): ...@@ -160,12 +137,23 @@ class WebService(object):
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
input_queues[idx].put(request.json) self.input_queues[self.idx].put(request.json)
result = output_queue.get()
idx += 1 #self.input_queues[0].put(request.json)
if idx >= len(self.gpus): self.idx += 1
idx = 0 if self.idx >= len(self.gpus):
self.idx = 0
result = self.output_queue.get()
return result return result
'''
feed, fetch = self.preprocess(request.json, request.json["fetch"])
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
return fetch_map
'''
app_instance.run(host="0.0.0.0", app_instance.run(host="0.0.0.0",
port=self.port, port=self.port,
...@@ -183,7 +171,7 @@ class WebService(object): ...@@ -183,7 +171,7 @@ class WebService(object):
self.name)) self.name))
server_pros = [] server_pros = []
for i, service in enumerate(self.rpc_service_list): for i, service in enumerate(self.rpc_service_list):
p = Process(target=_launch_rpc_service, args=(i, )) p = Process(target=self._launch_rpc_service, args=(i, ))
server_pros.append(p) server_pros.append(p)
for p in server_pros: for p in server_pros:
p.start() p.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册