recognition_web_service.py 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from paddle_serving_server.web_service import WebService, Op
15 16
import logging
import numpy as np
17 18 19 20
import sys
import cv2
from paddle_serving_app.reader import *
import base64
21 22 23
import os
import faiss
import pickle
24
import json
25

S
stephon 已提交
26

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
class DetOp(Op):
    def init_op(self):
        self.img_preprocess = Sequential([
            BGR2RGB(), Div(255.0),
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
            Resize((640, 640)), Transpose((2, 0, 1))
        ])

        self.img_postprocess = RCNNPostprocess("label_list.txt", "output")
        self.threshold = 0.2
        self.max_det_results = 5

    def generate_scale(self, im):
        """
        Args:
            im (np.ndarray): image (np.ndarray)
        Returns:
            im_scale_x: the resize ratio of X
            im_scale_y: the resize ratio of Y
        """
        target_size = [640, 640]
        origin_shape = im.shape[:2]
        resize_h, resize_w = target_size
        im_scale_y = resize_h / float(origin_shape[0])
        im_scale_x = resize_w / float(origin_shape[1])
        return im_scale_y, im_scale_x

    def preprocess(self, input_dicts, data_id, log_id):
        (_, input_dict), = input_dicts.items()
        imgs = []
        raw_imgs = []
        for key in input_dict.keys():
            data = base64.b64decode(input_dict[key].encode('utf8'))
            raw_imgs.append(data)
            data = np.fromstring(data, np.uint8)
            raw_im = cv2.imdecode(data, cv2.IMREAD_COLOR)

            im_scale_y, im_scale_x = self.generate_scale(raw_im)
            im = self.img_preprocess(raw_im)
S
stephon 已提交
66

S
stephon 已提交
67
            im_shape = np.array(im.shape[1:]).reshape(-1)
S
stephon 已提交
68
            scale_factor = np.array([im_scale_y, im_scale_x]).reshape(-1)
69
            imgs.append({
S
stephon 已提交
70
                "image": im[np.newaxis, :],
S
stephon 已提交
71 72
                "im_shape": im_shape[np.newaxis, :],
                "scale_factor": scale_factor[np.newaxis, :],
73 74 75 76
            })
        self.raw_img = raw_imgs

        feed_dict = {
S
stephon 已提交
77 78 79 80 81 82
            "image": np.concatenate(
                [x["image"] for x in imgs], axis=0),
            "im_shape": np.concatenate(
                [x["im_shape"] for x in imgs], axis=0),
            "scale_factor": np.concatenate(
                [x["scale_factor"] for x in imgs], axis=0)
83
        }
S
stephon 已提交
84
        return feed_dict, False, None, ""
85

S
stephon 已提交
86
    def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
87
        boxes = self.img_postprocess(fetch_dict, visualize=False)
S
stephon 已提交
88 89 90
        boxes.sort(key=lambda x: x["score"], reverse=True)
        boxes = filter(lambda x: x["score"] >= self.threshold,
                       boxes[:self.max_det_results])
91 92 93 94 95 96
        boxes = list(boxes)
        for i in range(len(boxes)):
            boxes[i]["bbox"][2] += boxes[i]["bbox"][0] - 1
            boxes[i]["bbox"][3] += boxes[i]["bbox"][1] - 1
        result = json.dumps(boxes)
        res_dict = {"bbox_result": result, "image": self.raw_img}
S
stephon 已提交
97 98
        return res_dict, None, ""

99 100

class RecOp(Op):
101 102
    def init_op(self):
        self.seq = Sequential([
S
stephon 已提交
103 104 105
            BGR2RGB(), Resize((224, 224)), Div(255),
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
                      False), Transpose((2, 0, 1))
106 107
        ])

B
Bin Lu 已提交
108
        index_dir = "../../drink_dataset_v1.0/index"
109 110 111 112
        assert os.path.exists(os.path.join(
            index_dir, "vector.index")), "vector.index not found ..."
        assert os.path.exists(os.path.join(
            index_dir, "id_map.pkl")), "id_map.pkl not found ... "
