提交 686afbbf 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #374 from MRXLT/0.2.0-fix-gpu

fix web service
......@@ -63,19 +63,25 @@ class WebService(object):
abort(400)
if "fetch" not in request.json:
abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client_service.batch_predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
try:
feed, fetch = self.preprocess(request.json,
request.json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client_service.predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json,
fetch=fetch,
fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError:
result = {"result": "Request Value Error"}
return result
app_instance.run(host="0.0.0.0",
......
......@@ -94,21 +94,26 @@ class WebService(object):
client.connect([endpoint])
while True:
request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client.batch_predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
try:
feed, fetch = self.preprocess(request_json,
request_json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client.predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request_json,
fetch=fetch,
fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
except ValueError:
self.output_queue.put(-1)
def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__)
......@@ -152,6 +157,8 @@ class WebService(object):
if self.idx >= len(self.gpus):
self.idx = 0
result = self.output_queue.get()
if not isinstance(result, dict) and result == -1:
result = {"result": "Request Value Error"}
return result
'''
feed, fetch = self.preprocess(request.json, request.json["fetch"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册