提交 3e732180 编写于 作者: W wangjiawei04

fix imagenet rpc and web service

上级 966aecc1
...@@ -38,7 +38,8 @@ start = time.time() ...@@ -38,7 +38,8 @@ start = time.time()
image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg" image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"
for i in range(10): for i in range(10):
img = seq(image_file) img = seq(image_file)
fetch_map = client.predict(feed={"image": img}, fetch=["score"]) fetch_map = client.predict(
feed={"image": img}, fetch=["score"], batch=False)
prob = max(fetch_map["score"][0]) prob = max(fetch_map["score"][0])
label = label_dict[fetch_map["score"][0].tolist().index(prob)].strip( label = label_dict[fetch_map["score"][0].tolist().index(prob)].strip(
).replace(",", "") ).replace(",", "")
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
from paddle_serving_client import Client from paddle_serving_client import Client
import numpy as np
from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize, Base64ToImage from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize, Base64ToImage
if len(sys.argv) != 4: if len(sys.argv) != 4:
...@@ -44,12 +44,13 @@ class ImageService(WebService): ...@@ -44,12 +44,13 @@ class ImageService(WebService):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
feed_batch = [] feed_batch = []
is_batch = True
for ins in feed: for ins in feed:
if "image" not in ins: if "image" not in ins:
raise ("feed data error!") raise ("feed data error!")
img = self.seq(ins["image"]) img = self.seq(ins["image"])
feed_batch.append({"image": img[np.newaxis, :]}) feed_batch.append({"image": img[np.newaxis, :]})
return feed_batch, fetch return feed_batch, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map={}): def postprocess(self, feed=[], fetch=[], fetch_map={}):
score_list = fetch_map["score"] score_list = fetch_map["score"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册