提交 9db05796 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #299 from MRXLT/general-server-py3

add batch predict for web service
...@@ -25,11 +25,21 @@ class ImageService(WebService): ...@@ -25,11 +25,21 @@ class ImageService(WebService):
reader = ImageReader() reader = ImageReader()
if "image" not in feed: if "image" not in feed:
raise ("feed data error!") raise ("feed data error!")
sample = base64.b64decode(feed["image"]) if isinstance(feed["image"], list):
img = reader.process_image(sample) feed_batch = []
res_feed = {} for image in feed["image"]:
res_feed["image"] = img.reshape(-1) sample = base64.b64decode(image)
return res_feed, fetch img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
image_service = ImageService(name="image") image_service = ImageService(name="image")
......
...@@ -25,16 +25,27 @@ class ImageService(WebService): ...@@ -25,16 +25,27 @@ class ImageService(WebService):
reader = ImageReader() reader = ImageReader()
if "image" not in feed: if "image" not in feed:
raise ("feed data error!") raise ("feed data error!")
sample = base64.b64decode(feed["image"]) print(type(feed["image"]), isinstance(feed["image"], list))
img = reader.process_image(sample) if isinstance(feed["image"], list):
res_feed = {} feed_batch = []
res_feed["image"] = img.reshape(-1) for image in feed["image"]:
return res_feed, fetch sample = base64.b64decode(image)
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
image_service = ImageService(name="image") image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1]) image_service.load_model_config(sys.argv[1])
image_service.set_gpus("0,1,2,3") image_service.set_gpus("0,1")
image_service.prepare_server( image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server() image_service.run_server()
...@@ -24,17 +24,26 @@ def predict(image_path, server): ...@@ -24,17 +24,26 @@ def predict(image_path, server):
req = json.dumps({"image": image, "fetch": ["score"]}) req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post( r = requests.post(
server, data=req, headers={"Content-Type": "application/json"}) server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["score"][0])
return r
def batch_predict(image_path, server):
image = base64.b64encode(open(image_path).read())
req = json.dumps({"image": [image, image], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["result"][1]["score"][0])
return r return r
if __name__ == "__main__": if __name__ == "__main__":
server = "http://127.0.0.1:9295/image/prediction" server = "http://127.0.0.1:9393/image/prediction"
#image_path = "./data/n01440764_10026.JPEG" #image_path = "./data/n01440764_10026.JPEG"
image_list = os.listdir("./data/image_data/n01440764/") image_list = os.listdir("./image_data/n01440764/")
start = time.time() start = time.time()
for img in image_list: for img in image_list:
image_file = "./data/image_data/n01440764/" + img image_file = "./image_data/n01440764/" + img
res = predict(image_file, server) res = predict(image_file, server)
print(res.json()["score"][0])
end = time.time() end = time.time()
print(end - start) print(end - start)
...@@ -64,12 +64,19 @@ class WebService(object): ...@@ -64,12 +64,19 @@ class WebService(object):
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"]) feed, fetch = self.preprocess(request.json, request.json["fetch"])
if "fetch" in feed: if isinstance(feed, list):
del feed["fetch"] fetch_map_batch = client_service.batch_predict(
fetch_map = client_service.predict(feed=feed, fetch=fetch) feed_batch=feed, fetch=fetch)
fetch_map = self.postprocess( fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map) feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
return fetch_map result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
return result
app_instance.run(host="0.0.0.0", app_instance.run(host="0.0.0.0",
port=self.port, port=self.port,
...@@ -92,5 +99,5 @@ class WebService(object): ...@@ -92,5 +99,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}): def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map return fetch_map
...@@ -23,14 +23,14 @@ from multiprocessing import Pool, Process ...@@ -23,14 +23,14 @@ from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args from paddle_serving_server_gpu import serve_args
def start_gpu_card_model(gpuid, args): # pylint: disable=doc-string-missing def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing
gpuid = int(gpuid) gpuid = int(gpuid)
device = "gpu" device = "gpu"
port = args.port port = args.port
if gpuid == -1: if gpuid == -1:
device = "cpu" device = "cpu"
elif gpuid >= 0: elif gpuid >= 0:
port = args.port + gpuid port = args.port + index
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
workdir = "{}_{}".format(args.workdir, gpuid) workdir = "{}_{}".format(args.workdir, gpuid)
...@@ -78,6 +78,7 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -78,6 +78,7 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
p = Process( p = Process(
target=start_gpu_card_model, args=( target=start_gpu_card_model, args=(
i, i,
gpu_id,
args, )) args, ))
gpu_processes.append(p) gpu_processes.append(p)
for p in gpu_processes: for p in gpu_processes:
......
...@@ -95,12 +95,20 @@ class WebService(object): ...@@ -95,12 +95,20 @@ class WebService(object):
while True: while True:
request_json = inputqueue.get() request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"]) feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed: if isinstance(feed, list):
del feed["fetch"] fetch_map_batch = client.batch_predict(
fetch_map = client.predict(feed=feed, fetch=fetch) feed_batch=feed, fetch=fetch)
fetch_map = self.postprocess( fetch_map_batch = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map) feed=request_json, fetch=fetch, fetch_map=fetch_map_batch)
self.output_queue.put(fetch_map) result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
def _launch_web_service(self, gpu_num): def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__) app_instance = Flask(__name__)
...@@ -186,5 +194,5 @@ class WebService(object): ...@@ -186,5 +194,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}): def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册