character.py 7.5 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import string
import re
from .check import check_config_params
import sys


class CharacterOps(object):
T
tink2123 已提交
23 24 25 26 27 28 29
    """
    Convert between text-label and text-index

    Args:
        config: config from yaml file

    """
L
LDOUBLEV 已提交
30 31 32 33

    def __init__(self, config):
        self.character_type = config['character_type']
        self.loss_type = config['loss_type']
T
tink2123 已提交
34
        # use the default dictionary(36 char)
L
LDOUBLEV 已提交
35 36 37
        if self.character_type == "en":
            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
            dict_character = list(self.character_str)
T
tink2123 已提交
38
        # use the custom dictionary
L
LDOUBLEV 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        elif self.character_type == "ch":
            character_dict_path = config['character_dict_path']
            self.character_str = ""
            with open(character_dict_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    line = line.decode('utf-8').strip("\n")
                    self.character_str += line
            dict_character = list(self.character_str)
        elif self.character_type == "en_sensitive":
            # same with ASTER setting (use 94 char).
            self.character_str = string.printable[:-6]
            dict_character = list(self.character_str)
        else:
            self.character_str = None
        assert self.character_str is not None, \
            "Nonsupport type of the character: {}".format(self.character_str)
        self.beg_str = "sos"
        self.end_str = "eos"
T
tink2123 已提交
58
        # add start and end str for attention
L
LDOUBLEV 已提交
59 60
        if self.loss_type == "attention":
            dict_character = [self.beg_str, self.end_str] + dict_character
T
tink2123 已提交
61
        # create char dict
L
LDOUBLEV 已提交
62 63 64 65 66 67
        self.dict = {}
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def encode(self, text):
T
tink2123 已提交
68 69 70 71
        """
        convert text-label into text-index.

        Args:
L
LDOUBLEV 已提交
72 73
            text: text labels of each image. [batch_size]

T
tink2123 已提交
74
        Return:
L
LDOUBLEV 已提交
75 76 77
            text: concatenated text index for CTCLoss.
                    [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
        """
T
tink2123 已提交
78
        # Ignore capital
T
tink2123 已提交
79
        if self.character_type == "en":
L
LDOUBLEV 已提交
80 81 82 83 84 85 86 87 88 89
            text = text.lower()
        text_list = []
        for char in text:
            if char not in self.dict:
                continue
            text_list.append(self.dict[char])
        text = np.array(text_list)
        return text

    def decode(self, text_index, is_remove_duplicate=False):
T
tink2123 已提交
90 91 92 93 94 95 96 97 98
        """
        convert text-index into text-label.
        Args:
            text_index: text index for each image
            is_remove_duplicate: Whether to remove duplicate characters,
                                 The default is False
        Return:
            text: text label
        """
L
LDOUBLEV 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        char_list = []
        char_num = self.get_char_num()

        if self.loss_type == "attention":
            beg_idx = self.get_beg_end_flag_idx("beg")
            end_idx = self.get_beg_end_flag_idx("end")
            ignored_tokens = [beg_idx, end_idx]
        else:
            ignored_tokens = [char_num]

        for idx in range(len(text_index)):
            if text_index[idx] in ignored_tokens:
                continue
            if is_remove_duplicate:
                if idx > 0 and text_index[idx - 1] == text_index[idx]:
                    continue
            char_list.append(self.character[text_index[idx]])
        text = ''.join(char_list)
        return text

    def get_char_num(self):
T
tink2123 已提交
120 121 122
        """
        Get character num
        """
L
LDOUBLEV 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        return len(self.character)

    def get_beg_end_flag_idx(self, beg_or_end):
        if self.loss_type == "attention":
            if beg_or_end == "beg":
                idx = np.array(self.dict[self.beg_str])
            elif beg_or_end == "end":
                idx = np.array(self.dict[self.end_str])
            else:
                assert False, "Unsupport type %s in get_beg_end_flag_idx"\
                    % beg_or_end
            return idx
        else:
            err = "error in get_beg_end_flag_idx when using the loss %s"\
                % (self.loss_type)
            assert False, err


def cal_predicts_accuracy(char_ops,
                          preds,
                          preds_lod,
                          labels,
                          labels_lod,
                          is_remove_duplicate=False):
T
tink2123 已提交
147 148 149 150 151
    """
    Calculate predicts accrarcy
    Args:
        char_ops: CharacterOps
        preds: preds result,text index
T
add ano  
tink2123 已提交
152 153 154 155 156
        preds_lod: lod tensor of preds
        labels: label of input image, text index
        labels_lod:  lod tensor of label
        is_remove_duplicate: Whether to remove duplicate characters,
                                 The default is False
T
tink2123 已提交
157 158

    Return:
T
add ano  
tink2123 已提交
159 160 161
        acc: The accuracy of test set
        acc_num: The correct number of samples predicted
        img_num: The total sample number of the test set
T
tink2123 已提交
162 163

    """
L
LDOUBLEV 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    acc_num = 0
    img_num = 0
    for ino in range(len(labels_lod) - 1):
        beg_no = preds_lod[ino]
        end_no = preds_lod[ino + 1]
        preds_text = preds[beg_no:end_no].reshape(-1)
        preds_text = char_ops.decode(preds_text, is_remove_duplicate)

        beg_no = labels_lod[ino]
        end_no = labels_lod[ino + 1]
        labels_text = labels[beg_no:end_no].reshape(-1)
        labels_text = char_ops.decode(labels_text, is_remove_duplicate)
        img_num += 1

        if preds_text == labels_text:
            acc_num += 1
    acc = acc_num * 1.0 / img_num
    return acc, acc_num, img_num


def convert_rec_attention_infer_res(preds):
T
add ano  
tink2123 已提交
185 186 187 188 189 190 191 192 193 194
    """
    Convert recognition attention predict result with lod information

    Args:
        preds: the output of the model

    Return:
        convert_ids: A 1-D Tensor represents all the predicted results.
        target_lod: The lod information of the predicted results
    """
L
LDOUBLEV 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    img_num = preds.shape[0]
    target_lod = [0]
    convert_ids = []
    for ino in range(img_num):
        end_pos = np.where(preds[ino, :] == 1)[0]
        if len(end_pos) <= 1:
            text_list = preds[ino, 1:]
        else:
            text_list = preds[ino, 1:end_pos[1]]
        target_lod.append(target_lod[ino] + len(text_list))
        convert_ids = convert_ids + list(text_list)
    convert_ids = np.array(convert_ids)
    convert_ids = convert_ids.reshape((-1, 1))
    return convert_ids, target_lod


def convert_rec_label_to_lod(ori_labels):
T
add ano  
tink2123 已提交
212 213 214 215 216 217 218 219 220 221
    """
    Convert recognition label to lod tensor

    Args:
        ori_labels: origin labels of total images
    Return:
        convert_ids: A 1-D Tensor represents all labels
        target_lod: The lod information of the labels

    """
L
LDOUBLEV 已提交
222 223 224 225 226 227 228 229 230
    img_num = len(ori_labels)
    target_lod = [0]
    convert_ids = []
    for ino in range(img_num):
        target_lod.append(target_lod[ino] + len(ori_labels[ino]))
        convert_ids = convert_ids + list(ori_labels[ino])
    convert_ids = np.array(convert_ids)
    convert_ids = convert_ids.reshape((-1, 1))
    return convert_ids, target_lod