diff --git a/doc/doc_ch/enhanced_ctc_loss.md b/doc/doc_ch/enhanced_ctc_loss.md new file mode 100644 index 0000000000000000000000000000000000000000..309dc712dc0242b859f934338be96e6648f81031 --- /dev/null +++ b/doc/doc_ch/enhanced_ctc_loss.md @@ -0,0 +1,78 @@ +# Enhanced CTC Loss + +在OCR识别中, CRNN是一种在工业界广泛使用的文字识别算法。 在训练阶段,其采用CTCLoss来计算网络损失; 在推理阶段,其采用CTCDecode来获得解码结果。虽然CRNN算法在实际业务中被证明能够获得很好的识别效果, 然而用户对识别准确率的要求却是无止境的,如何进一步提升文字识别的准确率呢? 本文以CTCLoss为切人点,分别从难例挖掘、 多任务学习、 Metric Learning 3个不同的角度探索了CTCLoss的改进融合方案,提出了EnhancedCTCLoss,其包括如下3个组成部分: Focal-CTC Loss,A-CTC Loss, C-CTC Loss。 + +## 1. Focal-CTC Loss +Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最先提出的时候主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。 +其损失函数形式如下: +
+ +
+ +其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ γ和平衡因子α。 当α = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示: +
+ +
+ +从上图可以看到, 当γ> 0时,调整系数(1-y’)^γ 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子γ用于调节简单样本权重降低的速率,当γ为0时即为交叉熵损失函数,当γ增加时,调整因子的影响也会随之增大。实验发现γ为2是最优。平衡因子α用来平衡正负样本本身的比例不均,文中α取0.25。 + +对于经典的CTC算法,假设某个特征序列(f1, f2, ......ft), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现 CTCLoss值和y’有如下关系: +
+ +
+ +结合Focal Loss的思想,赋予困难样本较大的权重,简单样本较小的权重,可以使网络更加聚焦于对困难样本的挖掘,进一步提升识别的准确率,由此我们提出了Focal-CTC Loss; 其定义如下所示: +
+ +
+ +实验中,γ取值为2, α= 1, 具体实现见: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py) + +## 2. A-CTC Loss +A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势: ++ ACE Loss能够解决2-D文本的识别问题; CTCLoss只能够处理1-D文本 ++ ACE Loss 在时间复杂度和空间复杂度上优于CTC loss + +前人总结的OCR识别算法的优劣如下图所示: +
+ +
+ +虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行组合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。 +A_CTC Loss定义如下: +
+ +
+ +实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py) + +## 3. C-CTC Loss +C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大累间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。 +在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。 +通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下: +
+ +
+ +实验中,我们设置λ=0.25. center_loss实现代码见: [center_loss.py](../../ppocr/losses/center_loss.py) + +值得一提的是, 在C-CTC Loss中,选择随机初始化Center并不能够带来明显的提升. 我们的Center初始化方法如下: ++ 基于原始的CTCLoss, 训练得到一个网络N ++ 挑选出训练集中,识别完全正确的部分, 组成集合G ++ 将G中的每个样本送入网络,进行前向计算, 提取最后一个FC层的输入(即feature)及其经过argmax计算的结果(即index)之间的对应关系 ++ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center. + +以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示: +``` +python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model: "./output/rec_mobile_pp-OCRv2/best_accuracy" +``` +运行完后,会在PaddleOCR主目录下生成`train_center.pkl`. + +## 4. 实验 +对于上述的三种方案,我们基于百度内部数据集进行了训练、评测,实验情况如下表所示: +|algorithm| Focal_CTC | A_CTC | C-CTC | +|:------| :------| ------: | :------: | +|gain| +0.3% | +0.7% | +1.7% | + +基于上述实验结论,我们在PP-OCRv2中,采用了C-CTC的策略。 值得一提的是,由于PP-OCRv2 处理的是6625个中文字符的识别任务,字符集比较大,形似字较多,所以在该任务上C-CTC 方案带来的提升较大。 但如果换做其他OCR识别任务,结论可能会有所不同。大家可以尝试Focal-CTC,A-CTC, C-CTC以及组合方案EnhancedCTC,相信会带来不同程度的提升效果。 +统一的融合方案见如下文件: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py) diff --git a/doc/doc_ch/equation_a_ctc.png b/doc/doc_ch/equation_a_ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..ae097610d37a88e76edefdbeb81df8403e94215f Binary files /dev/null and b/doc/doc_ch/equation_a_ctc.png differ diff --git a/doc/doc_ch/equation_c_ctc.png b/doc/doc_ch/equation_c_ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..67207a9937481f4920af3cbafbe1bfe8d27ee5dc Binary files /dev/null and b/doc/doc_ch/equation_c_ctc.png differ diff --git a/doc/doc_ch/equation_ctcloss.png b/doc/doc_ch/equation_ctcloss.png new file mode 100644 index 0000000000000000000000000000000000000000..33ad92c9e4567d2a4a0c8fc3b2a0bf3fba5ea8f2 Binary files /dev/null and b/doc/doc_ch/equation_ctcloss.png differ diff --git a/doc/doc_ch/equation_focal_ctc.png b/doc/doc_ch/equation_focal_ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..6ba1e8715d5876705ef429e48b5c94388fd41398 Binary files /dev/null and b/doc/doc_ch/equation_focal_ctc.png differ diff --git a/doc/doc_ch/focal_loss_formula.png b/doc/doc_ch/focal_loss_formula.png new file mode 100644 index 0000000000000000000000000000000000000000..971cebcd082cf5e19f9246f02216c0c14896bdc9 Binary files /dev/null and b/doc/doc_ch/focal_loss_formula.png differ diff --git a/doc/doc_ch/focal_loss_image.png b/doc/doc_ch/focal_loss_image.png new file mode 100644 index 0000000000000000000000000000000000000000..430550a732d4e2769151771bc85ae889dfc78fda Binary files /dev/null and b/doc/doc_ch/focal_loss_image.png differ diff --git a/doc/doc_ch/rec_algo_compare.png b/doc/doc_ch/rec_algo_compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2dde496c75f327ca1c0c9ccb0dbe6949215a4a1b Binary files /dev/null and b/doc/doc_ch/rec_algo_compare.png differ diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 5d09802b46d7ddfa802461760b917267155b3923..063d68e30861e092e10fa3068e4b7f4755b6197f 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -38,7 +38,7 @@ class CTCLoss(nn.Layer): if self.use_focal_loss: weight = paddle.exp(-loss) weight = paddle.subtract(paddle.to_tensor([1.0]), weight) - weight = paddle.square(weight) * self.focal_loss_alpha + weight = paddle.square(weight) loss = paddle.multiply(loss, weight) - loss = loss.mean() # sum + loss = loss.mean() return {'loss': loss} diff --git a/ppocr/losses/rec_enhanced_ctc_loss.py b/ppocr/losses/rec_enhanced_ctc_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b57be6468e2ec75811442e7449525267e7d9e82e --- /dev/null +++ b/ppocr/losses/rec_enhanced_ctc_loss.py @@ -0,0 +1,70 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from .ace_loss import ACELoss +from .center_loss import CenterLoss +from .rec_ctc_loss import CTCLoss + + +class EnhancedCTCLoss(nn.Layer): + def __init__(self, + use_focal_loss=False, + use_ace_loss=False, + ace_loss_weight=0.1, + use_center_loss=False, + center_loss_weight=0.05, + num_classes=6625, + feat_dim=96, + init_center=False, + center_file_path=None, + **kwargs): + super(EnhancedCTCLoss, self).__init__() + self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss) + + self.use_ace_loss = False + if use_ace_loss: + self.use_ace_loss = use_ace_loss + self.ace_loss_func = ACELoss() + self.ace_loss_weight = ace_loss_weight + + self.use_center_loss = False + if use_center_loss: + self.use_center_loss = use_center_loss + self.center_loss_func = CenterLoss( + num_classes=num_classes, + feat_dim=feat_dim, + init_center=init_center, + center_file_path=center_file_path) + self.center_loss_weight = center_loss_weight + + def __call__(self, predicts, batch): + loss = self.ctc_loss_func(predicts, batch)["loss"] + + if self.use_center_loss: + center_loss = self.center_loss_func( + predicts, batch)["loss_center"] * self.center_loss_weight + loss = loss + center_loss + + if self.use_ace_loss: + ace_loss = self.ace_loss_func( + predicts, batch)["loss_ace"] * self.ace_loss_weight + loss = loss + ace_loss + + return {'enhanced_ctc_loss': loss} diff --git a/tools/export_center.py b/tools/export_center.py new file mode 100644 index 0000000000000000000000000000000000000000..c46e8b9d58997b9b66c6ce81b2558ecd4cad0e81 --- /dev/null +++ b/tools/export_center.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import pickle + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from ppocr.data import build_dataloader +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model, load_dygraph_params +from ppocr.utils.utility import print_dict +import tools.program as program + + +def main(): + global_config = config['Global'] + # build dataloader + config['Eval']['dataset']['name'] = config['Train']['dataset']['name'] + config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][ + 'data_dir'] + config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][ + 'label_file_list'] + eval_dataloader = build_dataloader(config, 'Eval', device, logger) + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + # for rec algorithm + if hasattr(post_process_class, 'character'): + char_num = len(getattr(post_process_class, 'character')) + config['Architecture']["Head"]['out_channels'] = char_num + + #set return_features = True + config['Architecture']["Head"]["return_feats"] = True + + model = build_model(config['Architecture']) + + best_model_dict = load_dygraph_params(config, model, logger, None) + if len(best_model_dict): + logger.info('metric in ckpt ***************') + for k, v in best_model_dict.items(): + logger.info('{}:{}'.format(k, v)) + + # get features from train data + char_center = program.get_center(model, eval_dataloader, post_process_class) + + #serialize to disk + with open("train_center.pkl", 'wb') as f: + pickle.dump(char_center, f) + return + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/tools/program.py b/tools/program.py index 0b766928ce9c3ff393a57915461e207087e4379e..798e6dff297ad1149942488cca1d5540f1924867 100755 --- a/tools/program.py +++ b/tools/program.py @@ -399,6 +399,57 @@ def eval(model, return metric +def update_center(char_center, post_result, preds): + result, label = post_result + feats, logits = preds + logits = paddle.argmax(logits, axis=-1) + feats = feats.numpy() + logits = logits.numpy() + + for idx_sample in range(len(label)): + if result[idx_sample][0] == label[idx_sample][0]: + feat = feats[idx_sample] + logit = logits[idx_sample] + for idx_time in range(len(logit)): + index = logit[idx_time] + if index in char_center.keys(): + char_center[index][0] = ( + char_center[index][0] * char_center[index][1] + + feat[idx_time]) / (char_center[index][1] + 1) + char_center[index][1] += 1 + else: + char_center[index] = [feat[idx_time], 1] + return char_center + + +def get_center(model, eval_dataloader, post_process_class): + pbar = tqdm(total=len(eval_dataloader), desc='get center:') + max_iter = len(eval_dataloader) - 1 if platform.system( + ) == "Windows" else len(eval_dataloader) + char_center = dict() + for idx, batch in enumerate(eval_dataloader): + if idx >= max_iter: + break + images = batch[0] + start = time.time() + preds = model(images) + + batch = [item.numpy() for item in batch] + # Obtain usable results from post-processing methods + total_time += time.time() - start + # Evaluate the results of the current batch + post_result = post_process_class(preds, batch[1]) + + #update char_center + char_center = update_center(char_center, post_result, preds) + pbar.update(1) + + pbar.close() + for key in char_center.keys(): + char_center[key] = char_center[key][0] + return char_center + + def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() profiler_options = FLAGS.profiler_options @@ -427,7 +478,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED'] + 'SEED' + ] windows_not_support_list = ['PSE'] if platform.system() == "Windows" and alg in windows_not_support_list: logger.warning('{} is not support in Windows now'.format(