提交 8781ca33 编写于 作者: M MRXLT

add batch predict for HTTP service

上级 b0bda428
...@@ -25,11 +25,21 @@ class ImageService(WebService): ...@@ -25,11 +25,21 @@ class ImageService(WebService):
reader = ImageReader() reader = ImageReader()
if "image" not in feed: if "image" not in feed:
raise ("feed data error!") raise ("feed data error!")
sample = base64.b64decode(feed["image"]) if isinstance(feed["image"], list):
img = reader.process_image(sample) feed_batch = []
res_feed = {} for image in feed["image"]:
res_feed["image"] = img.reshape(-1) sample = base64.b64decode(image)
return res_feed, fetch img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
image_service = ImageService(name="image") image_service = ImageService(name="image")
......
...@@ -25,16 +25,27 @@ class ImageService(WebService): ...@@ -25,16 +25,27 @@ class ImageService(WebService):
reader = ImageReader() reader = ImageReader()
if "image" not in feed: if "image" not in feed:
raise ("feed data error!") raise ("feed data error!")
sample = base64.b64decode(feed["image"]) print(type(feed["image"]), isinstance(feed["image"], list))
img = reader.process_image(sample) if isinstance(feed["image"], list):
res_feed = {} feed_batch = []
res_feed["image"] = img.reshape(-1) for image in feed["image"]:
return res_feed, fetch sample = base64.b64decode(image)
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
image_service = ImageService(name="image") image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1]) image_service.load_model_config(sys.argv[1])
image_service.set_gpus("0,1,2,3") image_service.set_gpus("0,1")
image_service.prepare_server( image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server() image_service.run_server()
...@@ -24,17 +24,26 @@ def predict(image_path, server): ...@@ -24,17 +24,26 @@ def predict(image_path, server):
req = json.dumps({"image": image, "fetch": ["score"]}) req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post( r = requests.post(
server, data=req, headers={"Content-Type": "application/json"}) server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["score"][0])
return r
def batch_predict(image_path, server):
image = base64.b64encode(open(image_path).read())
req = json.dumps({"image": [image, image], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["result"][1]["score"][0])
return r return r
if __name__ == "__main__": if __name__ == "__main__":
server = "http://127.0.0.1:9295/image/prediction" server = "http://127.0.0.1:9393/image/prediction"
#image_path = "./data/n01440764_10026.JPEG" #image_path = "./data/n01440764_10026.JPEG"
image_list = os.listdir("./data/image_data/n01440764/") image_list = os.listdir("./image_data/n01440764/")
start = time.time() start = time.time()
for img in image_list: for img in image_list:
image_file = "./data/image_data/n01440764/" + img image_file = "./image_data/n01440764/" + img
res = predict(image_file, server) res = predict(image_file, server)
print(res.json()["score"][0])
end = time.time() end = time.time()
print(end - start) print(end - start)
...@@ -64,12 +64,19 @@ class WebService(object): ...@@ -64,12 +64,19 @@ class WebService(object):
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"]) feed, fetch = self.preprocess(request.json, request.json["fetch"])
if "fetch" in feed: if isinstance(feed, list):
del feed["fetch"] fetch_map_batch = client_service.batch_predict(
fetch_map = client_service.predict(feed=feed, fetch=fetch) feed_batch=feed, fetch=fetch)
fetch_map = self.postprocess( fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map) feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
return fetch_map result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
return result
app_instance.run(host="0.0.0.0", app_instance.run(host="0.0.0.0",
port=self.port, port=self.port,
...@@ -92,5 +99,5 @@ class WebService(object): ...@@ -92,5 +99,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}): def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map return fetch_map
...@@ -95,12 +95,20 @@ class WebService(object): ...@@ -95,12 +95,20 @@ class WebService(object):
while True: while True:
request_json = inputqueue.get() request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"]) feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed: if isinstance(feed, list):
del feed["fetch"] fetch_map_batch = client.batch_predict(
fetch_map = client.predict(feed=feed, fetch=fetch) feed_batch=feed, fetch=fetch)
fetch_map = self.postprocess( fetch_map_batch = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map) feed=request_json, fetch=fetch, fetch_map=fetch_map_batch)
self.output_queue.put(fetch_map) result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
def _launch_web_service(self, gpu_num): def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__) app_instance = Flask(__name__)
...@@ -186,5 +194,5 @@ class WebService(object): ...@@ -186,5 +194,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}): def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册