S
stephon 已提交
113

B
Bin Lu 已提交
114
        self.searcher = faiss.read_index(
115
            os.path.join(index_dir, "vector.index"))
S
stephon 已提交
116

117 118 119
        with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
            self.id_map = pickle.load(fd)

120 121
        self.rec_nms_thresold = 0.05
        self.rec_score_thres = 0.5
122 123
        self.feature_normalize = True
        self.return_k = 1
124

125 126
    def preprocess(self, input_dicts, data_id, log_id):
        (_, input_dict), = input_dicts.items()
127 128 129 130 131
        raw_img = input_dict["image"][0]
        data = np.frombuffer(raw_img, np.uint8)
        origin_img = cv2.imdecode(data, cv2.IMREAD_COLOR)
        dt_boxes = input_dict["bbox_result"]
        boxes = json.loads(dt_boxes)
S
stephon 已提交
132 133 134 135 136
        boxes.append({
            "category_id": 0,
            "score": 1.0,
            "bbox": [0, 0, origin_img.shape[1], origin_img.shape[0]]
        })
137 138 139
        self.det_boxes = boxes

        #construct batch images for rec
140
        imgs = []
141 142
        for box in boxes:
            box = [int(x) for x in box["bbox"]]
S
stephon 已提交
143
            im = origin_img[box[1]:box[3], box[0]:box[2]].copy()
144 145
            img = self.seq(im)
            imgs.append(img[np.newaxis, :].copy())
146

147
        input_imgs = np.concatenate(imgs, axis=0)
S
stephon 已提交
148
        return {"x": input_imgs}, False, None, ""
149

S
stephon 已提交
150
    def nms_to_rec_results(self, results, thresh=0.1):
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        filtered_results = []
        x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
        y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
        x2 = np.array([r["bbox"][2] for r in results]).astype("float32")
        y2 = np.array([r["bbox"][3] for r in results]).astype("float32")
        scores = np.array([r["rec_scores"] for r in results])

        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
        order = scores.argsort()[::-1]
        while order.size > 0:
            i = order[0]
            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.minimum(x2[i], x2[order[1:]])
            yy2 = np.minimum(y2[i], y2[order[1:]])

            w = np.maximum(0.0, xx2 - xx1 + 1)
            h = np.maximum(0.0, yy2 - yy1 + 1)
            inter = w * h
            ovr = inter / (areas[i] + areas[order[1:]] - inter)
            inds = np.where(ovr <= thresh)[0]
            order = order[inds + 1]
            filtered_results.append(results[i])
        return filtered_results
175

S
stephon 已提交
176
    def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
177 178 179 180 181 182 183
        batch_features = fetch_dict["features"]

        if self.feature_normalize:
            feas_norm = np.sqrt(
                np.sum(np.square(batch_features), axis=1, keepdims=True))
            batch_features = np.divide(batch_features, feas_norm)

S
stephon 已提交
184
        scores, docs = self.searcher.search(batch_features, self.return_k)
185

186 187 188 189
        results = []
        for i in range(scores.shape[0]):
            pred = {}
            if scores[i][0] >= self.rec_score_thres:
190
                pred["bbox"] = [int(x) for x in self.det_boxes[i]["bbox"]]
191 192 193
                pred["rec_docs"] = self.id_map[docs[i][0]].split()[1]
                pred["rec_scores"] = scores[i][0]
                results.append(pred)
S
stephon 已提交
194

195 196 197
        #do nms
        results = self.nms_to_rec_results(results, self.rec_nms_thresold)
        return {"result": str(results)}, None, ""
198

S
stephon 已提交
199

200
class RecognitionService(WebService):
201
    def get_pipeline_response(self, read_op):
202 203 204
        det_op = DetOp(name="det", input_ops=[read_op])
        rec_op = RecOp(name="rec", input_ops=[det_op])
        return rec_op
205

S
stephon 已提交
206

207 208 209
product_recog_service = RecognitionService(name="recognition")
product_recog_service.prepare_pipeline_config("config.yml")
product_recog_service.run_service()