From 6dd494b60a65c31ac2c1b4c26093bd4b205e7c6c Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 20 May 2020 18:15:36 +0800 Subject: [PATCH] add anno --- ppocr/modeling/architectures/rec_model.py | 21 +++++++++ ppocr/modeling/losses/rec_attention_loss.py | 1 + ppocr/modeling/losses/rec_ctc_loss.py | 1 + ppocr/utils/character.py | 49 ++++++++++++++++++--- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index d88c620b..a3ca606c 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -25,6 +25,14 @@ from copy import deepcopy class RecModel(object): + """ + Rec model architecture + + Args: + params(object): Params from yaml file and settings from command line + + """ + def __init__(self, params): super(RecModel, self).__init__() global_params = params['Global'] @@ -58,6 +66,13 @@ class RecModel(object): self.max_text_length = global_params['max_text_length'] def create_feed(self, mode): + """ + Create feed dict and DataLoader object + Args: + mode(str): runtime mode, can be "train", "eval" or "test" + Return: image, labels, loader + + """ image_shape = deepcopy(self.image_shape) image_shape.insert(0, -1) image = fluid.data(name='image', shape=image_shape, dtype='float32') @@ -96,9 +111,13 @@ class RecModel(object): inputs = image else: inputs = self.tps(image) + + # backbone conv_feas = self.backbone(inputs) + # predict predicts = self.head(conv_feas, labels, mode) decoded_out = predicts['decoded_out'] + #loss if mode == "train": loss = self.loss(predicts, labels) if self.loss_type == "attention": @@ -108,9 +127,11 @@ class RecModel(object): outputs = {'total_loss':loss, 'decoded_out':\ decoded_out, 'label':label} return loader, outputs + # export_model elif mode == "export": predict = predicts['predict'] predict = fluid.layers.softmax(predict) return [image, {'decoded_out': decoded_out, 'predicts': predict}] + # eval or test else: return loader, {'decoded_out': decoded_out} diff --git a/ppocr/modeling/losses/rec_attention_loss.py b/ppocr/modeling/losses/rec_attention_loss.py index 8d8d7c13..c8d80556 100755 --- a/ppocr/modeling/losses/rec_attention_loss.py +++ b/ppocr/modeling/losses/rec_attention_loss.py @@ -33,6 +33,7 @@ class AttentionLoss(object): predict = predicts['predict'] label_out = labels['label_out'] label_out = fluid.layers.cast(x=label_out, dtype='int64') + # calculate attention loss cost = fluid.layers.cross_entropy(input=predict, label=label_out) sum_cost = fluid.layers.reduce_sum(cost) return sum_cost diff --git a/ppocr/modeling/losses/rec_ctc_loss.py b/ppocr/modeling/losses/rec_ctc_loss.py index 3552d320..d443b8d5 100755 --- a/ppocr/modeling/losses/rec_ctc_loss.py +++ b/ppocr/modeling/losses/rec_ctc_loss.py @@ -30,6 +30,7 @@ class CTCLoss(object): def __call__(self, predicts, labels): predict = predicts['predict'] label = labels['label'] + # calculate ctc loss cost = fluid.layers.warpctc( input=predict, label=label, blank=self.char_num, norm_by_times=True) sum_cost = fluid.layers.reduce_sum(cost) diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index b4075039..42b63173 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -20,14 +20,22 @@ import sys class CharacterOps(object): - """ Convert between text-label and text-index """ + """ + Convert between text-label and text-index + + Args: + config: config from yaml file + + """ def __init__(self, config): self.character_type = config['character_type'] self.loss_type = config['loss_type'] + # use the default dictionary(36 char) if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) + # use the custom dictionary elif self.character_type == "ch": character_dict_path = config['character_dict_path'] self.character_str = "" @@ -47,26 +55,29 @@ class CharacterOps(object): "Nonsupport type of the character: {}".format(self.character_str) self.beg_str = "sos" self.end_str = "eos" + # add start and end str for attention if self.loss_type == "attention": dict_character = [self.beg_str, self.end_str] + dict_character + # create char dict self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character def encode(self, text): - """convert text-label into text-index. - input: + """ + convert text-label into text-index. + + Args: text: text labels of each image. [batch_size] - output: + Reture: text: concatenated text index for CTCLoss. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] - length: length of each text. [batch_size] """ + # Ignore capital if self.character_type == "en": text = text.lower() - text_list = [] for char in text: if char not in self.dict: @@ -76,7 +87,15 @@ class CharacterOps(object): return text def decode(self, text_index, is_remove_duplicate=False): - """ convert text-index into text-label. """ + """ + convert text-index into text-label. + Args: + text_index: text index for each image + is_remove_duplicate: Whether to remove duplicate characters, + The default is False + Return: + text: text label + """ char_list = [] char_num = self.get_char_num() @@ -98,6 +117,9 @@ class CharacterOps(object): return text def get_char_num(self): + """ + Get character num + """ return len(self.character) def get_beg_end_flag_idx(self, beg_or_end): @@ -122,6 +144,19 @@ def cal_predicts_accuracy(char_ops, labels, labels_lod, is_remove_duplicate=False): + """ + Calculate predicts accrarcy + Args: + char_ops: CharacterOps + preds: preds result,text index + preds_lod: + labels: + labels_lod: + is_remove_duplicate: + + Return: + + """ acc_num = 0 img_num = 0 for ino in range(len(labels_lod) - 1): -- GitLab