提交 7bd372c7 编写于 作者: M MRXLT

add batch predict for HTTP service

上级 6858647b
......@@ -25,6 +25,16 @@ class ImageService(WebService):
reader = ImageReader()
if "image" not in feed:
raise ("feed data error!")
if isinstance(feed["image"], list):
feed_batch = []
for image in feed["image"]:
sample = base64.b64decode(image)
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
......
......@@ -25,6 +25,17 @@ class ImageService(WebService):
reader = ImageReader()
if "image" not in feed:
raise ("feed data error!")
print(type(feed["image"]), isinstance(feed["image"], list))
if isinstance(feed["image"], list):
feed_batch = []
for image in feed["image"]:
sample = base64.b64decode(image)
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
......@@ -34,7 +45,7 @@ class ImageService(WebService):
image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1])
image_service.set_gpus("0,1,2,3")
image_service.set_gpus("0,1")
image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server()
......@@ -24,17 +24,26 @@ def predict(image_path, server):
req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["score"][0])
return r
def batch_predict(image_path, server):
image = base64.b64encode(open(image_path).read())
req = json.dumps({"image": [image, image], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["result"][1]["score"][0])
return r
if __name__ == "__main__":
server = "http://127.0.0.1:9295/image/prediction"
server = "http://127.0.0.1:9393/image/prediction"
#image_path = "./data/n01440764_10026.JPEG"
image_list = os.listdir("./data/image_data/n01440764/")
image_list = os.listdir("./image_data/n01440764/")
start = time.time()
for img in image_list:
image_file = "./data/image_data/n01440764/" + img
image_file = "./image_data/n01440764/" + img
res = predict(image_file, server)
print(res.json()["score"][0])
end = time.time()
print(end - start)
......@@ -64,12 +64,19 @@ class WebService(object):
if "fetch" not in request.json:
abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client_service.batch_predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
return fetch_map
return result
app_instance.run(host="0.0.0.0",
port=self.port,
......@@ -92,5 +99,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]):
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}):
def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map
......@@ -95,12 +95,20 @@ class WebService(object):
while True:
request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client.batch_predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(fetch_map)
self.output_queue.put(result)
def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__)
......@@ -186,5 +194,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]):
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}):
def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册