character.py 8.9 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import string
import re
from .check import check_config_params
import sys


class CharacterOps(object):
T
tink2123 已提交
23 24 25 26 27
    """
    Convert between text-label and text-index
    Args:
        config: config from yaml file
    """
L
LDOUBLEV 已提交
28 29 30 31

    def __init__(self, config):
        self.character_type = config['character_type']
        self.loss_type = config['loss_type']
T
fix bug  
tink2123 已提交
32
        self.max_text_len = config['max_text_length']
T
tink2123 已提交
33
        # use the default dictionary(36 char)
L
LDOUBLEV 已提交
34 35 36
        if self.character_type == "en":
            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
            dict_character = list(self.character_str)
T
tink2123 已提交
37
        # use the custom dictionary
T
tink2123 已提交
38 39 40
        elif self.character_type in [
                "ch", 'japan', 'korean', 'french', 'german'
        ]:
L
LDOUBLEV 已提交
41
            character_dict_path = config['character_dict_path']
42
            add_space = False
T
tink2123 已提交
43 44
            if 'use_space_char' in config:
                add_space = config['use_space_char']
L
LDOUBLEV 已提交
45 46 47 48
            self.character_str = ""
            with open(character_dict_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
T
tink2123 已提交
49
                    line = line.decode('utf-8').strip("\n").strip("\r\n")
L
LDOUBLEV 已提交
50
                    self.character_str += line
51 52
            if add_space:
                self.character_str += " "
L
LDOUBLEV 已提交
53 54 55 56 57 58 59 60 61 62 63
            dict_character = list(self.character_str)
        elif self.character_type == "en_sensitive":
            # same with ASTER setting (use 94 char).
            self.character_str = string.printable[:-6]
            dict_character = list(self.character_str)
        else:
            self.character_str = None
        assert self.character_str is not None, \
            "Nonsupport type of the character: {}".format(self.character_str)
        self.beg_str = "sos"
        self.end_str = "eos"
T
tink2123 已提交
64
        # add start and end str for attention
L
LDOUBLEV 已提交
65 66
        if self.loss_type == "attention":
            dict_character = [self.beg_str, self.end_str] + dict_character
T
tink2123 已提交
67 68
        elif self.loss_type == "srn":
            dict_character = dict_character + [self.beg_str, self.end_str]
T
tink2123 已提交
69
        # create char dict
L
LDOUBLEV 已提交
70 71 72 73 74 75
        self.dict = {}
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def encode(self, text):
T
tink2123 已提交
76 77 78
        """
        convert text-label into text-index.
        Args:
L
LDOUBLEV 已提交
79
            text: text labels of each image. [batch_size]
T
tink2123 已提交
80
        Return:
L
LDOUBLEV 已提交
81 82 83
            text: concatenated text index for CTCLoss.
                    [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
        """
T
tink2123 已提交
84
        # Ignore capital
T
tink2123 已提交
85
        if self.character_type == "en":
L
LDOUBLEV 已提交
86 87 88 89 90 91 92 93 94 95 96
            text = text.lower()

        text_list = []
        for char in text:
            if char not in self.dict:
                continue
            text_list.append(self.dict[char])
        text = np.array(text_list)
        return text

    def decode(self, text_index, is_remove_duplicate=False):
T
tink2123 已提交
97 98 99 100 101 102 103 104 105
        """
        convert text-index into text-label.
        Args:
            text_index: text index for each image
            is_remove_duplicate: Whether to remove duplicate characters,
                                 The default is False
        Return:
            text: text label
        """
L
LDOUBLEV 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
        char_list = []
        char_num = self.get_char_num()

        if self.loss_type == "attention":
            beg_idx = self.get_beg_end_flag_idx("beg")
            end_idx = self.get_beg_end_flag_idx("end")
            ignored_tokens = [beg_idx, end_idx]
        else:
            ignored_tokens = [char_num]

        for idx in range(len(text_index)):
            if text_index[idx] in ignored_tokens:
                continue
            if is_remove_duplicate:
                if idx > 0 and text_index[idx - 1] == text_index[idx]:
                    continue
T
tink2123 已提交
122
            char_list.append(self.character[int(text_index[idx])])
L
LDOUBLEV 已提交
123 124 125 126
        text = ''.join(char_list)
        return text

    def get_char_num(self):
T
tink2123 已提交
127 128 129
        """
        Get character num
        """
L
LDOUBLEV 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        return len(self.character)

    def get_beg_end_flag_idx(self, beg_or_end):
        if self.loss_type == "attention":
            if beg_or_end == "beg":
                idx = np.array(self.dict[self.beg_str])
            elif beg_or_end == "end":
                idx = np.array(self.dict[self.end_str])
            else:
                assert False, "Unsupport type %s in get_beg_end_flag_idx"\
                    % beg_or_end
            return idx
        else:
            err = "error in get_beg_end_flag_idx when using the loss %s"\
                % (self.loss_type)
            assert False, err


def cal_predicts_accuracy(char_ops,
                          preds,
                          preds_lod,
                          labels,
                          labels_lod,
                          is_remove_duplicate=False):
T
tink2123 已提交
154
    """
T
tink2123 已提交
155
    Calculate prediction accuracy
T
tink2123 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168
    Args:
        char_ops: CharacterOps
        preds: preds result,text index
        preds_lod: lod tensor of preds
        labels: label of input image, text index
        labels_lod:  lod tensor of label
        is_remove_duplicate: Whether to remove duplicate characters,
                                 The default is False
    Return:
        acc: The accuracy of test set
        acc_num: The correct number of samples predicted
        img_num: The total sample number of the test set
    """
L
LDOUBLEV 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    acc_num = 0
    img_num = 0
    for ino in range(len(labels_lod) - 1):
        beg_no = preds_lod[ino]
        end_no = preds_lod[ino + 1]
        preds_text = preds[beg_no:end_no].reshape(-1)
        preds_text = char_ops.decode(preds_text, is_remove_duplicate)

        beg_no = labels_lod[ino]
        end_no = labels_lod[ino + 1]
        labels_text = labels[beg_no:end_no].reshape(-1)
        labels_text = char_ops.decode(labels_text, is_remove_duplicate)
        img_num += 1

        if preds_text == labels_text:
            acc_num += 1
    acc = acc_num * 1.0 / img_num
T
tink2123 已提交
186 187
    return acc, acc_num, img_num

T
tink2123 已提交
188

T
tink2123 已提交
189
def cal_predicts_accuracy_srn(char_ops,
T
tink2123 已提交
190 191 192 193
                              preds,
                              labels,
                              max_text_len,
                              is_debug=False):
T
tink2123 已提交
194 195 196
    acc_num = 0
    img_num = 0

T
tink2123 已提交
197 198
    char_num = char_ops.get_char_num()

T
tink2123 已提交
199 200 201 202 203 204
    total_len = preds.shape[0]
    img_num = int(total_len / max_text_len)
    for i in range(img_num):
        cur_label = []
        cur_pred = []
        for j in range(max_text_len):
T
tink2123 已提交
205
            if labels[j + i * max_text_len] != int(char_num - 1):  #0
T
tink2123 已提交
206 207 208 209 210
                cur_label.append(labels[j + i * max_text_len][0])
            else:
                break

        for j in range(max_text_len + 1):
T
tink2123 已提交
211 212
            if j < len(cur_label) and preds[j + i * max_text_len][
                    0] != cur_label[j]:
T
tink2123 已提交
213 214 215 216
                break
            elif j == len(cur_label) and j == max_text_len:
                acc_num += 1
                break
T
tink2123 已提交
217 218
            elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(
                    char_num - 1):
T
tink2123 已提交
219 220 221
                acc_num += 1
                break
    acc = acc_num * 1.0 / img_num
L
LDOUBLEV 已提交
222 223 224 225
    return acc, acc_num, img_num


def convert_rec_attention_infer_res(preds):
T
tink2123 已提交
226 227 228 229 230 231 232 233
    """
    Convert recognition attention predict result with lod information
    Args:
        preds: the output of the model
    Return:
        convert_ids: A 1-D Tensor represents all the predicted results.
        target_lod: The lod information of the predicted results
    """
L
LDOUBLEV 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    img_num = preds.shape[0]
    target_lod = [0]
    convert_ids = []
    for ino in range(img_num):
        end_pos = np.where(preds[ino, :] == 1)[0]
        if len(end_pos) <= 1:
            text_list = preds[ino, 1:]
        else:
            text_list = preds[ino, 1:end_pos[1]]
        target_lod.append(target_lod[ino] + len(text_list))
        convert_ids = convert_ids + list(text_list)
    convert_ids = np.array(convert_ids)
    convert_ids = convert_ids.reshape((-1, 1))
    return convert_ids, target_lod


def convert_rec_label_to_lod(ori_labels):
    img_num = len(ori_labels)
    target_lod = [0]
    convert_ids = []
    for ino in range(img_num):
        target_lod.append(target_lod[ino] + len(ori_labels[ino]))
        convert_ids = convert_ids + list(ori_labels[ino])
    convert_ids = np.array(convert_ids)
    convert_ids = convert_ids.reshape((-1, 1))
    return convert_ids, target_lod