提交 6eafd55c 编写于 作者: D dongdaxiang 提交者: dongdaxiang

update web service

上级 50cdcc94
...@@ -235,6 +235,8 @@ class Client(object): ...@@ -235,6 +235,8 @@ class Client(object):
int_feed_names.append(key) int_feed_names.append(key)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
int_shape.append(list(feed_i[key].shape)) int_shape.append(list(feed_i[key].shape))
else:
int_shape.append(self.feed_shapes_[key])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
int_slot.append(np.reshape(feed_i[key], (-1)).tolist()) int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
else: else:
...@@ -244,6 +246,8 @@ class Client(object): ...@@ -244,6 +246,8 @@ class Client(object):
float_feed_names.append(key) float_feed_names.append(key)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
float_shape.append(list(feed_i[key].shape)) float_shape.append(list(feed_i[key].shape))
else:
float_shape.append(self.feed_shapes_[key])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
float_slot.append( float_slot.append(
np.reshape(feed_i[key], (-1)).tolist()) np.reshape(feed_i[key], (-1)).tolist())
......
...@@ -55,29 +55,42 @@ class WebService(object): ...@@ -55,29 +55,42 @@ class WebService(object):
"{}/serving_server_conf.prototxt".format(self.model_config)) "{}/serving_server_conf.prototxt".format(self.model_config))
self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
def get_prediction(self, request): @app_instance.route(service_name, methods=['POST'])
def get_prediction():
if not request.json: if not request.json:
abort(400) abort(400)
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
try: try:
feed, fetch = self.preprocess(request.json, request.json["fetch"]) feed, fetch = self.preprocess(request.json,
request.json["fetch"])
if isinstance(feed, list): if isinstance(feed, list):
fetch_map_batch = self.client_service.predict( fetch_map_batch = client_service.predict(
feed_batch=feed, fetch=fetch) feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess( fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) feed=request.json,
fetch=fetch,
fetch_map=fetch_map_batch)
for key in fetch_map_batch:
fetch_map_batch[key] = fetch_map_batch[key].tolist()
result = {"result": fetch_map_batch} result = {"result": fetch_map_batch}
elif isinstance(feed, dict): elif isinstance(feed, dict):
if "fetch" in feed: if "fetch" in feed:
del feed["fetch"] del feed["fetch"]
fetch_map = self.client_service.predict(feed=feed, fetch=fetch) fetch_map = client_service.predict(feed=feed, fetch=fetch)
for key in fetch_map:
fetch_map[key] = fetch_map[key][0].tolist()
result = self.postprocess( result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map) feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError: except ValueError:
result = {"result": "Request Value Error"} result = {"result": "Request Value Error"}
return result return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
def run_server(self): def run_server(self):
import socket import socket
localIP = socket.gethostbyname(socket.gethostname()) localIP = socket.gethostbyname(socket.gethostname())
......
...@@ -90,26 +90,89 @@ class WebService(object): ...@@ -90,26 +90,89 @@ class WebService(object):
self.client = Client() self.client = Client()
self.client.load_client_config("{}/serving_server_conf.prototxt".format( self.client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config)) self.model_config))
endpoints = ""
if gpu_num > 0: client.connect([endpoint])
while True:
request_json = inputqueue.get()
try:
feed, fetch = self.preprocess(request_json,
request_json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client.predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request_json,
fetch=fetch,
fetch_map=fetch_map_batch)
for key in fetch_map_batch:
fetch_map_batch[key] = fetch_map_batch[key].tolist()
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
for key in fetch_map:
fetch_map[key] = fetch_map[key][0].tolist()
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
except ValueError:
self.output_queue.put(-1)
def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__)
service_name = "/" + self.name + "/prediction"
self.input_queues = []
self.output_queue = Queue()
for i in range(gpu_num): for i in range(gpu_num):
endpoints += "127.0.0.1:{},".format(self.port + i + 1) self.input_queues.append(Queue())
else:
endpoints = "127.0.0.1:{}".format(self.port + 1) producer_list = []
self.client.connect([endpoints]) for i, input_q in enumerate(self.input_queues):
producer_processes = Process(
target=self.producers,
args=(
input_q,
"0.0.0.0:{}".format(self.port + 1 + i), ))
producer_list.append(producer_processes)
for p in producer_list:
p.start()
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect(["0.0.0.0:{}".format(self.port + 1)])
def get_prediction(self, request): self.idx = 0
@app_instance.route(service_name, methods=['POST'])
def get_prediction():
if not request.json: if not request.json:
abort(400) abort(400)
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"])
fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) self.input_queues[self.idx].put(request.json)
fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) #self.input_queues[0].put(request.json)
result = {"result": fetch_map_batch} self.idx += 1
if self.idx >= len(self.gpus):
self.idx = 0
result = self.output_queue.get()
if not isinstance(result, dict) and result == -1:
result = {"result": "Request Value Error"}
return result return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
for p in producer_list:
p.join()
def run_server(self): def run_server(self):
import socket import socket
localIP = socket.gethostbyname(socket.gethostname()) localIP = socket.gethostbyname(socket.gethostname())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册