From 622583dfd92ff128f3b3bac2e99a167b0202813d Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Sat, 17 Apr 2021 11:16:06 +0800 Subject: [PATCH] fix ctc loss and data input md --- doc/doc_ch/pgnet.md | 4 +++- doc/doc_en/pgnet_en.md | 4 +++- ppocr/losses/e2e_pg_loss.py | 12 +++++------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md index 83045b8b..3874c623 100644 --- a/doc/doc_ch/pgnet.md +++ b/doc/doc_ch/pgnet.md @@ -94,10 +94,12 @@ total_text.txt标注文件格式如下,文件名和标注信息中间用"\t" " 图像文件名 json.dumps编码的图像标注信息" rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] ``` -json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的十四个点的坐标(x, y),从左上角的点开始顺时针排列。 `transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** 如果您想在其他数据集上训练,可以按照上述形式构建标注文件。 +*PGNet支持任意点的数据输入,但是需要保证均匀标注(上下对称,左右距离一致)。在我们实验中,十四点标注要比四点标注训练效果好,可以尝试在四点标注和十四点标注上作两阶段训练* + ### 启动训练 PGNet训练分为两个步骤:step1: 在合成数据上训练,得到预训练模型,此时模型精度依然较低;step2: 加载预训练模型,在totaltext数据集上训练;为快速训练,我们直接提供了step1的预训练模型。 diff --git a/doc/doc_en/pgnet_en.md b/doc/doc_en/pgnet_en.md index 1352fbbe..942b8de3 100644 --- a/doc/doc_en/pgnet_en.md +++ b/doc/doc_en/pgnet_en.md @@ -93,12 +93,14 @@ rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698. ``` The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries. -The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner. +The `points` in the dictionary represent the coordinates (x, y) of the fourteen points of the text box, arranged clockwise from the point at the upper left corner. `transcription` represents the text of the current text box. **When its content is "###" it means that the text box is invalid and will be skipped during training.** If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format. +*PGNet supports data input of any point, but it needs to ensure uniform labeling (upper and lower symmetry, left and right distance is consistent). In our experiment, the training effect of fourteen points tagging is better than that of four points tagging. We can try to do two-stage training on four points tagging and fourteen points tagging.* + ### Start Training diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py index 10a8ed0a..56b9333d 100644 --- a/ppocr/losses/e2e_pg_loss.py +++ b/ppocr/losses/e2e_pg_loss.py @@ -102,13 +102,11 @@ class PGLoss(nn.Layer): f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2)) N, B, _ = f_tcl_char_ld.shape input_lengths = paddle.to_tensor([N] * B, dtype='int64') - cost = paddle.nn.functional.ctc_loss( - log_probs=f_tcl_char_ld, - labels=tcl_label, - input_lengths=input_lengths, - label_lengths=label_t, - blank=self.pad_num, - reduction='none') + loss_out = paddle.fluid.layers.warpctc(f_tcl_char_ld, tcl_label, + self.pad_num, True, + input_lengths, label_t) + + cost = paddle.fluid.layers.squeeze(loss_out, [-1]) cost = cost.mean() return cost -- GitLab