提交 96fbc794 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #449 from MRXLT/web-service

refine web service
...@@ -19,6 +19,7 @@ Usage: ...@@ -19,6 +19,7 @@ Usage:
""" """
import argparse import argparse
from .web_service import WebService from .web_service import WebService
from flask import Flask, request
def parse_args(): # pylint: disable=doc-string-missing def parse_args(): # pylint: disable=doc-string-missing
...@@ -88,3 +89,20 @@ if __name__ == "__main__": ...@@ -88,3 +89,20 @@ if __name__ == "__main__":
service.prepare_server( service.prepare_server(
workdir=args.workdir, port=args.port, device=args.device) workdir=args.workdir, port=args.port, device=args.device)
service.run_server() service.run_server()
app_instance = Flask(__name__)
@app_instance.before_first_request
def init():
service._launch_web_service()
service_name = "/" + service.name + "/prediction"
@app_instance.route(service_name, methods=["POST"])
def run():
return service.get_prediction(request)
app_instance.run(host="0.0.0.0",
port=service.port,
threaded=False,
processes=4)
...@@ -50,44 +50,33 @@ class WebService(object): ...@@ -50,44 +50,33 @@ class WebService(object):
self.device = device self.device = device
def _launch_web_service(self): def _launch_web_service(self):
app_instance = Flask(__name__) self.client_service = Client()
client_service = Client() self.client_service.load_client_config(
client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config)) "{}/serving_server_conf.prototxt".format(self.model_config))
client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=['POST']) def get_prediction(self, request):
def get_prediction(): if not request.json:
if not request.json: abort(400)
abort(400) if "fetch" not in request.json:
if "fetch" not in request.json: abort(400)
abort(400) try:
try: feed, fetch = self.preprocess(request.json, request.json["fetch"])
feed, fetch = self.preprocess(request.json, if isinstance(feed, list):
request.json["fetch"]) fetch_map_batch = self.client_service.predict(
if isinstance(feed, list): feed_batch=feed, fetch=fetch)
fetch_map_batch = client_service.predict( fetch_map_batch = self.postprocess(
feed_batch=feed, fetch=fetch) feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
fetch_map_batch = self.postprocess( result = {"result": fetch_map_batch}
feed=request.json, elif isinstance(feed, dict):
fetch=fetch, if "fetch" in feed:
fetch_map=fetch_map_batch) del feed["fetch"]
result = {"result": fetch_map_batch} fetch_map = self.client_service.predict(feed=feed, fetch=fetch)
elif isinstance(feed, dict): result = self.postprocess(
if "fetch" in feed: feed=request.json, fetch=fetch, fetch_map=fetch_map)
del feed["fetch"] except ValueError:
fetch_map = client_service.predict(feed=feed, fetch=fetch) result = {"result": "Request Value Error"}
result = self.postprocess( return result
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError:
result = {"result": "Request Value Error"}
return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
def run_server(self): def run_server(self):
import socket import socket
...@@ -96,11 +85,7 @@ class WebService(object): ...@@ -96,11 +85,7 @@ class WebService(object):
print("http://{}:{}/{}/prediction".format(localIP, self.port, print("http://{}:{}/{}/prediction".format(localIP, self.port,
self.name)) self.name))
p_rpc = Process(target=self._launch_rpc_service) p_rpc = Process(target=self._launch_rpc_service)
p_web = Process(target=self._launch_web_service)
p_rpc.start() p_rpc.start()
p_web.start()
p_web.join()
p_rpc.join()
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
......
...@@ -21,6 +21,7 @@ import argparse ...@@ -21,6 +21,7 @@ import argparse
import os import os
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args from paddle_serving_server_gpu import serve_args
from flask import Flask, request
def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing
...@@ -114,3 +115,20 @@ if __name__ == "__main__": ...@@ -114,3 +115,20 @@ if __name__ == "__main__":
web_service.prepare_server( web_service.prepare_server(
workdir=args.workdir, port=args.port, device=args.device) workdir=args.workdir, port=args.port, device=args.device)
web_service.run_server() web_service.run_server()
app_instance = Flask(__name__)
@app_instance.before_first_request
def init():
web_service._launch_web_service()
service_name = "/" + web_service.name + "/prediction"
@app_instance.route(service_name, methods=["POST"])
def run():
return web_service.get_prediction(request)
app_instance.run(host="0.0.0.0",
port=web_service.port,
threaded=False,
processes=4)
...@@ -11,17 +11,16 @@ ...@@ -11,17 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#!flask/bin/python
# pylint: disable=doc-string-missing
from flask import Flask, request, abort from flask import Flask, request, abort
from multiprocessing import Pool, Process, Queue
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
import paddle_serving_server_gpu as serving import paddle_serving_server_gpu as serving
from multiprocessing import Pool, Process, Queue
from paddle_serving_client import Client from paddle_serving_client import Client
from .serve import start_multi_card from paddle_serving_server_gpu.serve import start_multi_card
import time
import random import sys
import numpy as np
class WebService(object): class WebService(object):
...@@ -29,7 +28,6 @@ class WebService(object): ...@@ -29,7 +28,6 @@ class WebService(object):
self.name = name self.name = name
self.gpus = [] self.gpus = []
self.rpc_service_list = [] self.rpc_service_list = []
self.input_queues = []
def load_model_config(self, model_config): def load_model_config(self, model_config):
self.model_config = model_config self.model_config = model_config
...@@ -66,12 +64,6 @@ class WebService(object): ...@@ -66,12 +64,6 @@ class WebService(object):
return server return server
def _launch_rpc_service(self, service_idx): def _launch_rpc_service(self, service_idx):
if service_idx == 0:
self.rpc_service_list[service_idx].check_local_bin()
if not self.rpc_service_list[service_idx].use_local_bin:
self.rpc_service_list[service_idx].download_bin()
else:
time.sleep(3)
self.rpc_service_list[service_idx].run_server() self.rpc_service_list[service_idx].run_server()
def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0):
...@@ -93,87 +85,30 @@ class WebService(object): ...@@ -93,87 +85,30 @@ class WebService(object):
gpuid, gpuid,
thread_num=10)) thread_num=10))
def producers(self, inputqueue, endpoint): def _launch_web_service(self):
client = Client() gpu_num = len(self.gpus)
client.load_client_config("{}/serving_server_conf.prototxt".format( self.client = Client()
self.client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config)) self.model_config))
client.connect([endpoint]) endpoints = ""
while True: if gpu_num > 0:
request_json = inputqueue.get() for i in range(gpu_num):
try: endpoints += "127.0.0.1:{},".format(self.port + i + 1)
feed, fetch = self.preprocess(request_json, else:
request_json["fetch"]) endpoints = "127.0.0.1:{}".format(self.port + 1)
if isinstance(feed, list): self.client.connect([endpoints])
fetch_map_batch = client.predict(
feed_batch=feed, fetch=fetch) def get_prediction(self, request):
fetch_map_batch = self.postprocess( if not request.json:
feed=request_json, abort(400)
fetch=fetch, if "fetch" not in request.json:
fetch_map=fetch_map_batch) abort(400)
result = {"result": fetch_map_batch} feed, fetch = self.preprocess(request.json, request.json["fetch"])
elif isinstance(feed, dict): fetch_map_batch = self.client.predict(feed=feed, fetch=fetch)
if "fetch" in feed: fetch_map_batch = self.postprocess(
del feed["fetch"] feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
fetch_map = client.predict(feed=feed, fetch=fetch) result = {"result": fetch_map_batch}
result = self.postprocess( return result
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
except ValueError:
self.output_queue.put(-1)
def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__)
service_name = "/" + self.name + "/prediction"
self.input_queues = []
self.output_queue = Queue()
for i in range(gpu_num):
self.input_queues.append(Queue())
producer_list = []
for i, input_q in enumerate(self.input_queues):
producer_processes = Process(
target=self.producers,
args=(
input_q,
"0.0.0.0:{}".format(self.port + 1 + i), ))
producer_list.append(producer_processes)
for p in producer_list:
p.start()
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect(["0.0.0.0:{}".format(self.port + 1)])
self.idx = 0
@app_instance.route(service_name, methods=['POST'])
def get_prediction():
if not request.json:
abort(400)
if "fetch" not in request.json:
abort(400)
self.input_queues[self.idx].put(request.json)
#self.input_queues[0].put(request.json)
self.idx += 1
if self.idx >= len(self.gpus):
self.idx = 0
result = self.output_queue.get()
if not isinstance(result, dict) and result == -1:
result = {"result": "Request Value Error"}
return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
for p in producer_list:
p.join()
def run_server(self): def run_server(self):
import socket import socket
...@@ -188,13 +123,6 @@ class WebService(object): ...@@ -188,13 +123,6 @@ class WebService(object):
for p in server_pros: for p in server_pros:
p.start() p.start()
p_web = Process(
target=self._launch_web_service, args=(len(self.gpus), ))
p_web.start()
p_web.join()
for p in server_pros:
p.join()
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册