提交 3bdcadaf 编写于 作者: W wangjiawei04

add imagenet to pipeline

上级 9b20e199
...@@ -21,12 +21,13 @@ import logging ...@@ -21,12 +21,13 @@ import logging
import numpy as np import numpy as np
import base64, cv2 import base64, cv2
class ImagenetOp(Op): class ImagenetOp(Op):
def init_op(self): def init_op(self):
self.seq = Sequential([ self.seq = Sequential([
Resize(256), CenterCrop(224), RGB2BGR(), Transpose( Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
(2, 0, 1)), Div(255), Normalize([0.485, 0.456, 0.406], Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
[0.229, 0.224, 0.225], True) True)
]) ])
self.label_dict = {} self.label_dict = {}
label_idx = 0 label_idx = 0
...@@ -42,8 +43,7 @@ class ImagenetOp(Op): ...@@ -42,8 +43,7 @@ class ImagenetOp(Op):
# Note: class variables(self.var) can only be used in process op mode # Note: class variables(self.var) can only be used in process op mode
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img = self.seq(im) img = self.seq(im)
return {"image": img[np.newaxis,:].copy()}, False, None, "" return {"image": img[np.newaxis, :].copy()}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): def postprocess(self, input_dicts, fetch_dict, log_id):
print(fetch_dict) print(fetch_dict)
...@@ -53,13 +53,12 @@ class ImagenetOp(Op): ...@@ -53,13 +53,12 @@ class ImagenetOp(Op):
score = score.tolist() score = score.tolist()
max_score = max(score) max_score = max(score)
#result["label"].append(self.label_dict[score.index(max_score)] #result["label"].append(self.label_dict[score.index(max_score)]
#.strip().replace(",", "")) #.strip().replace(",", ""))
#result["prob"].append(max_score) #result["prob"].append(max_score)
#print(result) #print(result)
return result, None, "" return result, None, ""
class ImageService(WebService): class ImageService(WebService):
def get_pipeline_response(self, read_op): def get_pipeline_response(self, read_op):
image_op = ImagenetOp(name="imagenet", input_ops=[read_op]) image_op = ImagenetOp(name="imagenet", input_ops=[read_op])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册