ocr_reader.py 7.8 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
import copy
import numpy as np
import math
import re
import sys
import argparse
M
MRXLT 已提交
22
from paddle_serving_app.reader import Sequential, Resize, Transpose, Div, Normalize
M
MRXLT 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156


class CharacterOps(object):
    """ Convert between text-label and text-index """

    def __init__(self, config):
        self.character_type = config['character_type']
        self.loss_type = config['loss_type']
        if self.character_type == "en":
            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
            dict_character = list(self.character_str)
        elif self.character_type == "ch":
            character_dict_path = config['character_dict_path']
            self.character_str = ""
            with open(character_dict_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    line = line.decode('utf-8').strip("\n").strip("\r\n")
                    self.character_str += line
            dict_character = list(self.character_str)
        elif self.character_type == "en_sensitive":
            # same with ASTER setting (use 94 char).
            self.character_str = string.printable[:-6]
            dict_character = list(self.character_str)
        else:
            self.character_str = None
        assert self.character_str is not None, \
            "Nonsupport type of the character: {}".format(self.character_str)
        self.beg_str = "sos"
        self.end_str = "eos"
        if self.loss_type == "attention":
            dict_character = [self.beg_str, self.end_str] + dict_character
        self.dict = {}
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def encode(self, text):
        """convert text-label into text-index.
        input:
            text: text labels of each image. [batch_size]

        output:
            text: concatenated text index for CTCLoss.
                    [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
            length: length of each text. [batch_size]
        """
        if self.character_type == "en":
            text = text.lower()

        text_list = []
        for char in text:
            if char not in self.dict:
                continue
            text_list.append(self.dict[char])
        text = np.array(text_list)
        return text

    def decode(self, text_index, is_remove_duplicate=False):
        """ convert text-index into text-label. """
        char_list = []
        char_num = self.get_char_num()

        if self.loss_type == "attention":
            beg_idx = self.get_beg_end_flag_idx("beg")
            end_idx = self.get_beg_end_flag_idx("end")
            ignored_tokens = [beg_idx, end_idx]
        else:
            ignored_tokens = [char_num]

        for idx in range(len(text_index)):
            if text_index[idx] in ignored_tokens:
                continue
            if is_remove_duplicate:
                if idx > 0 and text_index[idx - 1] == text_index[idx]:
                    continue
            char_list.append(self.character[text_index[idx]])
        text = ''.join(char_list)
        return text

    def get_char_num(self):
        return len(self.character)

    def get_beg_end_flag_idx(self, beg_or_end):
        if self.loss_type == "attention":
            if beg_or_end == "beg":
                idx = np.array(self.dict[self.beg_str])
            elif beg_or_end == "end":
                idx = np.array(self.dict[self.end_str])
            else:
                assert False, "Unsupport type %s in get_beg_end_flag_idx"\
                    % beg_or_end
            return idx
        else:
            err = "error in get_beg_end_flag_idx when using the loss %s"\
                % (self.loss_type)
            assert False, err


class OCRReader(object):
    def __init__(self):
        args = self.parse_args()
        image_shape = [int(v) for v in args.rec_image_shape.split(",")]
        self.rec_image_shape = image_shape
        self.character_type = args.rec_char_type
        self.rec_batch_num = args.rec_batch_num
        char_ops_params = {}
        char_ops_params["character_type"] = args.rec_char_type
        char_ops_params["character_dict_path"] = args.rec_char_dict_path
        char_ops_params['loss_type'] = 'ctc'
        self.char_ops = CharacterOps(char_ops_params)

    def parse_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument("--rec_algorithm", type=str, default='CRNN')
        parser.add_argument("--rec_model_dir", type=str)
        parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
        parser.add_argument("--rec_char_type", type=str, default='ch')
        parser.add_argument("--rec_batch_num", type=int, default=1)
        parser.add_argument(
            "--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt")
        return parser.parse_args()

    def resize_norm_img(self, img, max_wh_ratio):
        imgC, imgH, imgW = self.rec_image_shape
        if self.character_type == "ch":
            imgW = int(32 * max_wh_ratio)
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
M
MRXLT 已提交
157 158 159 160 161 162

        seq = Sequential([
            Resize(imgH, resized_w), Transpose((2, 0, 1)), Div(255),
            Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], True)
        ])
        resized_image = seq(img)
M
MRXLT 已提交
163 164
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
M
MRXLT 已提交
165

M
MRXLT 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        return padding_im

    def preprocess(self, img_list):
        img_num = len(img_list)
        norm_img_batch = []
        max_wh_ratio = 0
        for ino in range(img_num):
            h, w = img_list[ino].shape[0:2]
            wh_ratio = w * 1.0 / h
            max_wh_ratio = max(max_wh_ratio, wh_ratio)
        for ino in range(img_num):
            norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
            norm_img = norm_img[np.newaxis, :]
            norm_img_batch.append(norm_img)
        norm_img_batch = np.concatenate(norm_img_batch)
        norm_img_batch = norm_img_batch.copy()

        return norm_img_batch[0]

W
wangjiawei04 已提交
185
    def postprocess(self, outputs, with_score=False):
M
MRXLT 已提交
186 187 188
        rec_res = []
        rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
        rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
W
wangjiawei04 已提交
189 190
        if with_score:
            predict_lod = outputs["softmax_0.tmp_0.lod"]
M
MRXLT 已提交
191 192 193
        for rno in range(len(rec_idx_lod) - 1):
            beg = rec_idx_lod[rno]
            end = rec_idx_lod[rno + 1]
W
wangjiawei04 已提交
194
            rec_idx_tmp = [rec_idx_batch[i][0] for i in range(beg, end, 1)]
M
MRXLT 已提交
195
            preds_text = self.char_ops.decode(rec_idx_tmp)
W
wangjiawei04 已提交
196 197 198 199 200 201 202 203 204 205 206
            if with_score:
                beg = predict_lod[rno]
                end = predict_lod[rno + 1]
                probs = outputs["softmax_0.tmp_0"][beg:end, :]
                ind = np.argmax(probs, axis=1)
                blank = probs.shape[1]
                valid_ind = np.where(ind != (blank - 1))[0]
                score = np.mean(probs[valid_ind, ind[valid_ind]])
                rec_res.append([preds_text, score])
            else:
                rec_res.append([preds_text])
M
MRXLT 已提交
207
        return rec_res