提交 9db05796 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #299 from MRXLT/general-server-py3

add batch predict for web service
......@@ -25,11 +25,21 @@ class ImageService(WebService):
reader = ImageReader()
if "image" not in feed:
raise ("feed data error!")
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
if isinstance(feed["image"], list):
feed_batch = []
for image in feed["image"]:
sample = base64.b64decode(image)
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
image_service = ImageService(name="image")
......
......@@ -25,16 +25,27 @@ class ImageService(WebService):
reader = ImageReader()
if "image" not in feed:
raise ("feed data error!")
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
print(type(feed["image"]), isinstance(feed["image"], list))
if isinstance(feed["image"], list):
feed_batch = []
for image in feed["image"]:
sample = base64.b64decode(image)
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
feed_batch.append(res_feed)
return feed_batch, fetch
else:
sample = base64.b64decode(feed["image"])
img = reader.process_image(sample)
res_feed = {}
res_feed["image"] = img.reshape(-1)
return res_feed, fetch
image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1])
image_service.set_gpus("0,1,2,3")
image_service.set_gpus("0,1")
image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server()
......@@ -24,17 +24,26 @@ def predict(image_path, server):
req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["score"][0])
return r
def batch_predict(image_path, server):
image = base64.b64encode(open(image_path).read())
req = json.dumps({"image": [image, image], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["result"][1]["score"][0])
return r
if __name__ == "__main__":
server = "http://127.0.0.1:9295/image/prediction"
server = "http://127.0.0.1:9393/image/prediction"
#image_path = "./data/n01440764_10026.JPEG"
image_list = os.listdir("./data/image_data/n01440764/")
image_list = os.listdir("./image_data/n01440764/")
start = time.time()
for img in image_list:
image_file = "./data/image_data/n01440764/" + img
image_file = "./image_data/n01440764/" + img
res = predict(image_file, server)
print(res.json()["score"][0])
end = time.time()
print(end - start)
......@@ -64,12 +64,19 @@ class WebService(object):
if "fetch" not in request.json:
abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"])
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
return fetch_map
if isinstance(feed, list):
fetch_map_batch = client_service.batch_predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
return result
app_instance.run(host="0.0.0.0",
port=self.port,
......@@ -92,5 +99,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]):
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}):
def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map
......@@ -23,14 +23,14 @@ from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args
def start_gpu_card_model(gpuid, args): # pylint: disable=doc-string-missing
def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing
gpuid = int(gpuid)
device = "gpu"
port = args.port
if gpuid == -1:
device = "cpu"
elif gpuid >= 0:
port = args.port + gpuid
port = args.port + index
thread_num = args.thread
model = args.model
workdir = "{}_{}".format(args.workdir, gpuid)
......@@ -78,6 +78,7 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
p = Process(
target=start_gpu_card_model, args=(
i,
gpu_id,
args, ))
gpu_processes.append(p)
for p in gpu_processes:
......
......@@ -95,12 +95,20 @@ class WebService(object):
while True:
request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(fetch_map)
if isinstance(feed, list):
fetch_map_batch = client.batch_predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(result)
def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__)
......@@ -186,5 +194,5 @@ class WebService(object):
def preprocess(self, feed={}, fetch=[]):
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map={}):
def postprocess(self, feed={}, fetch=[], fetch_map=None):
return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册