提交 6ca25e25 编写于 作者: S stephon

fix serving bug of picodet

上级 67d982c3
...@@ -23,6 +23,7 @@ import faiss ...@@ -23,6 +23,7 @@ import faiss
import pickle import pickle
import json import json
class DetOp(Op): class DetOp(Op):
def init_op(self): def init_op(self):
self.img_preprocess = Sequential([ self.img_preprocess = Sequential([
...@@ -62,39 +63,46 @@ class DetOp(Op): ...@@ -62,39 +63,46 @@ class DetOp(Op):
im_scale_y, im_scale_x = self.generate_scale(raw_im) im_scale_y, im_scale_x = self.generate_scale(raw_im)
im = self.img_preprocess(raw_im) im = self.img_preprocess(raw_im)
imgs.append({ imgs.append({
"image": im[np.newaxis, :], "image": im[np.newaxis, :],
"im_shape": np.array(list(im.shape[1:])).reshape(-1)[np.newaxis,:], "im_shape":
"scale_factor": np.array([im_scale_y, im_scale_x]).astype('float32'), np.array(list(im.shape[1:])).reshape(-1)[np.newaxis, :],
"scale_factor":
np.array([im_scale_y, im_scale_x]).reshape(-1)[np.newaxis, :],
}) })
self.raw_img = raw_imgs self.raw_img = raw_imgs
feed_dict = { feed_dict = {
"image": np.concatenate([x["image"] for x in imgs], axis=0), "image": np.concatenate(
"im_shape": np.concatenate([x["im_shape"] for x in imgs], axis=0), [x["image"] for x in imgs], axis=0),
"scale_factor": np.concatenate([x["scale_factor"] for x in imgs], axis=0) "im_shape": np.concatenate(
[x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate(
[x["scale_factor"] for x in imgs], axis=0)
} }
return feed_dict, False, None, "" return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): def postprocess(self, input_dicts, fetch_dict, log_id):
boxes = self.img_postprocess(fetch_dict, visualize=False) boxes = self.img_postprocess(fetch_dict, visualize=False)
boxes.sort(key = lambda x: x["score"], reverse = True) boxes.sort(key=lambda x: x["score"], reverse=True)
boxes = filter(lambda x: x["score"] >= self.threshold, boxes[:self.max_det_results]) boxes = filter(lambda x: x["score"] >= self.threshold,
boxes[:self.max_det_results])
boxes = list(boxes) boxes = list(boxes)
for i in range(len(boxes)): for i in range(len(boxes)):
boxes[i]["bbox"][2] += boxes[i]["bbox"][0] - 1 boxes[i]["bbox"][2] += boxes[i]["bbox"][0] - 1
boxes[i]["bbox"][3] += boxes[i]["bbox"][1] - 1 boxes[i]["bbox"][3] += boxes[i]["bbox"][1] - 1
result = json.dumps(boxes) result = json.dumps(boxes)
res_dict = {"bbox_result": result, "image": self.raw_img} res_dict = {"bbox_result": result, "image": self.raw_img}
return res_dict, None, "" return res_dict, None, ""
class RecOp(Op): class RecOp(Op):
def init_op(self): def init_op(self):
self.seq = Sequential([ self.seq = Sequential([
BGR2RGB(), Resize((224, 224)), BGR2RGB(), Resize((224, 224)), Div(255),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
False), Transpose((2, 0, 1)) False), Transpose((2, 0, 1))
]) ])
index_dir = "../../drink_dataset_v1.0/index" index_dir = "../../drink_dataset_v1.0/index"
...@@ -102,10 +110,10 @@ class RecOp(Op): ...@@ -102,10 +110,10 @@ class RecOp(Op):
index_dir, "vector.index")), "vector.index not found ..." index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join( assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... " index_dir, "id_map.pkl")), "id_map.pkl not found ... "
self.searcher = faiss.read_index( self.searcher = faiss.read_index(
os.path.join(index_dir, "vector.index")) os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
self.id_map = pickle.load(fd) self.id_map = pickle.load(fd)
...@@ -121,24 +129,25 @@ class RecOp(Op): ...@@ -121,24 +129,25 @@ class RecOp(Op):
origin_img = cv2.imdecode(data, cv2.IMREAD_COLOR) origin_img = cv2.imdecode(data, cv2.IMREAD_COLOR)
dt_boxes = input_dict["bbox_result"] dt_boxes = input_dict["bbox_result"]
boxes = json.loads(dt_boxes) boxes = json.loads(dt_boxes)
boxes.append({"category_id": 0, boxes.append({
"score": 1.0, "category_id": 0,
"bbox": [0, 0, origin_img.shape[1], origin_img.shape[0]] "score": 1.0,
}) "bbox": [0, 0, origin_img.shape[1], origin_img.shape[0]]
})
self.det_boxes = boxes self.det_boxes = boxes
#construct batch images for rec #construct batch images for rec
imgs = [] imgs = []
for box in boxes: for box in boxes:
box = [int(x) for x in box["bbox"]] box = [int(x) for x in box["bbox"]]
im = origin_img[box[1]: box[3], box[0]: box[2]].copy() im = origin_img[box[1]:box[3], box[0]:box[2]].copy()
img = self.seq(im) img = self.seq(im)
imgs.append(img[np.newaxis, :].copy()) imgs.append(img[np.newaxis, :].copy())
input_imgs = np.concatenate(imgs, axis=0) input_imgs = np.concatenate(imgs, axis=0)
return {"x": input_imgs}, False, None, "" return {"x": input_imgs}, False, None, ""
def nms_to_rec_results(self, results, thresh = 0.1): def nms_to_rec_results(self, results, thresh=0.1):
filtered_results = [] filtered_results = []
x1 = np.array([r["bbox"][0] for r in results]).astype("float32") x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
y1 = np.array([r["bbox"][1] for r in results]).astype("float32") y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
...@@ -172,7 +181,7 @@ class RecOp(Op): ...@@ -172,7 +181,7 @@ class RecOp(Op):
np.sum(np.square(batch_features), axis=1, keepdims=True)) np.sum(np.square(batch_features), axis=1, keepdims=True))
batch_features = np.divide(batch_features, feas_norm) batch_features = np.divide(batch_features, feas_norm)
scores, docs = self.searcher.search(batch_features, self.return_k) scores, docs = self.searcher.search(batch_features, self.return_k)
results = [] results = []
for i in range(scores.shape[0]): for i in range(scores.shape[0]):
...@@ -182,17 +191,19 @@ class RecOp(Op): ...@@ -182,17 +191,19 @@ class RecOp(Op):
pred["rec_docs"] = self.id_map[docs[i][0]].split()[1] pred["rec_docs"] = self.id_map[docs[i][0]].split()[1]
pred["rec_scores"] = scores[i][0] pred["rec_scores"] = scores[i][0]
results.append(pred) results.append(pred)
#do nms #do nms
results = self.nms_to_rec_results(results, self.rec_nms_thresold) results = self.nms_to_rec_results(results, self.rec_nms_thresold)
return {"result": str(results)}, None, "" return {"result": str(results)}, None, ""
class RecognitionService(WebService): class RecognitionService(WebService):
def get_pipeline_response(self, read_op): def get_pipeline_response(self, read_op):
det_op = DetOp(name="det", input_ops=[read_op]) det_op = DetOp(name="det", input_ops=[read_op])
rec_op = RecOp(name="rec", input_ops=[det_op]) rec_op = RecOp(name="rec", input_ops=[det_op])
return rec_op return rec_op
product_recog_service = RecognitionService(name="recognition") product_recog_service = RecognitionService(name="recognition")
product_recog_service.prepare_pipeline_config("config.yml") product_recog_service.prepare_pipeline_config("config.yml")
product_recog_service.run_service() product_recog_service.run_service()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册