diff --git a/.gitignore b/.gitignore
index caf886a2b581b5976ea391aca5d7a56041bdbaa8..3300be325f1f6c8b2b58301fc87a4f9d241afb84 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,7 +10,8 @@ __pycache__/
inference/
inference_results/
output/
-
+train_data/
+log/
*.DS_Store
*.vs
*.user
diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index 440c2d8c3eeef7379a5807db69107edfd2efdbb4..9e67f0cc54818aa5f6e6e7163277c05422f0754f 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -28,7 +28,7 @@ from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPoin
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \
- QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView
+ QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView, QSizePolicy
__dir__ = os.path.dirname(os.path.abspath(__file__))
@@ -227,6 +227,21 @@ class MainWindow(QMainWindow):
listLayout.addWidget(leftTopToolBoxContainer)
# ================== Label List ==================
+ labelIndexListlBox = QHBoxLayout()
+
+ # Create and add a widget for showing current label item index
+ self.indexList = QListWidget()
+ self.indexList.setMaximumSize(40, 16777215) # limit max width
+ self.indexList.setEditTriggers(QAbstractItemView.NoEditTriggers) # no editable
+ self.indexList.itemSelectionChanged.connect(self.indexSelectionChanged)
+ self.indexList.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # no scroll Bar
+ self.indexListDock = QDockWidget('No.', self)
+ self.indexListDock.setWidget(self.indexList)
+ self.indexListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
+ labelIndexListlBox.addWidget(self.indexListDock, 1)
+ # no margin between two boxes
+ labelIndexListlBox.setSpacing(0)
+
# Create and add a widget for showing current label items
self.labelList = EditInList()
labelListContainer = QWidget()
@@ -240,8 +255,8 @@ class MainWindow(QMainWindow):
self.labelListDock = QDockWidget(self.labelListDockName, self)
self.labelListDock.setWidget(self.labelList)
self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
- listLayout.addWidget(self.labelListDock)
-
+ labelIndexListlBox.addWidget(self.labelListDock, 10) # label list is wider than index list
+
# enable labelList drag_drop to adjust bbox order
# 设置选择模式为单选
self.labelList.setSelectionMode(QAbstractItemView.SingleSelection)
@@ -256,6 +271,17 @@ class MainWindow(QMainWindow):
# 触发放置
self.labelList.model().rowsMoved.connect(self.drag_drop_happened)
+ labelIndexListContainer = QWidget()
+ labelIndexListContainer.setLayout(labelIndexListlBox)
+ listLayout.addWidget(labelIndexListContainer)
+
+ # labelList indexList同步滚动
+ self.labelListBar = self.labelList.verticalScrollBar()
+ self.indexListBar = self.indexList.verticalScrollBar()
+
+ self.labelListBar.valueChanged.connect(self.move_scrollbar)
+ self.indexListBar.valueChanged.connect(self.move_scrollbar)
+
# ================== Detection Box ==================
self.BoxList = QListWidget()
@@ -766,6 +792,7 @@ class MainWindow(QMainWindow):
self.shapesToItemsbox.clear()
self.labelList.clear()
self.BoxList.clear()
+ self.indexList.clear()
self.filePath = None
self.imageData = None
self.labelFile = None
@@ -1027,13 +1054,19 @@ class MainWindow(QMainWindow):
for shape in self.canvas.selectedShapes:
shape.selected = False
self.labelList.clearSelection()
+ self.indexList.clearSelection()
self.canvas.selectedShapes = selected_shapes
for shape in self.canvas.selectedShapes:
shape.selected = True
self.shapesToItems[shape].setSelected(True)
self.shapesToItemsbox[shape].setSelected(True)
+ index = self.labelList.indexFromItem(self.shapesToItems[shape]).row()
+ self.indexList.item(index).setSelected(True)
self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
+ # map current label item to index item and select it
+ index = self.labelList.indexFromItem(self.currentItem()).row()
+ self.indexList.scrollToItem(self.indexList.item(index))
self.BoxList.scrollToItem(self.currentBox())
if self.kie_mode:
@@ -1066,12 +1099,18 @@ class MainWindow(QMainWindow):
shape.paintIdx = self.displayIndexOption.isChecked()
item = HashableQListWidgetItem(shape.label)
- item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
- item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked)
+ # current difficult checkbox is disenble
+ # item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
+ # item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked)
+
# Checked means difficult is False
# item.setBackground(generateColorByText(shape.label))
self.itemsToShapes[item] = shape
self.shapesToItems[shape] = item
+ # add current label item index before label string
+ current_index = QListWidgetItem(str(self.labelList.count()))
+ current_index.setTextAlignment(Qt.AlignHCenter)
+ self.indexList.addItem(current_index)
self.labelList.addItem(item)
# print('item in add label is ',[(p.x(), p.y()) for p in shape.points], shape.label)
@@ -1105,6 +1144,7 @@ class MainWindow(QMainWindow):
del self.shapesToItemsbox[shape]
del self.itemsToShapesbox[item]
self.updateComboBox()
+ self.updateIndexList()
def loadLabels(self, shapes):
s = []
@@ -1156,6 +1196,13 @@ class MainWindow(QMainWindow):
# self.comboBox.update_items(uniqueTextList)
+ def updateIndexList(self):
+ self.indexList.clear()
+ for i in range(self.labelList.count()):
+ string = QListWidgetItem(str(i))
+ string.setTextAlignment(Qt.AlignHCenter)
+ self.indexList.addItem(string)
+
def saveLabels(self, annotationFilePath, mode='Auto'):
# Mode is Auto means that labels will be loaded from self.result_dic totally, which is the output of ocr model
annotationFilePath = ustr(annotationFilePath)
@@ -1211,6 +1258,10 @@ class MainWindow(QMainWindow):
# fix copy and delete
# self.shapeSelectionChanged(True)
+ def move_scrollbar(self, value):
+ self.labelListBar.setValue(value)
+ self.indexListBar.setValue(value)
+
def labelSelectionChanged(self):
if self._noSelectionSlot:
return
@@ -1223,6 +1274,21 @@ class MainWindow(QMainWindow):
else:
self.canvas.deSelectShape()
+ def indexSelectionChanged(self):
+ if self._noSelectionSlot:
+ return
+ if self.canvas.editing():
+ selected_shapes = []
+ for item in self.indexList.selectedItems():
+ # map index item to label item
+ index = self.indexList.indexFromItem(item).row()
+ item = self.labelList.item(index)
+ selected_shapes.append(self.itemsToShapes[item])
+ if selected_shapes:
+ self.canvas.selectShapes(selected_shapes)
+ else:
+ self.canvas.deSelectShape()
+
def boxSelectionChanged(self):
if self._noSelectionSlot:
# self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter)
@@ -1517,6 +1583,7 @@ class MainWindow(QMainWindow):
if self.labelList.count():
self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1))
self.labelList.item(self.labelList.count() - 1).setSelected(True)
+ self.indexList.item(self.labelList.count() - 1).setSelected(True)
# show file list image count
select_indexes = self.fileListWidget.selectedIndexes()
@@ -2015,12 +2082,14 @@ class MainWindow(QMainWindow):
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
+ self.canvas.repaint()
def togglePaintIndexOption(self):
self.displayLabelOption.setChecked(False)
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
+ self.canvas.repaint()
def toogleDrawSquare(self):
self.canvas.setDrawingShapeToSquare(self.drawSquaresOption.isChecked())
@@ -2115,7 +2184,7 @@ class MainWindow(QMainWindow):
self.init_key_list(self.Cachelabel)
def reRecognition(self):
- img = cv2.imread(self.filePath)
+ img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes:
self.result_dic = []
@@ -2184,7 +2253,7 @@ class MainWindow(QMainWindow):
QMessageBox.information(self, "Information", "Draw a box!")
def singleRerecognition(self):
- img = cv2.imread(self.filePath)
+ img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
for shape in self.canvas.selectedShapes:
box = [[int(p.x()), int(p.y())] for p in shape.points]
if len(box) > 4:
@@ -2254,6 +2323,7 @@ class MainWindow(QMainWindow):
self.itemsToShapesbox.clear() # ADD
self.shapesToItemsbox.clear()
self.labelList.clear()
+ self.indexList.clear()
self.BoxList.clear()
self.result_dic = []
self.result_dic_locked = []
@@ -2665,6 +2735,7 @@ class MainWindow(QMainWindow):
def undoShapeEdit(self):
self.canvas.restoreShape()
self.labelList.clear()
+ self.indexList.clear()
self.BoxList.clear()
self.loadShapes(self.canvas.shapes)
self.actions.undo.setEnabled(self.canvas.isShapeRestorable)
@@ -2674,6 +2745,7 @@ class MainWindow(QMainWindow):
for shape in shapes:
self.addLabel(shape)
self.labelList.clearSelection()
+ self.indexList.clearSelection()
self._noSelectionSlot = False
self.canvas.loadShapes(shapes, replace=replace)
print("loadShapes") # 1
diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml
index 6edf0b9194ee59143e287394f505b60010ec6644..2f39fbd232fa4bcab4cd30622d21c56d11a72d31 100644
--- a/configs/det/det_mv3_db.yml
+++ b/configs/det/det_mv3_db.yml
@@ -101,7 +101,7 @@ Train:
drop_last: False
batch_size_per_card: 16
num_workers: 8
- use_shared_memory: False
+ use_shared_memory: True
Eval:
dataset:
@@ -129,4 +129,4 @@ Eval:
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 8
- use_shared_memory: False
+ use_shared_memory: True
diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aea71388f703376120af4d0caf2fa8ccd4d92cce
--- /dev/null
+++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml
@@ -0,0 +1,116 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 50
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
+ save_epoch_step: 3
+ # evaluation is run every 2000 iterations after the 4000th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ decay_epochs: [3, 4, 5]
+ values: [0.001, 0.0003, 0.00009, 0.000027]
+ clip_norm: 5
+
+Architecture:
+ model_type: rec
+ algorithm: SPIN
+ in_channels: 1
+ Transform:
+ name: GA_SPIN
+ offsets: True
+ default_type: 6
+ loc_lr: 0.1
+ stn: True
+ Backbone:
+ name: ResNet32
+ out_channels: 512
+ Neck:
+ name: SequenceEncoder
+ encoder_type: cascadernn
+ hidden_size: 256
+ out_channels: [256, 512]
+ with_linear: True
+ Head:
+ name: SPINAttentionHead
+ hidden_size: 256
+
+
+Loss:
+ name: SPINAttentionLoss
+ ignore_index: 0
+
+PostProcess:
+ name: SPINLabelDecode
+ use_space_char: False
+
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 8
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 8
+ num_workers: 2
diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml
index 1e6efe32dbc077e910a417a7616db7bcc7f7c825..b8daf3630755e61322665b6fc5f830e4a45875b8 100755
--- a/configs/table/table_master.yml
+++ b/configs/table/table_master.yml
@@ -104,7 +104,7 @@ Train:
Eval:
dataset:
name: PubTabDataSet
- data_dir: train_data/table/pubtabnet/train/
+ data_dir: train_data/table/pubtabnet/val/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage:
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 5c7adc715f1a5e728d9320c62dc15c578d9f18bf..fbd3ce9ebccec0b1c2133b52e4aeb9d4d5e21114 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -69,6 +69,7 @@
- [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md)
+- [x] [SPIN](./algorithm_rec_spin.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -89,6 +90,7 @@
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md
new file mode 100644
index 0000000000000000000000000000000000000000..c996992d2fa6297e6086ffae4bc36ad3e880873d
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_spin.md
@@ -0,0 +1,112 @@
+# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
+> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
+> AAAI, 2020
+
+SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。
+使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
+
+训练
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+
+```
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+```
+
+评估
+
+```
+# GPU 评估, Global.pretrained_model 为待测权重
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+预测:
+
+```
+# 预测使用的配置文件必须与训练一致
+python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
+```
+SPIN文本识别模型推理,可以执行如下命令:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee
+```
+
+
+### 4.2 C++推理
+
+由于C++预处理后处理还未支持SPIN,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@article{2020SPIN,
+ title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
+ author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
+ journal={AAAI2020},
+ year={2020},
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index f3c96b620c94c3b5f795b6117a7c6bcfcfa43b7a..a579d2447c52067e05d16af5e9d6cf50defc2b1c 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -68,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md)
+- [x] [SPIN](./algorithm_rec_spin_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -88,6 +89,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..43ab30ce7d96cbb64ddf87156fee3012d666b2bf
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_spin_en.md
@@ -0,0 +1,112 @@
+# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
+> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
+> AAAI, 2020
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
+```
+
+For SPIN text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@article{2020SPIN,
+ title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
+ author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
+ journal={AAAI2020},
+ year={2020},
+}
+```
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index d82176282839bf76d34ed8a60d5e2e13ac6bbce6..d41eed9dfbd2980242e76fa8d8aae380a6594cd4 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -26,7 +26,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index a4087d53287fcd57f9c4992ba712c700f33b9981..97539faf232ec157340d3136d2efc0daca8deda8 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1216,3 +1216,36 @@ class ABINetLabelEncode(BaseRecLabelEncode):
def add_special_char(self, dict_character):
dict_character = [''] + dict_character
return dict_character
+
+class SPINLabelEncode(AttnLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=True,
+ **kwargs):
+ super(SPINLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+ self.lower = lower
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = [self.beg_str] + [self.end_str] + dict_character
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ target = [0] + text + [1]
+ padded_text = [0 for _ in range(self.max_text_len + 2)]
+
+ padded_text[:len(target)] = target
+ data['label'] = np.array(padded_text)
+ return data
\ No newline at end of file
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 26773d0a516dfb0877453c7a5c8c8a2b5da92045..c5d8a3b2fd773a1877a788401a926d7fbca07adf 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -259,6 +259,49 @@ class PRENResizeImg(object):
data['image'] = resized_img.astype(np.float32)
return data
+class SPINRecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ interpolation=2,
+ mean=(127.5, 127.5, 127.5),
+ std=(127.5, 127.5, 127.5),
+ **kwargs):
+ self.image_shape = image_shape
+
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.interpolation = interpolation
+
+ def __call__(self, data):
+ img = data['image']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # different interpolation type corresponding the OpenCV
+ if self.interpolation == 0:
+ interpolation = cv2.INTER_NEAREST
+ elif self.interpolation == 1:
+ interpolation = cv2.INTER_LINEAR
+ elif self.interpolation == 2:
+ interpolation = cv2.INTER_CUBIC
+ elif self.interpolation == 3:
+ interpolation = cv2.INTER_AREA
+ else:
+ raise Exception("Unsupported interpolation type !!!")
+ # Deal with the image error during image loading
+ if img is None:
+ return None
+
+ img = cv2.resize(img, tuple(self.image_shape), interpolation)
+ img = np.array(img, np.float32)
+ img = np.expand_dims(img, -1)
+ img = img.transpose((2, 0, 1))
+ # normalize the image
+ img = img.copy().astype(np.float32)
+ mean = np.float64(self.mean.reshape(1, -1))
+ stdinv = 1 / np.float64(self.std.reshape(1, -1))
+ img -= mean
+ img *= stdinv
+ data['image'] = img
+ return data
class GrayRecResizeImg(object):
def __init__(self,
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 62e0544ea94daaaff7d019e6a48e65a2d508aca0..30120ac56756edd38676c40c39f0130f1b07c3ef 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
+from .rec_spin_att_loss import SPINAttentionLoss
# cls loss
from .cls_loss import ClsLoss
@@ -62,7 +63,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss'
+ 'TableMasterLoss', 'SPINAttentionLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..195780c7bfaf4aae5dd23bd72ace268bed9c1d4f
--- /dev/null
+++ b/ppocr/losses/rec_spin_att_loss.py
@@ -0,0 +1,45 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+'''This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR
+'''
+
+class SPINAttentionLoss(nn.Layer):
+ def __init__(self, reduction='mean', ignore_index=-100, **kwargs):
+ super(SPINAttentionLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index)
+
+ def forward(self, predicts, batch):
+ targets = batch[1].astype("int64")
+ targets = targets[:, 1:] # remove [eos] in label
+
+ label_lengths = batch[2].astype('int64')
+ batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
+ 1], predicts.shape[2]
+ assert len(targets.shape) == len(list(predicts.shape)) - 1, \
+ "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+ inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
+ targets = paddle.reshape(targets, [-1])
+
+ return {'loss': self.loss_func(inputs, targets)}
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index f4094d796b1f14c955e5962936e86bd6b3f5ec78..d4f5b15f56d34a9f6a6501058179a643ac7e8318 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -32,6 +32,7 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
+ from .rec_resnet_32 import ResNet32
from .rec_resnet_45 import ResNet45
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
@@ -41,7 +42,7 @@ def build_backbone(config, model_type):
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
- 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR'
+ 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd19251a3ed43a472d49f03743ead1491aa86ac
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_32.py
@@ -0,0 +1,269 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.nn as nn
+
+__all__ = ["ResNet32"]
+
+conv_weight_attr = nn.initializer.KaimingNormal()
+
+class ResNet32(nn.Layer):
+ """
+ Feature Extractor is proposed in FAN Ref [1]
+
+ Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017
+ """
+
+ def __init__(self, in_channels, out_channels=512):
+ """
+
+ Args:
+ in_channels (int): input channel
+ output_channel (int): output channel
+ """
+ super(ResNet32, self).__init__()
+ self.out_channels = out_channels
+ self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs: input feature
+
+ Returns:
+ output feature
+
+ """
+ return self.ConvNet(inputs)
+
+class BasicBlock(nn.Layer):
+ """Res-net Basic Block"""
+ expansion = 1
+
+ def __init__(self, inplanes, planes,
+ stride=1, downsample=None,
+ norm_type='BN', **kwargs):
+ """
+ Args:
+ inplanes (int): input channel
+ planes (int): channels of the middle feature
+ stride (int): stride of the convolution
+ downsample (int): type of the down_sample
+ norm_type (str): type of the normalization
+ **kwargs (None): backup parameter
+ """
+ super(BasicBlock, self).__init__()
+ self.conv1 = self._conv3x3(inplanes, planes)
+ self.bn1 = nn.BatchNorm2D(planes)
+ self.conv2 = self._conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2D(planes)
+ self.relu = nn.ReLU()
+ self.downsample = downsample
+ self.stride = stride
+
+ def _conv3x3(self, in_planes, out_planes, stride=1):
+ """
+
+ Args:
+ in_planes (int): input channel
+ out_planes (int): channels of the middle feature
+ stride (int): stride of the convolution
+ Returns:
+ nn.Layer: Conv2D with kernel = 3
+
+ """
+
+ return nn.Conv2D(in_planes, out_planes,
+ kernel_size=3, stride=stride,
+ padding=1, weight_attr=conv_weight_attr,
+ bias_attr=False)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+class ResNet(nn.Layer):
+ """Res-Net network structure"""
+ def __init__(self, input_channel,
+ output_channel, block, layers):
+ """
+
+ Args:
+ input_channel (int): input channel
+ output_channel (int): output channel
+ block (BasicBlock): convolution block
+ layers (list): layers of the block
+ """
+ super(ResNet, self).__init__()
+
+ self.output_channel_block = [int(output_channel / 4),
+ int(output_channel / 2),
+ output_channel,
+ output_channel]
+
+ self.inplanes = int(output_channel / 8)
+ self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
+ self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn0_2 = nn.BatchNorm2D(self.inplanes)
+ self.relu = nn.ReLU()
+
+ self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.layer1 = self._make_layer(block,
+ self.output_channel_block[0],
+ layers[0])
+ self.conv1 = nn.Conv2D(self.output_channel_block[0],
+ self.output_channel_block[0],
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
+
+ self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.layer2 = self._make_layer(block,
+ self.output_channel_block[1],
+ layers[1], stride=1)
+ self.conv2 = nn.Conv2D(self.output_channel_block[1],
+ self.output_channel_block[1],
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,)
+ self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
+
+ self.maxpool3 = nn.MaxPool2D(kernel_size=2,
+ stride=(2, 1),
+ padding=(0, 1))
+ self.layer3 = self._make_layer(block, self.output_channel_block[2],
+ layers[2], stride=1)
+ self.conv3 = nn.Conv2D(self.output_channel_block[2],
+ self.output_channel_block[2],
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
+
+ self.layer4 = self._make_layer(block, self.output_channel_block[3],
+ layers[3], stride=1)
+ self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
+ self.output_channel_block[3],
+ kernel_size=2, stride=(2, 1),
+ padding=(0, 1),
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
+ self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
+ self.output_channel_block[3],
+ kernel_size=2, stride=1,
+ padding=0,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ """
+
+ Args:
+ block (block): convolution block
+ planes (int): input channels
+ blocks (list): layers of the block
+ stride (int): stride of the convolution
+
+ Returns:
+ nn.Sequential: the combination of the convolution block
+
+ """
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2D(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride,
+ weight_attr=conv_weight_attr,
+ bias_attr=False),
+ nn.BatchNorm2D(planes * block.expansion),
+ )
+
+ layers = list()
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv0_1(x)
+ x = self.bn0_1(x)
+ x = self.relu(x)
+ x = self.conv0_2(x)
+ x = self.bn0_2(x)
+ x = self.relu(x)
+
+ x = self.maxpool1(x)
+ x = self.layer1(x)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.maxpool2(x)
+ x = self.layer2(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ x = self.maxpool3(x)
+ x = self.layer3(x)
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.relu(x)
+
+ x = self.layer4(x)
+ x = self.conv4_1(x)
+ x = self.bn4_1(x)
+ x = self.relu(x)
+ x = self.conv4_2(x)
+ x = self.bn4_2(x)
+ x = self.relu(x)
+ return x
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index fcd146efbc378faeebd42534a994836789974c32..b4f18b372058c539ae5949ced333ec7be122211f 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -33,6 +33,7 @@ def build_head(config):
from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
+ from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
# cls head
@@ -48,7 +49,7 @@ def build_head(config):
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
- 'MultiHead', 'ABINetHead', 'TableMasterHead'
+ 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead'
]
#table head
diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..86e35e4339d8e1006cfe43d6cf4f2f7d231082c4
--- /dev/null
+++ b/ppocr/modeling/heads/rec_spin_att_head.py
@@ -0,0 +1,115 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class SPINAttentionHead(nn.Layer):
+ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+ super(SPINAttentionHead, self).__init__()
+ self.input_size = in_channels
+ self.hidden_size = hidden_size
+ self.num_classes = out_channels
+
+ self.attention_cell = AttentionLSTMCell(
+ in_channels, hidden_size, out_channels, use_gru=False)
+ self.generator = nn.Linear(hidden_size, out_channels)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, batch_max_length=25):
+ batch_size = paddle.shape(inputs)[0]
+ num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
+
+ hidden = (paddle.zeros((batch_size, self.hidden_size)),
+ paddle.zeros((batch_size, self.hidden_size)))
+ output_hiddens = []
+ if self.training: # for train
+ targets = targets[0]
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets[:, i], onehot_dim=self.num_classes)
+ (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ probs = self.generator(output)
+ else:
+ targets = paddle.zeros(shape=[batch_size], dtype="int32")
+ probs = None
+ char_onehots = None
+ outputs = None
+ alpha = None
+
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets, onehot_dim=self.num_classes)
+ (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ probs_step = self.generator(outputs)
+ if probs is None:
+ probs = paddle.unsqueeze(probs_step, axis=1)
+ else:
+ probs = paddle.concat(
+ [probs, paddle.unsqueeze(
+ probs_step, axis=1)], axis=1)
+ next_input = probs_step.argmax(axis=1)
+ targets = next_input
+ if not self.training:
+ probs = paddle.nn.functional.softmax(probs, axis=2)
+ return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionLSTMCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ if not use_gru:
+ self.rnn = nn.LSTMCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ else:
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+
+ return cur_hidden, alpha
diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py
index c8a774b8c543b9ccc14223c52f1b79ce690592f6..33be9400b34cb535d260881748e179c3df106caa 100644
--- a/ppocr/modeling/necks/rnn.py
+++ b/ppocr/modeling/necks/rnn.py
@@ -47,6 +47,56 @@ class EncoderWithRNN(nn.Layer):
x, _ = self.lstm(x)
return x
+class BidirectionalLSTM(nn.Layer):
+ def __init__(self, input_size,
+ hidden_size,
+ output_size=None,
+ num_layers=1,
+ dropout=0,
+ direction=False,
+ time_major=False,
+ with_linear=False):
+ super(BidirectionalLSTM, self).__init__()
+ self.with_linear = with_linear
+ self.rnn = nn.LSTM(input_size,
+ hidden_size,
+ num_layers=num_layers,
+ dropout=dropout,
+ direction=direction,
+ time_major=time_major)
+
+ # text recognition the specified structure LSTM with linear
+ if self.with_linear:
+ self.linear = nn.Linear(hidden_size * 2, output_size)
+
+ def forward(self, input_feature):
+ recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
+ if self.with_linear:
+ output = self.linear(recurrent) # batch_size x T x output_size
+ return output
+ return recurrent
+
+class EncoderWithCascadeRNN(nn.Layer):
+ def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
+ super(EncoderWithCascadeRNN, self).__init__()
+ self.out_channels = out_channels[-1]
+ self.encoder = nn.LayerList(
+ [BidirectionalLSTM(
+ in_channels if i == 0 else out_channels[i - 1],
+ hidden_size,
+ output_size=out_channels[i],
+ num_layers=1,
+ direction='bidirectional',
+ with_linear=with_linear)
+ for i in range(num_layers)]
+ )
+
+
+ def forward(self, x):
+ for i, l in enumerate(self.encoder):
+ x = l(x)
+ return x
+
class EncoderWithFC(nn.Layer):
def __init__(self, in_channels, hidden_size):
@@ -166,13 +216,17 @@ class SequenceEncoder(nn.Layer):
'reshape': Im2Seq,
'fc': EncoderWithFC,
'rnn': EncoderWithRNN,
- 'svtr': EncoderWithSVTR
+ 'svtr': EncoderWithSVTR,
+ 'cascadernn': EncoderWithCascadeRNN
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, **kwargs)
+ elif encoder_type == 'cascadernn':
+ self.encoder = support_encoder_dict[encoder_type](
+ self.encoder_reshape.out_channels, hidden_size, **kwargs)
else:
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 405ab3cc6c0380654f61e42e523ddc85839139b3..7e4ffdf46854416f71e1c8f4e131d1f0283bb725 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -18,8 +18,10 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
from .stn import STN_ON
+ from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
- support_dict = ['TPS', 'STN_ON']
+
+ support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4719eb2162a02141620586bcb6a849ae16f3b62
--- /dev/null
+++ b/ppocr/modeling/transforms/gaspin_transformer.py
@@ -0,0 +1,284 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+import functools
+from .tps import GridGenerator
+
+'''This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
+'''
+
+class SP_TransformerNetwork(nn.Layer):
+ """
+ Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1]
+ Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
+ """
+
+ def __init__(self, nc=1, default_type=5):
+ """ Based on SPIN
+ Args:
+ nc (int): number of input channels (usually in 1 or 3)
+ default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
+ """
+ super(SP_TransformerNetwork, self).__init__()
+ self.power_list = self.cal_K(default_type)
+ self.sigmoid = nn.Sigmoid()
+ self.bn = nn.InstanceNorm2D(nc)
+
+ def cal_K(self, k=5):
+ """
+
+ Args:
+ k (int): the complexity of transformation intensities (by default set to 6 as the paper)
+
+ Returns:
+ List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)]
+
+ """
+ from math import log
+ x = []
+ if k != 0:
+ for i in range(1, k+1):
+ lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2)
+ upper = round(1/lower, 2)
+ x.append(lower)
+ x.append(upper)
+ x.append(1.00)
+ return x
+
+ def forward(self, batch_I, weights, offsets, lambda_color=None):
+ """
+
+ Args:
+ batch_I (Tensor): batch of input images [batch_size x nc x I_height x I_width]
+ weights:
+ offsets: the predicted offset by AIN, a scalar
+ lambda_color: the learnable update gate \alpha in Equa. (5) as
+ g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets}
+
+ Returns:
+ Tensor: transformed images by SPN as Equa. (4) in Ref. [1]
+ [batch_size x I_channel_num x I_r_height x I_r_width]
+
+ """
+ batch_I = (batch_I + 1) * 0.5
+ if offsets is not None:
+ batch_I = batch_I*(1-lambda_color) + offsets*lambda_color
+ batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
+ batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
+
+ batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1)
+ batch_weight_sum = self.bn(batch_weight_sum)
+ batch_weight_sum = self.sigmoid(batch_weight_sum)
+ batch_weight_sum = batch_weight_sum * 2 - 1
+ return batch_weight_sum
+
+class GA_SPIN_Transformer(nn.Layer):
+ """
+ Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
+
+
+ Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
+ """
+
+ def __init__(self, in_channels=1,
+ I_r_size=(32, 100),
+ offsets=False,
+ norm_type='BN',
+ default_type=6,
+ loc_lr=1,
+ stn=True):
+ """
+ Args:
+ in_channels (int): channel of input features,
+ set it to 1 if the grayscale images and 3 if RGB input
+ I_r_size (tuple): size of rectified images (used in STN transformations)
+ offsets (bool): set it to False if use SPN w.o. AIN,
+ and set it to True if use SPIN (both with SPN and AIN)
+ norm_type (str): the normalization type of the module,
+ set it to 'BN' by default, 'IN' optionally
+ default_type (int): the K chromatic space,
+ set it to 3/5/6 depend on the complexity of transformation intensities
+ loc_lr (float): learning rate of location network
+ stn (bool): whther to use stn.
+
+ """
+ super(GA_SPIN_Transformer, self).__init__()
+ self.nc = in_channels
+ self.spt = True
+ self.offsets = offsets
+ self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN
+ self.I_r_size = I_r_size
+ self.out_channels = in_channels
+ if norm_type == 'BN':
+ norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
+ elif norm_type == 'IN':
+ norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False,
+ use_global_stats=False)
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+
+ if self.spt:
+ self.sp_net = SP_TransformerNetwork(in_channels,
+ default_type)
+ self.spt_convnet = nn.Sequential(
+ # 32*100
+ nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
+ norm_layer(32), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 16*50
+ nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
+ norm_layer(64), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 8*25
+ nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
+ norm_layer(128), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 4*12
+ )
+ self.stucture_fc1 = nn.Sequential(
+ nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
+ norm_layer(256), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
+ norm_layer(256), nn.ReLU(), # 2*6
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
+ norm_layer(512), nn.ReLU(), # 1*3
+ nn.AdaptiveAvgPool2D(1),
+ nn.Flatten(1, -1), # batch_size x 512
+ nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
+ nn.BatchNorm1D(256), nn.ReLU()
+ )
+ self.out_weight = 2*default_type+1
+ self.spt_length = 2*default_type+1
+ if offsets:
+ self.out_weight += 1
+ if self.stn:
+ self.F = 20
+ self.out_weight += self.F * 2
+ self.GridGenerator = GridGenerator(self.F*2, self.F)
+
+ # self.out_weight*=nc
+ # Init structure_fc2 in LocalizationNetwork
+ initial_bias = self.init_spin(default_type*2)
+ initial_bias = initial_bias.reshape(-1)
+ param_attr = ParamAttr(
+ learning_rate=loc_lr,
+ initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])))
+ bias_attr = ParamAttr(
+ learning_rate=loc_lr,
+ initializer=nn.initializer.Assign(initial_bias))
+ self.stucture_fc2 = nn.Linear(256, self.out_weight,
+ weight_attr=param_attr,
+ bias_attr=bias_attr)
+ self.sigmoid = nn.Sigmoid()
+
+ if offsets:
+ self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16,
+ 3, 1, 1,
+ bias_attr=False),
+ norm_layer(16),
+ nn.ReLU(),)
+ self.offset_fc2 = nn.Conv2D(16, in_channels,
+ 3, 1, 1)
+ self.pool = nn.MaxPool2D(2, 2)
+
+ def init_spin(self, nz):
+ """
+ Args:
+ nz (int): number of paired \betas exponents, which means the value of K x 2
+
+ """
+ init_id = [0.00]*nz+[5.00]
+ if self.offsets:
+ init_id += [-5.00]
+ # init_id *=3
+ init = np.array(init_id)
+
+ if self.stn:
+ F = self.F
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ initial_bias = initial_bias.reshape(-1)
+ init = np.concatenate([init, initial_bias], axis=0)
+ return init
+
+ def forward(self, x, return_weight=False):
+ """
+ Args:
+ x (Tensor): input image batch
+ return_weight (bool): set to False by default,
+ if set to True return the predicted offsets of AIN, denoted as x_{offsets}
+
+ Returns:
+ Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size
+ """
+
+ if self.spt:
+ feat = self.spt_convnet(x)
+ fc1 = self.stucture_fc1(feat)
+ sp_weight_fusion = self.stucture_fc2(fc1)
+ sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1])
+ if self.offsets: # SPIN w. AIN
+ lambda_color = sp_weight_fusion[:, self.spt_length, 0]
+ lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+ offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
+
+ assert offsets.shape[2] == 2 # 2
+ assert offsets.shape[3] == 6 # 16
+ offsets = self.sigmoid(offsets) # v12
+
+ if return_weight:
+ return offsets
+ offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear')
+
+ if self.stn:
+ batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2])
+ build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
+ self.I_r_size[0],
+ self.I_r_size[1],
+ 2])
+
+ else: # SPIN w.o. AIN
+ sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+ lambda_color, offsets = None, None
+
+ if self.stn:
+ batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2])
+ build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
+ self.I_r_size[0],
+ self.I_r_size[1],
+ 2])
+
+ x = self.sp_net(x, sp_weight, offsets, lambda_color)
+ if self.stn:
+ x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
+ return x
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 1d414eb2e8562925f461b0c6f6ce15774b81bb8f..eeebc5803f321df0d6709bb57a009692659bfe77 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,8 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
+ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
+ SPINLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -44,7 +45,7 @@ def build_post_process(config, global_config=None):
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
- 'TableMasterLabelDecode'
+ 'TableMasterLabelDecode', 'SPINLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index cc7c2cb379cc476943152507569f0b0066189c46..3fe29aabe58f42faa02d1b25b4255ba8a19b3ea3 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -667,3 +667,18 @@ class ABINetLabelDecode(NRTRLabelDecode):
def add_special_char(self, dict_character):
dict_character = [''] + dict_character
return dict_character
+
+class SPINLabelDecode(AttnLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SPINLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = dict_character
+ dict_character = [self.beg_str] + [self.end_str] + dict_character
+ return dict_character
\ No newline at end of file
diff --git a/ppocr/utils/dict/spin_dict.txt b/ppocr/utils/dict/spin_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8ee8347fd9c85228a3cf46c810d4fc28ab05c492
--- /dev/null
+++ b/ppocr/utils/dict/spin_dict.txt
@@ -0,0 +1,68 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+:
+(
+'
+-
+,
+%
+>
+.
+[
+?
+)
+"
+=
+_
+*
+]
+;
+&
++
+$
+@
+/
+|
+!
+<
+#
+`
+{
+~
+\
+}
+^
\ No newline at end of file
diff --git a/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt
index cab0cb0aa390c7cb1efa6e5d3bc636e9c974acba..7a20df7f6c1e2eeee19f55cefa3320d34a62a701 100644
--- a/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -51,3 +51,9 @@ null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
index 6c63af6ae8d62e098dec35d5918291291f654e32..a1497ba8fa4790a53cd602829edf3240ff8dc51a 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
@@ -6,7 +6,7 @@ Global:
print_batch_step: 10
save_model_dir: ./output/rec_pp-OCRv2_distillation
save_epoch_step: 3
- eval_batch_step: [0, 2000]
+ eval_batch_step: [0, 200000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
@@ -114,7 +114,7 @@ Train:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list:
- - ./train_data/ic15_data/rec_gt_train.txt
+ - ./train_data/ic15_data/rec_gt_train4w.txt
transforms:
- DecodeImage:
img_mode: BGR
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
index df42b342ba5fa3947a69c2bde5548975ca92d857..a96b87dede1e1b4c7b3ed59c4bd9c0470402e7e2 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -51,3 +51,9 @@ null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,320]}]
+===========================train_benchmark_params==========================
+batch_size:64
+fp_items:fp32|fp16
+epoch:1
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt
index a69e0ab81ec6963228e9ab2e39c5bb1d730b6323..bf10aebe3e9aa67e30ce7a20cb07f376825e39ae 100644
--- a/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
+norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.print_batch_step=1 Train.loader.shuffle=false Global.eval_batch_step=[4000,400]
pact_train:null
fpgm_train:null
distill_train:null
@@ -51,3 +51,9 @@ null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
index f704a1dfb5dc2335c353a495dfbc0ce42cf35bf4..ee884f668767ea1c96782072c729bbcc700674d1 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
@@ -153,7 +153,7 @@ Train:
data_dir: ./train_data/ic15_data/
ext_op_transform_idx: 1
label_file_list:
- - ./train_data/ic15_data/rec_gt_train_lite.txt
+ - ./train_data/ic15_data/rec_gt_train4w.txt
transforms:
- DecodeImage:
img_mode: BGR
@@ -183,7 +183,7 @@ Eval:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list:
- - ./train_data/ic15_data/rec_gt_test_lite.txt
+ - ./train_data/ic15_data/rec_gt_test.txt
transforms:
- DecodeImage:
img_mode: BGR
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
index 1feb9d49fce69d92ef141c3a942f858fc68cfaab..420c6592d71653377c740c703bedeb8e048cfc03 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -51,3 +51,9 @@ null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,320]}]
+===========================train_benchmark_params==========================
+batch_size:128
+fp_items:fp32|fp16
+epoch:1
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
index 789ed4d23d9c1fa3997daceee0627218aecd4c73..3db816cc0887eb2efc195965498174e868bfc6ec 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
+norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -50,4 +50,10 @@ null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
index f02b93926cac8116844142fe3ecb03959abb0530..36fdb1b91eceede0d692ff4c2680d1403ec86024 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,100]}]
+===========================train_benchmark_params==========================
+batch_size:256
+fp_items:fp32|fp16
+epoch:3
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml b/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml
index f512de808141a0a0e815f9477de80b893ae3c946..6728703a10c2eb4e19b3bbf2225f089324e7d5cd 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml
@@ -2,13 +2,13 @@ Global:
use_gpu: false
epoch_num: 5
log_smooth_window: 20
- print_batch_step: 1
+ print_batch_step: 2
save_model_dir: ./output/db_mv3/
save_epoch_step: 1200
# evaluation is run every 2000 iterations
- eval_batch_step: [0, 400]
+ eval_batch_step: [0, 30000]
cal_metric_during_train: False
- pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
+ pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
index c16ca150029d03052396ca28a6396520e63b3f84..7b90a4078a0c30f9d5ecab60c82acbd4052821ea 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py
--benchmark:True
null:null
===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
index 64c0cf455cdd058d1840a9ad1f86954293d2e219..9fc117d67c6c2c048b2c8797bc07be8c93b0d519 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,100]}]
+===========================train_benchmark_params==========================
+batch_size:256
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt
index 2c8aa953449c4b97790842bb90256280b8b20d9a..62303b7e511bc42444e3610fb17eaf74ccf0848a 100644
--- a/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
+norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -52,8 +52,8 @@ null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
-batch_size:8|16
+batch_size:16
fp_items:fp32|fp16
-epoch:15
+epoch:4
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
-flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
diff --git a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
index 151f2769cc2d97d6a3546f338383dd811aa06ace..11af0ad18e948d9fa1f325745988877125583658 100644
--- a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
quant_export:null
fpgm_export:null
distill_train:null
@@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py
--benchmark:True
null:null
===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3a513b8f38cd5abf800c86f8fbeda789cb3d056a
--- /dev/null
+++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml
@@ -0,0 +1,139 @@
+Global:
+ use_gpu: true
+ epoch_num: 1500
+ log_smooth_window: 20
+ print_batch_step: 20
+ save_model_dir: ./output/det_r50_dcn_fce_ctw/
+ save_epoch_step: 100
+ # evaluation is run every 835 iterations
+ eval_batch_step: [0, 4000]
+ cal_metric_during_train: False
+ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_en/img_10.jpg
+ save_res_path: ./output/det_fce/predicts_fce.txt
+
+
+Architecture:
+ model_type: det
+ algorithm: FCE
+ Transform:
+ Backbone:
+ name: ResNet_vd
+ layers: 50
+ dcn_stage: [False, True, True, True]
+ out_indices: [1,2,3]
+ Neck:
+ name: FCEFPN
+ out_channels: 256
+ has_extra_convs: False
+ extra_stage: 0
+ Head:
+ name: FCEHead
+ fourier_degree: 5
+Loss:
+ name: FCELoss
+ fourier_degree: 5
+ num_sample: 50
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ learning_rate: 0.0001
+ regularizer:
+ name: 'L2'
+ factor: 0
+
+PostProcess:
+ name: FCEPostProcess
+ scales: [8, 16, 32]
+ alpha: 1.0
+ beta: 1.0
+ fourier_degree: 5
+ box_type: 'poly'
+
+Metric:
+ name: DetFCEMetric
+ main_indicator: hmean
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization/
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ ignore_orientation: True
+ - DetLabelEncode: # Class handling label
+ - ColorJitter:
+ brightness: 0.142
+ saturation: 0.5
+ contrast: 0.5
+ - RandomScaling:
+ - RandomCropFlip:
+ crop_ratio: 0.5
+ - RandomCropPolyInstances:
+ crop_ratio: 0.8
+ min_side_ratio: 0.3
+ - RandomRotatePolyInstances:
+ rotate_ratio: 0.5
+ max_angle: 30
+ pad_with_fixed_color: False
+ - SquareResizePad:
+ target_size: 800
+ pad_ratio: 0.6
+ - IaaAugment:
+ augmenter_args:
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
+ - FCENetTargets:
+ fourier_degree: 5
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'p3_maps', 'p4_maps', 'p5_maps'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ drop_last: False
+ batch_size_per_card: 6
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization/
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ ignore_orientation: True
+ - DetLabelEncode: # Class handling label
+ - DetResizeForTest:
+ limit_type: 'min'
+ limit_side_len: 736
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - Pad:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1 # must be 1
+ num_workers: 2
\ No newline at end of file
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2d294fd3038f5506a28d637dbe1aba44b5da237b
--- /dev/null
+++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt
@@ -0,0 +1,59 @@
+===========================train_params===========================
+model_name:det_r50_dcn_fce_ctw_v2.0
+python:python3.7
+gpu_list:0
+Global.use_gpu:True|True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
+infer_quant:False
+inference:tools/infer/predict_det.py
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+--save_log_path:null
+--benchmark:True
+--det_algorithm:FCE
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:6
+fp_items:fp32|fp16
+epoch:1
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
diff --git a/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml b/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml
index c6b6fc3ed79d0d717fe3dbd4cb9c8559ff8f07c4..844f42e9ad2b4eafbbd829de5711132f8119671f 100644
--- a/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml
+++ b/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml
@@ -20,7 +20,7 @@ Architecture:
algorithm: EAST
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 50
Neck:
name: EASTFPN
diff --git a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
index 24e4d760c37828c213741b9ff127d55df2f9335a..5ee445a6cd03dfb888d1bc73eb51481a014cbb36 100644
--- a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml -o Global.pretrained_model=pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy.pdparams Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -55,4 +55,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
batch_size:8
fp_items:fp32|fp16
epoch:2
---profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
\ No newline at end of file
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
diff --git a/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml b/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml
index f7e60fd1968820ef093455473346a6b8f0f8d34e..c069d1f132db5fc512011ead0107cf4082c144be 100644
--- a/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml
+++ b/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 125 iterations
eval_batch_step: [ 0,1000 ]
cal_metric_during_train: False
- pretrained_model:
+ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy
save_inference_dir:
use_visualdl: False
@@ -20,7 +20,7 @@ Architecture:
algorithm: PSE
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 50
Neck:
name: FPN
diff --git a/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt
index 53511e6ae21003cb9df6a92d3931577fbbef5b18..78d25f6b17e30d6b7b12ae3acc1b264febfa97da 100644
--- a/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -54,5 +54,6 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
-epoch:10
+epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml
index 281038b968a5bf829483882117d779ec7de1976d..5d8e84c95c477a639130a342c6c72345e97701da 100755
--- a/test_tipc/configs/en_table_structure/table_mv3.yml
+++ b/test_tipc/configs/en_table_structure/table_mv3.yml
@@ -6,7 +6,7 @@ Global:
save_model_dir: ./output/table_mv3/
save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration
- eval_batch_step: [0, 400]
+ eval_batch_step: [0, 40000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
diff --git a/test_tipc/configs/en_table_structure/train_infer_python.txt b/test_tipc/configs/en_table_structure/train_infer_python.txt
index d9f3b30e16c75281a929130d877b947a23c16190..633b6185d976ac61408283025bd4ba305187317d 100644
--- a/test_tipc/configs/en_table_structure/train_infer_python.txt
+++ b/test_tipc/configs/en_table_structure/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./ppstructure/docs/table/table.jpg
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
quant_export:
fpgm_export:
distill_export:null
@@ -51,3 +51,9 @@ null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,488,488]}]
+===========================train_benchmark_params==========================
+batch_size:32
+fp_items:fp32|fp16
+epoch:1
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
index 463e8d1d7d7f43ba7b48810e2a2e8552eb5e4fe3..b0ba615293153cc3bbeb5b47d053596306ee45e2 100644
--- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
+++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
@@ -6,7 +6,7 @@ Global:
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
save_epoch_step: 3
# evaluation is run every 2000 iterations
- eval_batch_step: [0, 2000]
+ eval_batch_step: [0, 20000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
index 39bf9227902480ffe4ed37d454c21d6a163c41bd..4e34a6a525fb8104407d04c617db39934b84e140 100644
--- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -50,4 +50,10 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--benchmark:True
null:null
===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,32,100]}]
\ No newline at end of file
+random_infer_input:[{float32,[3,32,100]}]
+===========================train_benchmark_params==========================
+batch_size:256
+fp_items:fp32|fp16
+epoch:4
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d0cb20481f56a093f96c3d13f5fa2c2d13ae0c69
--- /dev/null
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
@@ -0,0 +1,117 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 50
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ decay_epochs: [3, 4, 5]
+ values: [0.001, 0.0003, 0.00009, 0.000027]
+
+ clip_norm: 5
+
+Architecture:
+ model_type: rec
+ algorithm: SPIN
+ in_channels: 1
+ Transform:
+ name: GA_SPIN
+ offsets: True
+ default_type: 6
+ loc_lr: 0.1
+ stn: True
+ Backbone:
+ name: ResNet32
+ out_channels: 512
+ Neck:
+ name: SequenceEncoder
+ encoder_type: cascadernn
+ hidden_size: 256
+ out_channels: [256, 512]
+ with_linear: True
+ Head:
+ name: SPINAttentionHead
+ hidden_size: 256
+
+
+Loss:
+ name: SPINAttentionLoss
+ ignore_index: 0
+
+PostProcess:
+ name: SPINLabelDecode
+ use_space_char: False
+
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 1
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4915055a576f0a5c1f7b0935a31d1d3c266903a5
--- /dev/null
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_r32_gaspin_bilstm_att
+python:python
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1|6
+--use_tensorrt:False|False
+--precision:fp32|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,32,100]}]
diff --git a/test_tipc/docs/benchmark_train.md b/test_tipc/docs/benchmark_train.md
index ad2524c165da3079d24b2b1570a5111d152f8373..a7f95eb6c530e1c451bb400cdb193694e2aee5f6 100644
--- a/test_tipc/docs/benchmark_train.md
+++ b/test_tipc/docs/benchmark_train.md
@@ -51,3 +51,25 @@ train_log/
├── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C1_log
└── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C4_log
```
+## 3. 各模型单卡性能数据一览
+
+*注:本节中的速度指标均使用单卡(1块Nvidia V100 16G GPU)测得。通常情况下。
+
+
+|模型名称|配置文件|大数据集 float32 fps |小数据集 float32 fps |diff |大数据集 float16 fps|小数据集 float16 fps| diff | 大数据集大小 | 小数据集大小 |
+|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
+| ch_ppocr_mobile_v2.0_det |[config](../configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt) | 53.836 | 53.343 / 53.914 / 52.785 |0.020940758 | 45.574 | 45.57 / 46.292 / 46.213 | 0.015596647 | 10,000| 2,000|
+| ch_ppocr_mobile_v2.0_rec |[config](../configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt) | 2083.311 | 2043.194 / 2066.372 / 2093.317 |0.023944295 | 2153.261 | 2167.561 / 2165.726 / 2155.614| 0.005511725 | 600,000| 160,000|
+| ch_ppocr_server_v2.0_det |[config](../configs/ch_ppocr_server_v2.0_det/train_infer_python.txt) | 20.716 | 20.739 / 20.807 / 20.755 |0.003268131 | 20.592 | 20.498 / 20.993 / 20.75| 0.023579288 | 10,000| 2,000|
+| ch_ppocr_server_v2.0_rec |[config](../configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt) | 528.56 | 528.386 / 528.991 / 528.391 |0.001143687 | 1189.788 | 1190.007 / 1176.332 / 1192.084| 0.013213834 | 600,000| 160,000|
+| ch_PP-OCRv2_det |[config](../configs/ch_PP-OCRv2_det/train_infer_python.txt) | 13.87 | 13.386 / 13.529 / 13.428 |0.010569887 | 17.847 | 17.746 / 17.908 / 17.96| 0.011915367 | 10,000| 2,000|
+| ch_PP-OCRv2_rec |[config](../configs/ch_PP-OCRv2_rec/train_infer_python.txt) | 109.248 | 106.32 / 106.318 / 108.587 |0.020895687 | 117.491 | 117.62 / 117.757 / 117.726| 0.001163413 | 140,000| 40,000|
+| det_mv3_db_v2.0 |[config](../configs/det_mv3_db_v2_0/train_infer_python.txt) | 61.802 | 62.078 / 61.802 / 62.008 |0.00444602 | 82.947 | 84.294 / 84.457 / 84.005| 0.005351836 | 10,000| 2,000|
+| det_r50_vd_db_v2.0 |[config](../configs/det_r50_vd_db_v2.0/train_infer_python.txt) | 29.955 | 29.092 / 29.31 / 28.844 |0.015899011 | 51.097 |50.367 / 50.879 / 50.227| 0.012814717 | 10,000| 2,000|
+| det_r50_vd_east_v2.0 |[config](../configs/det_r50_vd_east_v2.0/train_infer_python.txt) | 42.485 | 42.624 / 42.663 / 42.561 |0.00239083 | 67.61 |67.825/ 68.299/ 68.51| 0.00999854 | 10,000| 2,000|
+| det_r50_vd_pse_v2.0 |[config](../configs/det_r50_vd_pse_v2.0/train_infer_python.txt) | 16.455 | 16.517 / 16.555 / 16.353 |0.012201752 | 27.02 |27.288 / 27.152 / 27.408| 0.009340339 | 10,000| 2,000|
+| rec_mv3_none_bilstm_ctc_v2.0 |[config](../configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt) | 2288.358 | 2291.906 / 2293.725 / 2290.05 |0.001602197 | 2336.17 |2327.042 / 2328.093 / 2344.915| 0.007622025 | 600,000| 160,000|
+| PP-Structure-table |[config](../configs/en_table_structure/train_infer_python.txt) | 14.151 | 14.077 / 14.23 / 14.25 |0.012140351 | 16.285 | 16.595 / 16.878 / 16.531 | 0.020559308 | 20,000| 5,000|
+| det_r50_dcn_fce_ctw_v2.0 |[config](../configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt) | 14.057 | 14.029 / 14.02 / 14.014 |0.001069214 | 18.298 |18.411 / 18.376 / 18.331| 0.004345228 | 10,000| 2,000|
+| ch_PP-OCRv3_det |[config](../configs/ch_PP-OCRv3_det/train_infer_python.txt) | 8.622 | 8.431 / 8.423 / 8.479|0.006604552 | 14.203 |14.346 14.468 14.23| 0.016450097 | 10,000| 2,000|
+| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 73.627 | 72.46 / 73.575 / 73.704|0.016878324 | | | | 160,000| 40,000|
\ No newline at end of file
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index ec6dece42a0126e6d05405b3262c1c1d24f0a376..cb3fa2440d9672ba113904bd1548d458491d1d8c 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -22,27 +22,79 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt
- if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then
- rm -rf ./train_data/icdar2015
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
- cd ./train_data/ && tar xf icdar2015.tar && cd ../
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
+ fi
+ if [[ ${model_name} =~ "ch_ppocr_server_v2.0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
+ fi
+ if [[ ${model_name} =~ "ch_PP-OCRv2_det" ]];then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
fi
if [[ ${model_name} =~ "det_r50_vd_east_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
- cd ./train_data/ && tar xf icdar2015.tar && cd ../
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
fi
- if [[ ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
+ if [[ ${model_name} =~ "det_r50_db_v2.0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
- cd ./train_data/ && tar xf icdar2015.tar && cd ../
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
fi
if [[ ${model_name} =~ "det_r18_db_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
- cd ./train_data/ && tar xf icdar2015.tar && cd ../
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
+ fi
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_rec" || ${model_name} =~ "ch_ppocr_server_v2.0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2.0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
+ rm -rf ./train_data/ic15_data_benckmark
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf ic15_data_benckmark.tar
+ ln -s ./ic15_data_benckmark ./ic15_data
+ cd ../
+ fi
+ if [[ ${model_name} == "en_table_structure" ]];then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
+ rm -rf ./train_data/pubtabnet
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/pubtabnet_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf pubtabnet_benckmark.tar
+ ln -s ./pubtabnet_benckmark ./pubtabnet
+ cd ../
+ fi
+ if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar && cd ../
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015_benckmark.tar
+ ln -s ./icdar2015_benckmark ./icdar2015
+ cd ../
fi
fi
@@ -137,6 +189,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
fi
+ if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
+ fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -363,6 +419,10 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../
fi
+ if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
+ fi
if [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
diff --git a/tools/export_model.py b/tools/export_model.py
index afecbff8cbb834a5aa5ef3ea1448cf04fbd8c3bb..69ac904c661fad77255c70563fdf1f16c5c29875 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -91,7 +91,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "NRTR":
+ elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index a95f55596647acc0eaca9616b5630917d7ebdf3a..e6ba4d04ec32e0971e3c44f15b800c2cd2bc6c51 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -81,6 +81,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ elif self.rec_algorithm == "SPIN":
+ postprocess_params = {
+ 'name': 'SARLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
@@ -258,6 +264,22 @@ class TextRecognizer(object):
return padding_im, resize_shape, pad_shape, valid_ratio
+ def resize_norm_img_spin(self, img):
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # return padding_im
+ img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
+ img = np.array(img, np.float32)
+ img = np.expand_dims(img, -1)
+ img = img.transpose((2, 0, 1))
+ mean = [127.5]
+ std = [127.5]
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ mean = np.float32(mean.reshape(1, -1))
+ stdinv = 1 / np.float32(std.reshape(1, -1))
+ img -= mean
+ img *= stdinv
+ return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -337,6 +359,10 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == 'SPIN':
+ norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape)
diff --git a/tools/program.py b/tools/program.py
index 1d83b46216ad62d59e7123c1b2d590d2a1aae5ac..0fa0e609bd14d07cc593786b3a3f760cb9b98500 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -154,6 +154,24 @@ def check_xpu(use_xpu):
except Exception as e:
pass
+def to_float32(preds):
+ if isinstance(preds, dict):
+ for k in preds:
+ if isinstance(preds[k], dict) or isinstance(preds[k], list):
+ preds[k] = to_float32(preds[k])
+ else:
+ preds[k] = preds[k].astype(paddle.float32)
+ elif isinstance(preds, list):
+ for k in range(len(preds)):
+ if isinstance(preds[k], dict):
+ preds[k] = to_float32(preds[k])
+ elif isinstance(preds[k], list):
+ preds[k] = to_float32(preds[k])
+ else:
+ preds[k] = preds[k].astype(paddle.float32)
+ else:
+ preds = preds.astype(paddle.float32)
+ return preds
def train(config,
train_dataloader,
@@ -207,7 +225,7 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
@@ -252,13 +270,19 @@ def train(config,
# use amp
if scaler:
- with paddle.amp.auto_cast():
+ with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
+ preds = to_float32(preds)
+ loss = loss_class(preds, batch)
+ avg_loss = loss['loss']
+ scaled_avg_loss = scaler.scale(avg_loss)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer, scaled_avg_loss)
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
@@ -266,15 +290,8 @@ def train(config,
preds = model(batch)
else:
preds = model(images)
-
- loss = loss_class(preds, batch)
- avg_loss = loss['loss']
-
- if scaler:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- else:
+ loss = loss_class(preds, batch)
+ avg_loss = loss['loss']
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
@@ -579,7 +596,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster'
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN'
]
if use_xpu:
diff --git a/tools/train.py b/tools/train.py
index b7c25e34231fb650fd2c7c89dc17320f561962f9..309d4bb9e6b0fbcc9dd93545877662d746ada086 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -157,6 +157,7 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
else:
scaler = None