提交 d9e564d7 编写于 作者: D dongdaxiang

change web service

上级 a31710f3
...@@ -55,41 +55,32 @@ class WebService(object): ...@@ -55,41 +55,32 @@ class WebService(object):
"{}/serving_server_conf.prototxt".format(self.model_config)) "{}/serving_server_conf.prototxt".format(self.model_config))
self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
@app_instance.route(service_name, methods=['POST']) def get_prediction():
def get_prediction(): if not request.json:
if not request.json: abort(400)
abort(400) if "fetch" not in request.json:
if "fetch" not in request.json: abort(400)
abort(400) try:
try: feed, fetch = self.preprocess(request.json, request.json["fetch"])
feed, fetch = self.preprocess(request.json, if isinstance(feed, list):
request.json["fetch"]) fetch_map_batch = client_service.predict(
if isinstance(feed, list): feed_batch=feed, fetch=fetch)
fetch_map_batch = client_service.predict( fetch_map_batch = self.postprocess(
feed_batch=feed, fetch=fetch) feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
fetch_map_batch = self.postprocess( for key in fetch_map_batch:
feed=request.json, fetch_map_batch[key] = fetch_map_batch[key].tolist()
fetch=fetch, result = {"result": fetch_map_batch}
fetch_map=fetch_map_batch) elif isinstance(feed, dict):
for key in fetch_map_batch: if "fetch" in feed:
fetch_map_batch[key] = fetch_map_batch[key].tolist() del feed["fetch"]
result = {"result": fetch_map_batch} fetch_map = client_service.predict(feed=feed, fetch=fetch)
elif isinstance(feed, dict): for key in fetch_map:
if "fetch" in feed: fetch_map[key] = fetch_map[key][0].tolist()
del feed["fetch"] result = self.postprocess(
fetch_map = client_service.predict(feed=feed, fetch=fetch) feed=request.json, fetch=fetch, fetch_map=fetch_map)
for key in fetch_map: except ValueError:
fetch_map[key] = fetch_map[key][0].tolist() result = {"result": "Request Value Error"}
result = self.postprocess( return result
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError:
result = {"result": "Request Value Error"}
return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
def run_server(self): def run_server(self):
import socket import socket
......
...@@ -147,30 +147,21 @@ class WebService(object): ...@@ -147,30 +147,21 @@ class WebService(object):
self.idx = 0 self.idx = 0
@app_instance.route(service_name, methods=['POST']) def get_prediction():
def get_prediction(): if not request.json:
if not request.json: abort(400)
abort(400) if "fetch" not in request.json:
if "fetch" not in request.json: abort(400)
abort(400)
self.input_queues[self.idx].put(request.json)
self.input_queues[self.idx].put(request.json)
self.idx += 1
self.idx += 1 if self.idx >= len(self.gpus):
if self.idx >= len(self.gpus): self.idx = 0
self.idx = 0 result = self.output_queue.get()
result = self.output_queue.get() if not isinstance(result, dict) and result == -1:
if not isinstance(result, dict) and result == -1: result = {"result": "Request Value Error"}
result = {"result": "Request Value Error"} return result
return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
for p in producer_list:
p.join()
def run_server(self): def run_server(self):
import socket import socket
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册