未验证 提交 d48e1cfd 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #467 from MRXLT/web-service

find available port automatically && add run_flask
...@@ -36,3 +36,4 @@ bert_service.set_gpus(gpu_ids) ...@@ -36,3 +36,4 @@ bert_service.set_gpus(gpu_ids)
bert_service.prepare_server( bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu") workdir="workdir", port=int(sys.argv[2]), device="gpu")
bert_service.run_server() bert_service.run_server()
bert_service.run_flask()
...@@ -31,14 +31,14 @@ class ImageService(WebService): ...@@ -31,14 +31,14 @@ class ImageService(WebService):
sample = base64.b64decode(image) sample = base64.b64decode(image)
img = reader.process_image(sample) img = reader.process_image(sample)
res_feed = {} res_feed = {}
res_feed["image"] = img.reshape(-1) res_feed["image"] = img
feed_batch.append(res_feed) feed_batch.append(res_feed)
return feed_batch, fetch return feed_batch, fetch
else: else:
sample = base64.b64decode(feed["image"]) sample = base64.b64decode(feed["image"])
img = reader.process_image(sample) img = reader.process_image(sample)
res_feed = {} res_feed = {}
res_feed["image"] = img.reshape(-1) res_feed["image"] = img
return res_feed, fetch return res_feed, fetch
...@@ -47,3 +47,4 @@ image_service.load_model_config(sys.argv[1]) ...@@ -47,3 +47,4 @@ image_service.load_model_config(sys.argv[1])
image_service.prepare_server( image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu")
image_service.run_server() image_service.run_server()
image_service.run_flask()
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_server_gpu.web_service import WebService
import sys import sys
import cv2 import cv2
import base64 import base64
import numpy as np import numpy as np
from image_reader import ImageReader from image_reader import ImageReader
from paddle_serving_server_gpu.web_service import WebService
class ImageService(WebService): class ImageService(WebService):
...@@ -32,14 +32,14 @@ class ImageService(WebService): ...@@ -32,14 +32,14 @@ class ImageService(WebService):
sample = base64.b64decode(image) sample = base64.b64decode(image)
img = reader.process_image(sample) img = reader.process_image(sample)
res_feed = {} res_feed = {}
res_feed["image"] = img.reshape(-1) res_feed["image"] = img
feed_batch.append(res_feed) feed_batch.append(res_feed)
return feed_batch, fetch return feed_batch, fetch
else: else:
sample = base64.b64decode(feed["image"]) sample = base64.b64decode(feed["image"])
img = reader.process_image(sample) img = reader.process_image(sample)
res_feed = {} res_feed = {}
res_feed["image"] = img.reshape(-1) res_feed["image"] = img
return res_feed, fetch return res_feed, fetch
...@@ -49,3 +49,4 @@ image_service.set_gpus("0,1") ...@@ -49,3 +49,4 @@ image_service.set_gpus("0,1")
image_service.prepare_server( image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server() image_service.run_server()
image_service.run_flask()
...@@ -31,7 +31,7 @@ def predict(image_path, server): ...@@ -31,7 +31,7 @@ def predict(image_path, server):
r = requests.post( r = requests.post(
server, data=req, headers={"Content-Type": "application/json"}) server, data=req, headers={"Content-Type": "application/json"})
try: try:
print(r.json()["score"][0]) print(r.json()["result"]["score"])
except ValueError: except ValueError:
print(r.text) print(r.text)
return r return r
......
...@@ -26,7 +26,7 @@ start = time.time() ...@@ -26,7 +26,7 @@ start = time.time()
for i in range(1000): for i in range(1000):
with open("./data/n01440764_10026.JPEG", "rb") as f: with open("./data/n01440764_10026.JPEG", "rb") as f:
img = f.read() img = f.read()
img = reader.process_image(img).reshape(-1) img = reader.process_image(img)
fetch_map = client.predict(feed={"image": img}, fetch=["score"]) fetch_map = client.predict(feed={"image": img}, fetch=["score"])
end = time.time() end = time.time()
print(end - start) print(end - start)
......
...@@ -39,3 +39,4 @@ imdb_service.prepare_server( ...@@ -39,3 +39,4 @@ imdb_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu")
imdb_service.prepare_dict({"dict_file_path": sys.argv[4]}) imdb_service.prepare_dict({"dict_file_path": sys.argv[4]})
imdb_service.run_server() imdb_service.run_server()
imdb_service.run_flask()
...@@ -351,6 +351,7 @@ class Server(object): ...@@ -351,6 +351,7 @@ class Server(object):
self._prepare_resource(workdir) self._prepare_resource(workdir)
self._prepare_engine(self.model_config_paths, device) self._prepare_engine(self.model_config_paths, device)
self._prepare_infer_service(port) self._prepare_infer_service(port)
self.port = port
self.workdir = workdir self.workdir = workdir
infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn) infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
......
...@@ -18,6 +18,8 @@ from flask import Flask, request, abort ...@@ -18,6 +18,8 @@ from flask import Flask, request, abort
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_server import OpMaker, OpSeqMaker, Server
from paddle_serving_client import Client from paddle_serving_client import Client
from contextlib import closing
import socket
class WebService(object): class WebService(object):
...@@ -41,19 +43,34 @@ class WebService(object): ...@@ -41,19 +43,34 @@ class WebService(object):
server.set_num_threads(16) server.set_num_threads(16)
server.load_model_config(self.model_config) server.load_model_config(self.model_config)
server.prepare_server( server.prepare_server(
workdir=self.workdir, port=self.port + 1, device=self.device) workdir=self.workdir, port=self.port_list[0], device=self.device)
server.run_server() server.run_server()
def port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
def prepare_server(self, workdir="", port=9393, device="cpu"): def prepare_server(self, workdir="", port=9393, device="cpu"):
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
self.device = device self.device = device
default_port = 12000
self.port_list = []
for i in range(1000):
if self.port_is_available(default_port + i):
self.port_list.append(default_port + i)
break
def _launch_web_service(self): def _launch_web_service(self):
self.client_service = Client() self.client = Client()
self.client_service.load_client_config( self.client.load_client_config("{}/serving_server_conf.prototxt".format(
"{}/serving_server_conf.prototxt".format(self.model_config)) self.model_config))
self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) self.client.connect(["0.0.0.0:{}".format(self.port_list[0])])
def get_prediction(self, request): def get_prediction(self, request):
if not request.json: if not request.json:
...@@ -64,12 +81,12 @@ class WebService(object): ...@@ -64,12 +81,12 @@ class WebService(object):
feed, fetch = self.preprocess(request.json, request.json["fetch"]) feed, fetch = self.preprocess(request.json, request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed: if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"] del feed["fetch"]
fetch_map = self.client_service.predict(feed=feed, fetch=fetch) fetch_map = self.client.predict(feed=feed, fetch=fetch)
for key in fetch_map: fetch_map = self.postprocess(
fetch_map[key] = fetch_map[key][0].tolist()
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map) feed=request.json, fetch=fetch, fetch_map=fetch_map)
result = {"result": result} for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
result = {"result": fetch_map}
except ValueError: except ValueError:
result = {"result": "Request Value Error"} result = {"result": "Request Value Error"}
return result return result
...@@ -83,6 +100,24 @@ class WebService(object): ...@@ -83,6 +100,24 @@ class WebService(object):
p_rpc = Process(target=self._launch_rpc_service) p_rpc = Process(target=self._launch_rpc_service)
p_rpc.start() p_rpc.start()
def run_flask(self):
app_instance = Flask(__name__)
@app_instance.before_first_request
def init():
self._launch_web_service()
service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=["POST"])
def run():
return self.get_prediction(request)
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=4)
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
......
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from flask import Flask, request, abort from flask import Flask, request, abort
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server from contextlib import closing
import paddle_serving_server_gpu as serving
from multiprocessing import Pool, Process, Queue from multiprocessing import Pool, Process, Queue
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
from paddle_serving_server_gpu.serve import start_multi_card from paddle_serving_server_gpu.serve import start_multi_card
import socket
import sys import sys
import numpy as np import numpy as np
import paddle_serving_server_gpu as serving
class WebService(object): class WebService(object):
...@@ -67,22 +68,39 @@ class WebService(object): ...@@ -67,22 +68,39 @@ class WebService(object):
def _launch_rpc_service(self, service_idx): def _launch_rpc_service(self, service_idx):
self.rpc_service_list[service_idx].run_server() self.rpc_service_list[service_idx].run_server()
def port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0):
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
self.device = device self.device = device
self.gpuid = gpuid self.gpuid = gpuid
self.port_list = []
default_port = 12000
for i in range(1000):
if self.port_is_available(default_port + i):
self.port_list.append(default_port + i)
if len(self.port_list) > len(self.gpus):
break
if len(self.gpus) == 0: if len(self.gpus) == 0:
# init cpu service # init cpu service
self.rpc_service_list.append( self.rpc_service_list.append(
self.default_rpc_service( self.default_rpc_service(
self.workdir, self.port + 1, -1, thread_num=10)) self.workdir, self.port_list[0], -1, thread_num=10))
else: else:
for i, gpuid in enumerate(self.gpus): for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append( self.rpc_service_list.append(
self.default_rpc_service( self.default_rpc_service(
"{}_{}".format(self.workdir, i), "{}_{}".format(self.workdir, i),
self.port + 1 + i, self.port_list[i],
gpuid, gpuid,
thread_num=10)) thread_num=10))
...@@ -94,9 +112,9 @@ class WebService(object): ...@@ -94,9 +112,9 @@ class WebService(object):
endpoints = "" endpoints = ""
if gpu_num > 0: if gpu_num > 0:
for i in range(gpu_num): for i in range(gpu_num):
endpoints += "127.0.0.1:{},".format(self.port + i + 1) endpoints += "127.0.0.1:{},".format(self.port_list[i])
else: else:
endpoints = "127.0.0.1:{}".format(self.port + 1) endpoints = "127.0.0.1:{}".format(self.port_list[0])
self.client.connect([endpoints]) self.client.connect([endpoints])
def get_prediction(self, request): def get_prediction(self, request):
...@@ -109,11 +127,11 @@ class WebService(object): ...@@ -109,11 +127,11 @@ class WebService(object):
if isinstance(feed, dict) and "fetch" in feed: if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"] del feed["fetch"]
fetch_map = self.client.predict(feed=feed, fetch=fetch) fetch_map = self.client.predict(feed=feed, fetch=fetch)
for key in fetch_map: fetch_map = self.postprocess(
fetch_map[key] = fetch_map[key][0].tolist()
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map) feed=request.json, fetch=fetch, fetch_map=fetch_map)
result = {"result": result} for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
result = {"result": fetch_map}
except ValueError: except ValueError:
result = {"result": "Request Value Error"} result = {"result": "Request Value Error"}
return result return result
...@@ -131,6 +149,24 @@ class WebService(object): ...@@ -131,6 +149,24 @@ class WebService(object):
for p in server_pros: for p in server_pros:
p.start() p.start()
def run_flask(self):
app_instance = Flask(__name__)
@app_instance.before_first_request
def init():
self._launch_web_service()
service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=["POST"])
def run():
return self.get_prediction(request)
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=4)
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册