ocr_rpc_server.py 4.7 KB
Newer Older
W
wangjiawei04 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle_serving_client import Client
import cv2
import sys
import numpy as np
import os
import time
import re
import base64
from clas_rpc_server import TextClassifierHelper
from det_rpc_server import TextDetectorHelper
from rec_rpc_server import TextRecognizerHelper
W
wangjiawei04 已提交
26
from tools.infer.predict_system import TextSystem, sorted_boxes
W
wangjiawei04 已提交
27
import copy
28
from params import read_params
W
wangjiawei04 已提交
29

30
global_args = read_params()
W
wangjiawei04 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44
if global_args.use_gpu:
    from paddle_serving_server_gpu.web_service import WebService
else:
    from paddle_serving_server.web_service import WebService


class TextSystemHelper(TextSystem):
    def __init__(self, args):
        self.text_detector = TextDetectorHelper(args)
        self.text_recognizer = TextRecognizerHelper(args)
        self.use_angle_cls = args.use_angle_cls
        if self.use_angle_cls:
            self.clas_client = Client()
            self.clas_client.load_client_config(
W
wangjiawei04 已提交
45
                "cls_infer_client/serving_client_conf.prototxt")
W
wangjiawei04 已提交
46 47 48 49
            self.clas_client.connect(["127.0.0.1:9294"])
            self.text_classifier = TextClassifierHelper(args)
        self.det_client = Client()
        self.det_client.load_client_config(
W
wangjiawei04 已提交
50
            "det_infer_client/serving_client_conf.prototxt")
W
wangjiawei04 已提交
51
        self.det_client.connect(["127.0.0.1:9293"])
W
wangjiawei04 已提交
52
        self.fetch = ["save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"]
W
wangjiawei04 已提交
53 54 55 56 57 58 59 60 61 62

    def preprocess(self, img):
        feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
        fetch_map = self.det_client.predict(feed, fetch)
        outputs = [fetch_map[x] for x in fetch]
        dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
        if dt_boxes is None:
            return None, None
        img_crop_list = []
        dt_boxes = sorted_boxes(dt_boxes)
W
wangjiawei04 已提交
63
        self.dt_boxes = dt_boxes
W
wangjiawei04 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = self.get_rotate_crop_image(img, tmp_box)
            img_crop_list.append(img_crop)
        if self.use_angle_cls:
            feed, fetch, self.tmp_args = self.text_classifier.preprocess(
                img_crop_list)
            fetch_map = self.clas_client.predict(feed, fetch)
            outputs = [fetch_map[x] for x in self.text_classifier.fetch]
            for x in fetch_map.keys():
                if ".lod" in x:
                    self.tmp_args[x] = fetch_map[x]
            img_crop_list, _ = self.text_classifier.postprocess(outputs,
                                                                self.tmp_args)
        feed, fetch, self.tmp_args = self.text_recognizer.preprocess(
            img_crop_list)
        return feed, self.fetch, self.tmp_args

    def postprocess(self, outputs, args):
        return self.text_recognizer.postprocess(outputs, args)


class OCRService(WebService):
    def init_rec(self):
88
        self.text_system = TextSystemHelper(global_args)
W
wangjiawei04 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

    def preprocess(self, feed=[], fetch=[]):
        # TODO: to handle batch rec images
        data = base64.b64decode(feed[0]["image"].encode('utf8'))
        data = np.fromstring(data, np.uint8)
        im = cv2.imdecode(data, cv2.IMREAD_COLOR)
        feed, fetch, self.tmp_args = self.text_system.preprocess(im)
        return feed, fetch

    def postprocess(self, feed={}, fetch=[], fetch_map=None):
        outputs = [fetch_map[x] for x in self.text_system.fetch]
        for x in fetch_map.keys():
            if ".lod" in x:
                self.tmp_args[x] = fetch_map[x]
        rec_res = self.text_system.postprocess(outputs, self.tmp_args)
        res = {
            "pred_text": [x[0] for x in rec_res],
W
wangjiawei04 已提交
106 107
            "score": [str(x[1]) for x in rec_res],
            "pos": [x.tolist() for x in self.text_system.dt_boxes]
W
wangjiawei04 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
        }
        return res


if __name__ == "__main__":
    ocr_service = OCRService(name="ocr")
    ocr_service.load_model_config(global_args.rec_model_dir)
    ocr_service.init_rec()
    if global_args.use_gpu:
        ocr_service.prepare_server(
            workdir="workdir", port=9292, device="gpu", gpuid=0)
    else:
        ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
    ocr_service.run_rpc_service()
    ocr_service.run_web_service()