提交 221d977c 编写于 作者: M MRXLT

bug fix

上级 b3ccae27
...@@ -54,6 +54,7 @@ class ImageService(WebService): ...@@ -54,6 +54,7 @@ class ImageService(WebService):
score_list = fetch_map["score"] score_list = fetch_map["score"]
result = {"label": [], "prob": []} result = {"label": [], "prob": []}
for score in score_list: for score in score_list:
score = score.tolist()
max_score = max(score) max_score = max(score)
result["label"].append(self.label_dict[score.index(max_score)] result["label"].append(self.label_dict[score.index(max_score)]
.strip().replace(",", "")) .strip().replace(",", ""))
......
...@@ -92,8 +92,6 @@ class WebService(object): ...@@ -92,8 +92,6 @@ class WebService(object):
if isinstance(feed, dict) and "fetch" in feed: if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"] del feed["fetch"]
fetch_map = self.client.predict(feed=feed, fetch=fetch) fetch_map = self.client.predict(feed=feed, fetch=fetch)
for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
result = self.postprocess( result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result} result = {"result": result}
...@@ -137,4 +135,6 @@ class WebService(object): ...@@ -137,4 +135,6 @@ class WebService(object):
return feed, fetch return feed, fetch
def postprocess(self, feed=[], fetch=[], fetch_map=None): def postprocess(self, feed=[], fetch=[], fetch_map=None):
for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
return fetch_map return fetch_map
...@@ -60,8 +60,8 @@ class WebService(object): ...@@ -60,8 +60,8 @@ class WebService(object):
server = Server() server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num) server.set_num_threads(thread_num)
server.set_mem_optim(mem_optim) server.set_memory_optimize(mem_optim)
server.set_ir_optim(ir_optim) server.set_ir_optimize(ir_optim)
server.load_model_config(self.model_config) server.load_model_config(self.model_config)
if gpuid >= 0: if gpuid >= 0:
...@@ -108,8 +108,8 @@ class WebService(object): ...@@ -108,8 +108,8 @@ class WebService(object):
self.port_list[0], self.port_list[0],
-1, -1,
thread_num=2, thread_num=2,
mem_optim, mem_optim=mem_optim,
ir_optim)) ir_optim=ir_optim))
else: else:
for i, gpuid in enumerate(self.gpus): for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append( self.rpc_service_list.append(
...@@ -118,8 +118,8 @@ class WebService(object): ...@@ -118,8 +118,8 @@ class WebService(object):
self.port_list[i], self.port_list[i],
gpuid, gpuid,
thread_num=2, thread_num=2,
mem_optim, mem_optim=mem_optim,
ir_optim)) ir_optim=ir_optim))
def _launch_web_service(self): def _launch_web_service(self):
gpu_num = len(self.gpus) gpu_num = len(self.gpus)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册