提交 622583df 编写于 作者: J Jethong

fix ctc loss and data input md

上级 c455034f
...@@ -94,10 +94,12 @@ total_text.txt标注文件格式如下,文件名和标注信息中间用"\t" ...@@ -94,10 +94,12 @@ total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"
" 图像文件名 json.dumps编码的图像标注信息" " 图像文件名 json.dumps编码的图像标注信息"
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}]
``` ```
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** `transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。 如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
*PGNet支持任意点的数据输入,但是需要保证均匀标注(上下对称,左右距离一致)。在我们实验中,十四点标注要比四点标注训练效果好,可以尝试在四点标注和十四点标注上作两阶段训练*
### 启动训练 ### 启动训练
PGNet训练分为两个步骤:step1: 在合成数据上训练,得到预训练模型,此时模型精度依然较低;step2: 加载预训练模型,在totaltext数据集上训练;为快速训练,我们直接提供了step1的预训练模型。 PGNet训练分为两个步骤:step1: 在合成数据上训练,得到预训练模型,此时模型精度依然较低;step2: 加载预训练模型,在totaltext数据集上训练;为快速训练,我们直接提供了step1的预训练模型。
......
...@@ -93,12 +93,14 @@ rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698. ...@@ -93,12 +93,14 @@ rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.
``` ```
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries. The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner. The `points` in the dictionary represent the coordinates (x, y) of the fourteen points of the text box, arranged clockwise from the point at the upper left corner.
`transcription` represents the text of the current text box. **When its content is "###" it means that the text box is invalid and will be skipped during training.** `transcription` represents the text of the current text box. **When its content is "###" it means that the text box is invalid and will be skipped during training.**
If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format. If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format.
*PGNet supports data input of any point, but it needs to ensure uniform labeling (upper and lower symmetry, left and right distance is consistent). In our experiment, the training effect of fourteen points tagging is better than that of four points tagging. We can try to do two-stage training on four points tagging and fourteen points tagging.*
### Start Training ### Start Training
......
...@@ -102,13 +102,11 @@ class PGLoss(nn.Layer): ...@@ -102,13 +102,11 @@ class PGLoss(nn.Layer):
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2)) f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape N, B, _ = f_tcl_char_ld.shape
input_lengths = paddle.to_tensor([N] * B, dtype='int64') input_lengths = paddle.to_tensor([N] * B, dtype='int64')
cost = paddle.nn.functional.ctc_loss( loss_out = paddle.fluid.layers.warpctc(f_tcl_char_ld, tcl_label,
log_probs=f_tcl_char_ld, self.pad_num, True,
labels=tcl_label, input_lengths, label_t)
input_lengths=input_lengths,
label_lengths=label_t, cost = paddle.fluid.layers.squeeze(loss_out, [-1])
blank=self.pad_num,
reduction='none')
cost = cost.mean() cost = cost.mean()
return cost return cost
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册