diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index 261462044a9000561517c3657f5b5a6090fd107a..d0441c05da0ed70e088c15c886c7864249ed76d0 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -25,6 +25,12 @@ from copy import deepcopy class RecModel(object): + """ + Rec model architecture + Args: + params(object): Params from yaml file and settings from command line + """ + def __init__(self, params): super(RecModel, self).__init__() global_params = params['Global'] @@ -64,6 +70,12 @@ class RecModel(object): self.num_heads = None def create_feed(self, mode): + """ + Create feed dict and DataLoader object + Args: + mode(str): runtime mode, can be "train", "eval" or "test" + Return: image, labels, loader + """ image_shape = deepcopy(self.image_shape) image_shape.insert(0, -1) if mode == "train": @@ -189,9 +201,12 @@ class RecModel(object): inputs = image else: inputs = self.tps(image) + # backbone conv_feas = self.backbone(inputs) + # predict predicts = self.head(conv_feas, labels, mode) decoded_out = predicts['decoded_out'] + # loss if mode == "train": loss = self.loss(predicts, labels) if self.loss_type == "attention": @@ -211,7 +226,7 @@ class RecModel(object): outputs = {'total_loss':loss, 'decoded_out':\ decoded_out, 'label':label} return loader, outputs - + # export_model elif mode == "export": predict = predicts['predict'] if self.loss_type == "ctc": @@ -225,6 +240,7 @@ class RecModel(object): ] return [image, {'decoded_out': decoded_out, 'predicts': predict}] + # eval or test else: predict = predicts['predict'] if self.loss_type == "ctc": diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 84948c2b20933d0f2086a42442a420d1b6b1eeee..67adcce0fd0e0ad4bef92d2f04e5de1a9e0193b0 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -27,6 +27,12 @@ import numpy as np class CTCPredict(object): + """ + CTC predict + Args: + params(object): Params from yaml file and settings from command line + """ + def __init__(self, params): super(CTCPredict, self).__init__() self.char_num = params['char_num'] diff --git a/ppocr/modeling/losses/rec_attention_loss.py b/ppocr/modeling/losses/rec_attention_loss.py index 8d8d7c1359f5f5edf79aed39092fa637a6cbde03..c8d8055627f9e364f0342a5740d7bf9b0297de59 100755 --- a/ppocr/modeling/losses/rec_attention_loss.py +++ b/ppocr/modeling/losses/rec_attention_loss.py @@ -33,6 +33,7 @@ class AttentionLoss(object): predict = predicts['predict'] label_out = labels['label_out'] label_out = fluid.layers.cast(x=label_out, dtype='int64') + # calculate attention loss cost = fluid.layers.cross_entropy(input=predict, label=label_out) sum_cost = fluid.layers.reduce_sum(cost) return sum_cost diff --git a/ppocr/modeling/losses/rec_ctc_loss.py b/ppocr/modeling/losses/rec_ctc_loss.py index 3552d320978f33ec3eb032c96654eb3b7886d8c0..d443b8d50f842ea738ec8448af17667bb6bea193 100755 --- a/ppocr/modeling/losses/rec_ctc_loss.py +++ b/ppocr/modeling/losses/rec_ctc_loss.py @@ -30,6 +30,7 @@ class CTCLoss(object): def __call__(self, predicts, labels): predict = predicts['predict'] label = labels['label'] + # calculate ctc loss cost = fluid.layers.warpctc( input=predict, label=label, blank=self.char_num, norm_by_times=True) sum_cost = fluid.layers.reduce_sum(cost) diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index 97237cfa71a3d3ae0684ecbefbb2511f09bcd3a2..cd4f87568a30329e0394bd77bf879be35a54ef80 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -20,15 +20,21 @@ import sys class CharacterOps(object): - """ Convert between text-label and text-index """ + """ + Convert between text-label and text-index + Args: + config: config from yaml file + """ def __init__(self, config): self.character_type = config['character_type'] self.loss_type = config['loss_type'] self.max_text_len = config['max_text_length'] + # use the default dictionary(36 char) if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) + # use the custom dictionary elif self.character_type in [ "ch", 'japan', 'korean', 'french', 'german' ]: @@ -55,25 +61,27 @@ class CharacterOps(object): "Nonsupport type of the character: {}".format(self.character_str) self.beg_str = "sos" self.end_str = "eos" + # add start and end str for attention if self.loss_type == "attention": dict_character = [self.beg_str, self.end_str] + dict_character elif self.loss_type == "srn": dict_character = dict_character + [self.beg_str, self.end_str] + # create char dict self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character def encode(self, text): - """convert text-label into text-index. - input: + """ + convert text-label into text-index. + Args: text: text labels of each image. [batch_size] - - output: + Return: text: concatenated text index for CTCLoss. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] - length: length of each text. [batch_size] """ + # Ignore capital if self.character_type == "en": text = text.lower() @@ -86,7 +94,15 @@ class CharacterOps(object): return text def decode(self, text_index, is_remove_duplicate=False): - """ convert text-index into text-label. """ + """ + convert text-index into text-label. + Args: + text_index: text index for each image + is_remove_duplicate: Whether to remove duplicate characters, + The default is False + Return: + text: text label + """ char_list = [] char_num = self.get_char_num() @@ -108,6 +124,9 @@ class CharacterOps(object): return text def get_char_num(self): + """ + Get character num + """ return len(self.character) def get_beg_end_flag_idx(self, beg_or_end): @@ -132,6 +151,21 @@ def cal_predicts_accuracy(char_ops, labels, labels_lod, is_remove_duplicate=False): + """ + Calculate predicts accrarcy + Args: + char_ops: CharacterOps + preds: preds result,text index + preds_lod: lod tensor of preds + labels: label of input image, text index + labels_lod: lod tensor of label + is_remove_duplicate: Whether to remove duplicate characters, + The default is False + Return: + acc: The accuracy of test set + acc_num: The correct number of samples predicted + img_num: The total sample number of the test set + """ acc_num = 0 img_num = 0 for ino in range(len(labels_lod) - 1): @@ -189,6 +223,14 @@ def cal_predicts_accuracy_srn(char_ops, def convert_rec_attention_infer_res(preds): + """ + Convert recognition attention predict result with lod information + Args: + preds: the output of the model + Return: + convert_ids: A 1-D Tensor represents all the predicted results. + target_lod: The lod information of the predicted results + """ img_num = preds.shape[0] target_lod = [0] convert_ids = [] diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 4479d9dff2c352c54a26ac1bfdbddab497fff418..180bfcfe679d1b74a643b1c97ce46d23d7c184cf 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -122,7 +122,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode): def test_rec_benchmark(exe, config, eval_info_dict): - " Evaluate lmdb dataset " + """ + eval rec benchmark + """ eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \ 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] eval_data_dir = config['TestReader']['lmdb_sets_dir'] diff --git a/tools/program.py b/tools/program.py index 20d0b0b2cab51c5c21905bcfb732239594439311..9aca29afac8e265aa60580c1027e518d11999888 100755 --- a/tools/program.py +++ b/tools/program.py @@ -150,19 +150,20 @@ def check_gpu(use_gpu): def build(config, main_prog, startup_prog, mode): """ Build a program using a model and an optimizer - 1. create feeds - 2. create a dataloader - 3. create a model - 4. create fetchs - 5. create an optimizer + 1. create a dataloader + 2. create a model + 3. create fetchs + 4. create an optimizer Args: config(dict): config main_prog(): main program startup_prog(): startup program - is_train(bool): train or valid + mode(str): train or valid Returns: dataloader(): a bridge between the model and the data - fetchs(dict): dict of model outputs(included loss and measures) + fetch_name_list(dict): dict of model outputs(included loss and measures) + fetch_varname_list(list): list of outputs' varname + opt_loss_name(str): name of loss """ with fluid.program_guard(main_prog, startup_prog): with fluid.unique_name.guard(): @@ -257,9 +258,14 @@ def train_eval_det_run(config, train_info_dict, eval_info_dict, is_slim=None): - ''' - main program of evaluation for detection - ''' + """ + Feed data to the model and fetch the measures and loss for detection + Args: + config: config + exe: + train_info_dict: information dict for training + eval_info_dict: information dict for evaluation + """ train_batch_id = 0 log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num'] @@ -376,9 +382,14 @@ def train_eval_rec_run(config, train_info_dict, eval_info_dict, is_slim=None): - ''' - main program of evaluation for recognition - ''' + """ + Feed data to the model and fetch the measures and loss for recognition + Args: + config: config + exe: + train_info_dict: information dict for training + eval_info_dict: information dict for evaluation + """ train_batch_id = 0 log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num']