提交 d26a1c9b 编写于 作者: M MRXLT

fix bug && unify http service

上级 55b17479
...@@ -67,10 +67,10 @@ class WebService(object): ...@@ -67,10 +67,10 @@ class WebService(object):
break break
def _launch_web_service(self): def _launch_web_service(self):
self.client_service = Client() self.client = Client()
self.client_service.load_client_config( self.client.load_client_config("{}/serving_server_conf.prototxt".format(
"{}/serving_server_conf.prototxt".format(self.model_config)) self.model_config))
self.client_service.connect(["0.0.0.0:{}".format(self.port_list[0])]) self.client.connect(["0.0.0.0:{}".format(self.port_list[0])])
def get_prediction(self, request): def get_prediction(self, request):
if not request.json: if not request.json:
...@@ -79,22 +79,14 @@ class WebService(object): ...@@ -79,22 +79,14 @@ class WebService(object):
abort(400) abort(400)
try: try:
feed, fetch = self.preprocess(request.json, request.json["fetch"]) feed, fetch = self.preprocess(request.json, request.json["fetch"])
if isinstance(feed, list): if isinstance(feed, dict) and "fetch" in feed:
fetch_map_batch = self.client_service.predict( del feed["fetch"]
feed_batch=feed, fetch=fetch) fetch_map = self.client.predict(feed=feed, fetch=fetch)
fetch_map_batch = self.postprocess( fetch_map = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) feed=request.json, fetch=fetch, fetch_map=fetch_map)
for key in fetch_map_batch: for key in fetch_map:
fetch_map_batch[key] = fetch_map_batch[key].tolist() fetch_map[key] = fetch_map[key].tolist()
result = {"result": fetch_map_batch} result = {"result": fetch_map}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = self.client_service.predict(feed=feed, fetch=fetch)
for key in fetch_map:
fetch_map[key] = fetch_map[key][0].tolist()
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError: except ValueError:
result = {"result": "Request Value Error"} result = {"result": "Request Value Error"}
return result return result
......
...@@ -121,13 +121,18 @@ class WebService(object): ...@@ -121,13 +121,18 @@ class WebService(object):
abort(400) abort(400)
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"]) try:
fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) feed, fetch = self.preprocess(request.json, request.json["fetch"])
fetch_map_batch = self.postprocess( if isinstance(feed, dict) and "fetch" in feed:
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) del feed["fetch"]
for key in fetch_map_batch: fetch_map = self.client.predict(feed=feed, fetch=fetch)
fetch_map_batch[key] = fetch_map_batch[key].tolist() fetch_map = self.postprocess(
result = {"result": fetch_map_batch} feed=request.json, fetch=fetch, fetch_map=fetch_map)
for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
result = {"result": fetch_map}
except ValueError:
result = {"result": "Request Value Error"}
return result return result
def run_server(self): def run_server(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册