ocr_rpc_server.py 5.0 KB
Newer Older
W
wangjiawei04 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
import time
import re
import base64
from clas_rpc_server import TextClassifierHelper
from det_rpc_server import TextDetectorHelper
from rec_rpc_server import TextRecognizerHelper
W
wangjiawei04 已提交
27
from tools.infer.predict_system import TextSystem, sorted_boxes
W
wangjiawei04 已提交
28
import copy
29
from params import read_params
W
wangjiawei04 已提交
30

31
global_args = read_params()
W
wangjiawei04 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45
if global_args.use_gpu:
    from paddle_serving_server_gpu.web_service import WebService
else:
    from paddle_serving_server.web_service import WebService


class TextSystemHelper(TextSystem):
    def __init__(self, args):
        self.text_detector = TextDetectorHelper(args)
        self.text_recognizer = TextRecognizerHelper(args)
        self.use_angle_cls = args.use_angle_cls
        if self.use_angle_cls:
            self.clas_client = Client()
            self.clas_client.load_client_config(
M
MissPenguin 已提交
46 47
                os.path.join(args.cls_client_dir, "serving_client_conf.prototxt")
            )
W
wangjiawei04 已提交
48 49 50 51
            self.clas_client.connect(["127.0.0.1:9294"])
            self.text_classifier = TextClassifierHelper(args)
        self.det_client = Client()
        self.det_client.load_client_config(
M
MissPenguin 已提交
52 53
            os.path.join(args.det_client_dir, "serving_client_conf.prototxt")
            )
W
wangjiawei04 已提交
54
        self.det_client.connect(["127.0.0.1:9293"])
W
wangjiawei04 已提交
55
        self.fetch = ["save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"]
W
wangjiawei04 已提交
56 57 58 59 60 61

    def preprocess(self, img):
        feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
        fetch_map = self.det_client.predict(feed, fetch)
        outputs = [fetch_map[x] for x in fetch]
        dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
M
MissPenguin 已提交
62
        # print(dt_boxes)
W
wangjiawei04 已提交
63 64 65 66
        if dt_boxes is None:
            return None, None
        img_crop_list = []
        dt_boxes = sorted_boxes(dt_boxes)
M
MissPenguin 已提交
67
        self.dt_boxes = dt_boxes
W
wangjiawei04 已提交
68 69 70 71 72 73 74 75
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = self.get_rotate_crop_image(img, tmp_box)
            img_crop_list.append(img_crop)
        if self.use_angle_cls:
            feed, fetch, self.tmp_args = self.text_classifier.preprocess(
                img_crop_list)
            fetch_map = self.clas_client.predict(feed, fetch)
M
MissPenguin 已提交
76
            # print(fetch_map)
W
wangjiawei04 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            outputs = [fetch_map[x] for x in self.text_classifier.fetch]
            for x in fetch_map.keys():
                if ".lod" in x:
                    self.tmp_args[x] = fetch_map[x]
            img_crop_list, _ = self.text_classifier.postprocess(outputs,
                                                                self.tmp_args)
        feed, fetch, self.tmp_args = self.text_recognizer.preprocess(
            img_crop_list)
        return feed, self.fetch, self.tmp_args

    def postprocess(self, outputs, args):
        return self.text_recognizer.postprocess(outputs, args)


class OCRService(WebService):
    def init_rec(self):
93
        self.text_system = TextSystemHelper(global_args)
W
wangjiawei04 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

    def preprocess(self, feed=[], fetch=[]):
        # TODO: to handle batch rec images
        data = base64.b64decode(feed[0]["image"].encode('utf8'))
        data = np.fromstring(data, np.uint8)
        im = cv2.imdecode(data, cv2.IMREAD_COLOR)
        feed, fetch, self.tmp_args = self.text_system.preprocess(im)
        return feed, fetch

    def postprocess(self, feed={}, fetch=[], fetch_map=None):
        outputs = [fetch_map[x] for x in self.text_system.fetch]
        for x in fetch_map.keys():
            if ".lod" in x:
                self.tmp_args[x] = fetch_map[x]
        rec_res = self.text_system.postprocess(outputs, self.tmp_args)
M
MissPenguin 已提交
109 110 111 112 113 114 115 116
        res = []
        for i in range(len(rec_res)):
            tmp_res = {
                "text_region": self.text_system.dt_boxes[i].tolist(),
                "text": rec_res[i][0],
                "confidence": float(rec_res[i][1])
            }
            res.append(tmp_res)
W
wangjiawei04 已提交
117 118 119 120 121
        return res


if __name__ == "__main__":
    ocr_service = OCRService(name="ocr")
M
MissPenguin 已提交
122
    ocr_service.load_model_config(global_args.rec_server_dir)
W
wangjiawei04 已提交
123 124 125 126 127 128 129 130
    ocr_service.init_rec()
    if global_args.use_gpu:
        ocr_service.prepare_server(
            workdir="workdir", port=9292, device="gpu", gpuid=0)
    else:
        ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
    ocr_service.run_rpc_service()
    ocr_service.run_web_service()