diff --git a/.gitignore b/.gitignore index caf886a2b581b5976ea391aca5d7a56041bdbaa8..3300be325f1f6c8b2b58301fc87a4f9d241afb84 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,8 @@ __pycache__/ inference/ inference_results/ output/ - +train_data/ +log/ *.DS_Store *.vs *.user diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index 440c2d8c3eeef7379a5807db69107edfd2efdbb4..9e67f0cc54818aa5f6e6e7163277c05422f0754f 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -28,7 +28,7 @@ from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPoin from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \ QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \ - QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView + QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView, QSizePolicy __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -227,6 +227,21 @@ class MainWindow(QMainWindow): listLayout.addWidget(leftTopToolBoxContainer) # ================== Label List ================== + labelIndexListlBox = QHBoxLayout() + + # Create and add a widget for showing current label item index + self.indexList = QListWidget() + self.indexList.setMaximumSize(40, 16777215) # limit max width + self.indexList.setEditTriggers(QAbstractItemView.NoEditTriggers) # no editable + self.indexList.itemSelectionChanged.connect(self.indexSelectionChanged) + self.indexList.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # no scroll Bar + self.indexListDock = QDockWidget('No.', self) + self.indexListDock.setWidget(self.indexList) + self.indexListDock.setFeatures(QDockWidget.NoDockWidgetFeatures) + labelIndexListlBox.addWidget(self.indexListDock, 1) + # no margin between two boxes + labelIndexListlBox.setSpacing(0) + # Create and add a widget for showing current label items self.labelList = EditInList() labelListContainer = QWidget() @@ -240,8 +255,8 @@ class MainWindow(QMainWindow): self.labelListDock = QDockWidget(self.labelListDockName, self) self.labelListDock.setWidget(self.labelList) self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures) - listLayout.addWidget(self.labelListDock) - + labelIndexListlBox.addWidget(self.labelListDock, 10) # label list is wider than index list + # enable labelList drag_drop to adjust bbox order # 设置选择模式为单选 self.labelList.setSelectionMode(QAbstractItemView.SingleSelection) @@ -256,6 +271,17 @@ class MainWindow(QMainWindow): # 触发放置 self.labelList.model().rowsMoved.connect(self.drag_drop_happened) + labelIndexListContainer = QWidget() + labelIndexListContainer.setLayout(labelIndexListlBox) + listLayout.addWidget(labelIndexListContainer) + + # labelList indexList同步滚动 + self.labelListBar = self.labelList.verticalScrollBar() + self.indexListBar = self.indexList.verticalScrollBar() + + self.labelListBar.valueChanged.connect(self.move_scrollbar) + self.indexListBar.valueChanged.connect(self.move_scrollbar) + # ================== Detection Box ================== self.BoxList = QListWidget() @@ -766,6 +792,7 @@ class MainWindow(QMainWindow): self.shapesToItemsbox.clear() self.labelList.clear() self.BoxList.clear() + self.indexList.clear() self.filePath = None self.imageData = None self.labelFile = None @@ -1027,13 +1054,19 @@ class MainWindow(QMainWindow): for shape in self.canvas.selectedShapes: shape.selected = False self.labelList.clearSelection() + self.indexList.clearSelection() self.canvas.selectedShapes = selected_shapes for shape in self.canvas.selectedShapes: shape.selected = True self.shapesToItems[shape].setSelected(True) self.shapesToItemsbox[shape].setSelected(True) + index = self.labelList.indexFromItem(self.shapesToItems[shape]).row() + self.indexList.item(index).setSelected(True) self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible + # map current label item to index item and select it + index = self.labelList.indexFromItem(self.currentItem()).row() + self.indexList.scrollToItem(self.indexList.item(index)) self.BoxList.scrollToItem(self.currentBox()) if self.kie_mode: @@ -1066,12 +1099,18 @@ class MainWindow(QMainWindow): shape.paintIdx = self.displayIndexOption.isChecked() item = HashableQListWidgetItem(shape.label) - item.setFlags(item.flags() | Qt.ItemIsUserCheckable) - item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked) + # current difficult checkbox is disenble + # item.setFlags(item.flags() | Qt.ItemIsUserCheckable) + # item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked) + # Checked means difficult is False # item.setBackground(generateColorByText(shape.label)) self.itemsToShapes[item] = shape self.shapesToItems[shape] = item + # add current label item index before label string + current_index = QListWidgetItem(str(self.labelList.count())) + current_index.setTextAlignment(Qt.AlignHCenter) + self.indexList.addItem(current_index) self.labelList.addItem(item) # print('item in add label is ',[(p.x(), p.y()) for p in shape.points], shape.label) @@ -1105,6 +1144,7 @@ class MainWindow(QMainWindow): del self.shapesToItemsbox[shape] del self.itemsToShapesbox[item] self.updateComboBox() + self.updateIndexList() def loadLabels(self, shapes): s = [] @@ -1156,6 +1196,13 @@ class MainWindow(QMainWindow): # self.comboBox.update_items(uniqueTextList) + def updateIndexList(self): + self.indexList.clear() + for i in range(self.labelList.count()): + string = QListWidgetItem(str(i)) + string.setTextAlignment(Qt.AlignHCenter) + self.indexList.addItem(string) + def saveLabels(self, annotationFilePath, mode='Auto'): # Mode is Auto means that labels will be loaded from self.result_dic totally, which is the output of ocr model annotationFilePath = ustr(annotationFilePath) @@ -1211,6 +1258,10 @@ class MainWindow(QMainWindow): # fix copy and delete # self.shapeSelectionChanged(True) + def move_scrollbar(self, value): + self.labelListBar.setValue(value) + self.indexListBar.setValue(value) + def labelSelectionChanged(self): if self._noSelectionSlot: return @@ -1223,6 +1274,21 @@ class MainWindow(QMainWindow): else: self.canvas.deSelectShape() + def indexSelectionChanged(self): + if self._noSelectionSlot: + return + if self.canvas.editing(): + selected_shapes = [] + for item in self.indexList.selectedItems(): + # map index item to label item + index = self.indexList.indexFromItem(item).row() + item = self.labelList.item(index) + selected_shapes.append(self.itemsToShapes[item]) + if selected_shapes: + self.canvas.selectShapes(selected_shapes) + else: + self.canvas.deSelectShape() + def boxSelectionChanged(self): if self._noSelectionSlot: # self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter) @@ -1517,6 +1583,7 @@ class MainWindow(QMainWindow): if self.labelList.count(): self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1)) self.labelList.item(self.labelList.count() - 1).setSelected(True) + self.indexList.item(self.labelList.count() - 1).setSelected(True) # show file list image count select_indexes = self.fileListWidget.selectedIndexes() @@ -2015,12 +2082,14 @@ class MainWindow(QMainWindow): for shape in self.canvas.shapes: shape.paintLabel = self.displayLabelOption.isChecked() shape.paintIdx = self.displayIndexOption.isChecked() + self.canvas.repaint() def togglePaintIndexOption(self): self.displayLabelOption.setChecked(False) for shape in self.canvas.shapes: shape.paintLabel = self.displayLabelOption.isChecked() shape.paintIdx = self.displayIndexOption.isChecked() + self.canvas.repaint() def toogleDrawSquare(self): self.canvas.setDrawingShapeToSquare(self.drawSquaresOption.isChecked()) @@ -2115,7 +2184,7 @@ class MainWindow(QMainWindow): self.init_key_list(self.Cachelabel) def reRecognition(self): - img = cv2.imread(self.filePath) + img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1) # org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]] if self.canvas.shapes: self.result_dic = [] @@ -2184,7 +2253,7 @@ class MainWindow(QMainWindow): QMessageBox.information(self, "Information", "Draw a box!") def singleRerecognition(self): - img = cv2.imread(self.filePath) + img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1) for shape in self.canvas.selectedShapes: box = [[int(p.x()), int(p.y())] for p in shape.points] if len(box) > 4: @@ -2254,6 +2323,7 @@ class MainWindow(QMainWindow): self.itemsToShapesbox.clear() # ADD self.shapesToItemsbox.clear() self.labelList.clear() + self.indexList.clear() self.BoxList.clear() self.result_dic = [] self.result_dic_locked = [] @@ -2665,6 +2735,7 @@ class MainWindow(QMainWindow): def undoShapeEdit(self): self.canvas.restoreShape() self.labelList.clear() + self.indexList.clear() self.BoxList.clear() self.loadShapes(self.canvas.shapes) self.actions.undo.setEnabled(self.canvas.isShapeRestorable) @@ -2674,6 +2745,7 @@ class MainWindow(QMainWindow): for shape in shapes: self.addLabel(shape) self.labelList.clearSelection() + self.indexList.clearSelection() self._noSelectionSlot = False self.canvas.loadShapes(shapes, replace=replace) print("loadShapes") # 1 diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 6edf0b9194ee59143e287394f505b60010ec6644..2f39fbd232fa4bcab4cd30622d21c56d11a72d31 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -101,7 +101,7 @@ Train: drop_last: False batch_size_per_card: 16 num_workers: 8 - use_shared_memory: False + use_shared_memory: True Eval: dataset: @@ -129,4 +129,4 @@ Eval: drop_last: False batch_size_per_card: 1 # must be 1 num_workers: 8 - use_shared_memory: False + use_shared_memory: True diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml new file mode 100644 index 0000000000000000000000000000000000000000..aea71388f703376120af4d0caf2fa8ccd4d92cce --- /dev/null +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -0,0 +1,116 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 50 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ./ppocr/utils/dict/spin_dict.txt + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4, 5] + values: [0.001, 0.0003, 0.00009, 0.000027] + clip_norm: 5 + +Architecture: + model_type: rec + algorithm: SPIN + in_channels: 1 + Transform: + name: GA_SPIN + offsets: True + default_type: 6 + loc_lr: 0.1 + stn: True + Backbone: + name: ResNet32 + out_channels: 512 + Neck: + name: SequenceEncoder + encoder_type: cascadernn + hidden_size: 256 + out_channels: [256, 512] + with_linear: True + Head: + name: SPINAttentionHead + hidden_size: 256 + + +Loss: + name: SPINAttentionLoss + ignore_index: 0 + +PostProcess: + name: SPINLabelDecode + use_space_char: False + + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 8 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 2 diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml index 1e6efe32dbc077e910a417a7616db7bcc7f7c825..b8daf3630755e61322665b6fc5f830e4a45875b8 100755 --- a/configs/table/table_master.yml +++ b/configs/table/table_master.yml @@ -104,7 +104,7 @@ Train: Eval: dataset: name: PubTabDataSet - data_dir: train_data/table/pubtabnet/train/ + data_dir: train_data/table/pubtabnet/val/ label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl] transforms: - DecodeImage: diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 5c7adc715f1a5e728d9320c62dc15c578d9f18bf..fbd3ce9ebccec0b1c2133b52e4aeb9d4d5e21114 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -69,6 +69,7 @@ - [x] [SVTR](./algorithm_rec_svtr.md) - [x] [ViTSTR](./algorithm_rec_vitstr.md) - [x] [ABINet](./algorithm_rec_abinet.md) +- [x] [SPIN](./algorithm_rec_spin.md) 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -89,6 +90,7 @@ |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md new file mode 100644 index 0000000000000000000000000000000000000000..c996992d2fa6297e6086ffae4bc36ad3e880873d --- /dev/null +++ b/doc/doc_ch/algorithm_rec_spin.md @@ -0,0 +1,112 @@ +# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) +> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou +> AAAI, 2020 + +SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。 +使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下: + +|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +训练 + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml +``` + +评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换: + +``` +python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att +``` +SPIN文本识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee +``` + + +### 4.2 C++推理 + +由于C++预处理后处理还未支持SPIN,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@article{2020SPIN, + title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition}, + author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou}, + journal={AAAI2020}, + year={2020}, +} +``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index f3c96b620c94c3b5f795b6117a7c6bcfcfa43b7a..a579d2447c52067e05d16af5e9d6cf50defc2b1c 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -68,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [SVTR](./algorithm_rec_svtr_en.md) - [x] [ViTSTR](./algorithm_rec_vitstr_en.md) - [x] [ABINet](./algorithm_rec_abinet_en.md) +- [x] [SPIN](./algorithm_rec_spin_en.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -88,6 +89,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md new file mode 100644 index 0000000000000000000000000000000000000000..43ab30ce7d96cbb64ddf87156fee3012d666b2bf --- /dev/null +++ b/doc/doc_en/algorithm_rec_spin_en.md @@ -0,0 +1,112 @@ +# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) +> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou +> AAAI, 2020 + +Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows: + +|Model|Backbone|config|Acc|Download link| +| --- | --- | --- | --- | --- | +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| + + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att +``` + +For SPIN text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@article{2020SPIN, + title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition}, + author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou}, + journal={AAAI2020}, + year={2020}, +} +``` diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index d82176282839bf76d34ed8a60d5e2e13ac6bbce6..d41eed9dfbd2980242e76fa8d8aae380a6594cd4 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -26,7 +26,7 @@ from .make_pse_gt import MakePseGt from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ - ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug + ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index a4087d53287fcd57f9c4992ba712c700f33b9981..97539faf232ec157340d3136d2efc0daca8deda8 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1216,3 +1216,36 @@ class ABINetLabelEncode(BaseRecLabelEncode): def add_special_char(self, dict_character): dict_character = [''] + dict_character return dict_character + +class SPINLabelEncode(AttnLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + lower=True, + **kwargs): + super(SPINLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + self.lower = lower + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + [self.end_str] + dict_character + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) > self.max_text_len: + return None + data['length'] = np.array(len(text)) + target = [0] + text + [1] + padded_text = [0 for _ in range(self.max_text_len + 2)] + + padded_text[:len(target)] = target + data['label'] = np.array(padded_text) + return data \ No newline at end of file diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 26773d0a516dfb0877453c7a5c8c8a2b5da92045..c5d8a3b2fd773a1877a788401a926d7fbca07adf 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -259,6 +259,49 @@ class PRENResizeImg(object): data['image'] = resized_img.astype(np.float32) return data +class SPINRecResizeImg(object): + def __init__(self, + image_shape, + interpolation=2, + mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5), + **kwargs): + self.image_shape = image_shape + + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.interpolation = interpolation + + def __call__(self, data): + img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # different interpolation type corresponding the OpenCV + if self.interpolation == 0: + interpolation = cv2.INTER_NEAREST + elif self.interpolation == 1: + interpolation = cv2.INTER_LINEAR + elif self.interpolation == 2: + interpolation = cv2.INTER_CUBIC + elif self.interpolation == 3: + interpolation = cv2.INTER_AREA + else: + raise Exception("Unsupported interpolation type !!!") + # Deal with the image error during image loading + if img is None: + return None + + img = cv2.resize(img, tuple(self.image_shape), interpolation) + img = np.array(img, np.float32) + img = np.expand_dims(img, -1) + img = img.transpose((2, 0, 1)) + # normalize the image + img = img.copy().astype(np.float32) + mean = np.float64(self.mean.reshape(1, -1)) + stdinv = 1 / np.float64(self.std.reshape(1, -1)) + img -= mean + img *= stdinv + data['image'] = img + return data class GrayRecResizeImg(object): def __init__(self, diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 62e0544ea94daaaff7d019e6a48e65a2d508aca0..30120ac56756edd38676c40c39f0130f1b07c3ef 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss from .rec_aster_loss import AsterLoss from .rec_pren_loss import PRENLoss from .rec_multi_loss import MultiLoss +from .rec_spin_att_loss import SPINAttentionLoss # cls loss from .cls_loss import ClsLoss @@ -62,7 +63,7 @@ def build_loss(config): 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', - 'TableMasterLoss' + 'TableMasterLoss', 'SPINAttentionLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..195780c7bfaf4aae5dd23bd72ace268bed9c1d4f --- /dev/null +++ b/ppocr/losses/rec_spin_att_loss.py @@ -0,0 +1,45 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + +'''This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR +''' + +class SPINAttentionLoss(nn.Layer): + def __init__(self, reduction='mean', ignore_index=-100, **kwargs): + super(SPINAttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index) + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + targets = targets[:, 1:] # remove [eos] in label + + label_lengths = batch[2].astype('int64') + batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[ + 1], predicts.shape[2] + assert len(targets.shape) == len(list(predicts.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) + targets = paddle.reshape(targets, [-1]) + + return {'loss': self.loss_func(inputs, targets)} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index f4094d796b1f14c955e5962936e86bd6b3f5ec78..d4f5b15f56d34a9f6a6501058179a643ac7e8318 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -32,6 +32,7 @@ def build_backbone(config, model_type): from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 + from .rec_resnet_32 import ResNet32 from .rec_resnet_45 import ResNet45 from .rec_resnet_aster import ResNet_ASTER from .rec_micronet import MicroNet @@ -41,7 +42,7 @@ def build_backbone(config, model_type): support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', - 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR' + 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd19251a3ed43a472d49f03743ead1491aa86ac --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_32.py @@ -0,0 +1,269 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.nn as nn + +__all__ = ["ResNet32"] + +conv_weight_attr = nn.initializer.KaimingNormal() + +class ResNet32(nn.Layer): + """ + Feature Extractor is proposed in FAN Ref [1] + + Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017 + """ + + def __init__(self, in_channels, out_channels=512): + """ + + Args: + in_channels (int): input channel + output_channel (int): output channel + """ + super(ResNet32, self).__init__() + self.out_channels = out_channels + self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3]) + + def forward(self, inputs): + """ + Args: + inputs: input feature + + Returns: + output feature + + """ + return self.ConvNet(inputs) + +class BasicBlock(nn.Layer): + """Res-net Basic Block""" + expansion = 1 + + def __init__(self, inplanes, planes, + stride=1, downsample=None, + norm_type='BN', **kwargs): + """ + Args: + inplanes (int): input channel + planes (int): channels of the middle feature + stride (int): stride of the convolution + downsample (int): type of the down_sample + norm_type (str): type of the normalization + **kwargs (None): backup parameter + """ + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2D(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + """ + + Args: + in_planes (int): input channel + out_planes (int): channels of the middle feature + stride (int): stride of the convolution + Returns: + nn.Layer: Conv2D with kernel = 3 + + """ + + return nn.Conv2D(in_planes, out_planes, + kernel_size=3, stride=stride, + padding=1, weight_attr=conv_weight_attr, + bias_attr=False) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + + return out + +class ResNet(nn.Layer): + """Res-Net network structure""" + def __init__(self, input_channel, + output_channel, block, layers): + """ + + Args: + input_channel (int): input channel + output_channel (int): output channel + block (BasicBlock): convolution block + layers (list): layers of the block + """ + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), + int(output_channel / 2), + output_channel, + output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16)) + self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn0_2 = nn.BatchNorm2D(self.inplanes) + self.relu = nn.ReLU() + + self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, + self.output_channel_block[0], + layers[0]) + self.conv1 = nn.Conv2D(self.output_channel_block[0], + self.output_channel_block[0], + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn1 = nn.BatchNorm2D(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, + self.output_channel_block[1], + layers[1], stride=1) + self.conv2 = nn.Conv2D(self.output_channel_block[1], + self.output_channel_block[1], + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False,) + self.bn2 = nn.BatchNorm2D(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2D(kernel_size=2, + stride=(2, 1), + padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], + layers[2], stride=1) + self.conv3 = nn.Conv2D(self.output_channel_block[2], + self.output_channel_block[2], + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn3 = nn.BatchNorm2D(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], + layers[3], stride=1) + self.conv4_1 = nn.Conv2D(self.output_channel_block[3], + self.output_channel_block[3], + kernel_size=2, stride=(2, 1), + padding=(0, 1), + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2D(self.output_channel_block[3], + self.output_channel_block[3], + kernel_size=2, stride=1, + padding=0, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + """ + + Args: + block (block): convolution block + planes (int): input channels + blocks (list): layers of the block + stride (int): stride of the convolution + + Returns: + nn.Sequential: the combination of the convolution block + + """ + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, + weight_attr=conv_weight_attr, + bias_attr=False), + nn.BatchNorm2D(planes * block.expansion), + ) + + layers = list() + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index fcd146efbc378faeebd42534a994836789974c32..b4f18b372058c539ae5949ced333ec7be122211f 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -33,6 +33,7 @@ def build_head(config): from .rec_aster_head import AsterHead from .rec_pren_head import PRENHead from .rec_multi_head import MultiHead + from .rec_spin_att_head import SPINAttentionHead from .rec_abinet_head import ABINetHead # cls head @@ -48,7 +49,7 @@ def build_head(config): 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', - 'MultiHead', 'ABINetHead', 'TableMasterHead' + 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead' ] #table head diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py new file mode 100644 index 0000000000000000000000000000000000000000..86e35e4339d8e1006cfe43d6cf4f2f7d231082c4 --- /dev/null +++ b/ppocr/modeling/heads/rec_spin_att_head.py @@ -0,0 +1,115 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class SPINAttentionHead(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(SPINAttentionHead, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = paddle.shape(inputs)[0] + num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence + + hidden = (paddle.zeros((batch_size, self.hidden_size)), + paddle.zeros((batch_size, self.hidden_size))) + output_hiddens = [] + if self.training: # for train + targets = targets[0] + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + char_onehots = None + outputs = None + alpha = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(outputs) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + next_input = probs_step.argmax(axis=1) + targets = next_input + if not self.training: + probs = paddle.nn.functional.softmax(probs, axis=2) + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index c8a774b8c543b9ccc14223c52f1b79ce690592f6..33be9400b34cb535d260881748e179c3df106caa 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -47,6 +47,56 @@ class EncoderWithRNN(nn.Layer): x, _ = self.lstm(x) return x +class BidirectionalLSTM(nn.Layer): + def __init__(self, input_size, + hidden_size, + output_size=None, + num_layers=1, + dropout=0, + direction=False, + time_major=False, + with_linear=False): + super(BidirectionalLSTM, self).__init__() + self.with_linear = with_linear + self.rnn = nn.LSTM(input_size, + hidden_size, + num_layers=num_layers, + dropout=dropout, + direction=direction, + time_major=time_major) + + # text recognition the specified structure LSTM with linear + if self.with_linear: + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input_feature): + recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + if self.with_linear: + output = self.linear(recurrent) # batch_size x T x output_size + return output + return recurrent + +class EncoderWithCascadeRNN(nn.Layer): + def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False): + super(EncoderWithCascadeRNN, self).__init__() + self.out_channels = out_channels[-1] + self.encoder = nn.LayerList( + [BidirectionalLSTM( + in_channels if i == 0 else out_channels[i - 1], + hidden_size, + output_size=out_channels[i], + num_layers=1, + direction='bidirectional', + with_linear=with_linear) + for i in range(num_layers)] + ) + + + def forward(self, x): + for i, l in enumerate(self.encoder): + x = l(x) + return x + class EncoderWithFC(nn.Layer): def __init__(self, in_channels, hidden_size): @@ -166,13 +216,17 @@ class SequenceEncoder(nn.Layer): 'reshape': Im2Seq, 'fc': EncoderWithFC, 'rnn': EncoderWithRNN, - 'svtr': EncoderWithSVTR + 'svtr': EncoderWithSVTR, + 'cascadernn': EncoderWithCascadeRNN } assert encoder_type in support_encoder_dict, '{} must in {}'.format( encoder_type, support_encoder_dict.keys()) if encoder_type == "svtr": self.encoder = support_encoder_dict[encoder_type]( self.encoder_reshape.out_channels, **kwargs) + elif encoder_type == 'cascadernn': + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels, hidden_size, **kwargs) else: self.encoder = support_encoder_dict[encoder_type]( self.encoder_reshape.out_channels, hidden_size) diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 405ab3cc6c0380654f61e42e523ddc85839139b3..7e4ffdf46854416f71e1c8f4e131d1f0283bb725 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -18,8 +18,10 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS from .stn import STN_ON + from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN - support_dict = ['TPS', 'STN_ON'] + + support_dict = ['TPS', 'STN_ON', 'GA_SPIN'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4719eb2162a02141620586bcb6a849ae16f3b62 --- /dev/null +++ b/ppocr/modeling/transforms/gaspin_transformer.py @@ -0,0 +1,284 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +import functools +from .tps import GridGenerator + +'''This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py +''' + +class SP_TransformerNetwork(nn.Layer): + """ + Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1] + Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021. + """ + + def __init__(self, nc=1, default_type=5): + """ Based on SPIN + Args: + nc (int): number of input channels (usually in 1 or 3) + default_type (int): the complexity of transformation intensities (by default set to 6 as the paper) + """ + super(SP_TransformerNetwork, self).__init__() + self.power_list = self.cal_K(default_type) + self.sigmoid = nn.Sigmoid() + self.bn = nn.InstanceNorm2D(nc) + + def cal_K(self, k=5): + """ + + Args: + k (int): the complexity of transformation intensities (by default set to 6 as the paper) + + Returns: + List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)] + + """ + from math import log + x = [] + if k != 0: + for i in range(1, k+1): + lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2) + upper = round(1/lower, 2) + x.append(lower) + x.append(upper) + x.append(1.00) + return x + + def forward(self, batch_I, weights, offsets, lambda_color=None): + """ + + Args: + batch_I (Tensor): batch of input images [batch_size x nc x I_height x I_width] + weights: + offsets: the predicted offset by AIN, a scalar + lambda_color: the learnable update gate \alpha in Equa. (5) as + g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets} + + Returns: + Tensor: transformed images by SPN as Equa. (4) in Ref. [1] + [batch_size x I_channel_num x I_r_height x I_r_width] + + """ + batch_I = (batch_I + 1) * 0.5 + if offsets is not None: + batch_I = batch_I*(1-lambda_color) + offsets*lambda_color + batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1) + batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1) + + batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1) + batch_weight_sum = self.bn(batch_weight_sum) + batch_weight_sum = self.sigmoid(batch_weight_sum) + batch_weight_sum = batch_weight_sum * 2 - 1 + return batch_weight_sum + +class GA_SPIN_Transformer(nn.Layer): + """ + Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1] + + + Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021. + """ + + def __init__(self, in_channels=1, + I_r_size=(32, 100), + offsets=False, + norm_type='BN', + default_type=6, + loc_lr=1, + stn=True): + """ + Args: + in_channels (int): channel of input features, + set it to 1 if the grayscale images and 3 if RGB input + I_r_size (tuple): size of rectified images (used in STN transformations) + offsets (bool): set it to False if use SPN w.o. AIN, + and set it to True if use SPIN (both with SPN and AIN) + norm_type (str): the normalization type of the module, + set it to 'BN' by default, 'IN' optionally + default_type (int): the K chromatic space, + set it to 3/5/6 depend on the complexity of transformation intensities + loc_lr (float): learning rate of location network + stn (bool): whther to use stn. + + """ + super(GA_SPIN_Transformer, self).__init__() + self.nc = in_channels + self.spt = True + self.offsets = offsets + self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN + self.I_r_size = I_r_size + self.out_channels = in_channels + if norm_type == 'BN': + norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True) + elif norm_type == 'IN': + norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False, + use_global_stats=False) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + + if self.spt: + self.sp_net = SP_TransformerNetwork(in_channels, + default_type) + self.spt_convnet = nn.Sequential( + # 32*100 + nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False), + norm_layer(32), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + # 16*50 + nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False), + norm_layer(64), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + # 8*25 + nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False), + norm_layer(128), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + # 4*12 + ) + self.stucture_fc1 = nn.Sequential( + nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False), + norm_layer(256), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False), + norm_layer(256), nn.ReLU(), # 2*6 + nn.MaxPool2D(kernel_size=2, stride=2), + nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False), + norm_layer(512), nn.ReLU(), # 1*3 + nn.AdaptiveAvgPool2D(1), + nn.Flatten(1, -1), # batch_size x 512 + nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)), + nn.BatchNorm1D(256), nn.ReLU() + ) + self.out_weight = 2*default_type+1 + self.spt_length = 2*default_type+1 + if offsets: + self.out_weight += 1 + if self.stn: + self.F = 20 + self.out_weight += self.F * 2 + self.GridGenerator = GridGenerator(self.F*2, self.F) + + # self.out_weight*=nc + # Init structure_fc2 in LocalizationNetwork + initial_bias = self.init_spin(default_type*2) + initial_bias = initial_bias.reshape(-1) + param_attr = ParamAttr( + learning_rate=loc_lr, + initializer=nn.initializer.Assign(np.zeros([256, self.out_weight]))) + bias_attr = ParamAttr( + learning_rate=loc_lr, + initializer=nn.initializer.Assign(initial_bias)) + self.stucture_fc2 = nn.Linear(256, self.out_weight, + weight_attr=param_attr, + bias_attr=bias_attr) + self.sigmoid = nn.Sigmoid() + + if offsets: + self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16, + 3, 1, 1, + bias_attr=False), + norm_layer(16), + nn.ReLU(),) + self.offset_fc2 = nn.Conv2D(16, in_channels, + 3, 1, 1) + self.pool = nn.MaxPool2D(2, 2) + + def init_spin(self, nz): + """ + Args: + nz (int): number of paired \betas exponents, which means the value of K x 2 + + """ + init_id = [0.00]*nz+[5.00] + if self.offsets: + init_id += [-5.00] + # init_id *=3 + init = np.array(init_id) + + if self.stn: + F = self.F + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + initial_bias = initial_bias.reshape(-1) + init = np.concatenate([init, initial_bias], axis=0) + return init + + def forward(self, x, return_weight=False): + """ + Args: + x (Tensor): input image batch + return_weight (bool): set to False by default, + if set to True return the predicted offsets of AIN, denoted as x_{offsets} + + Returns: + Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size + """ + + if self.spt: + feat = self.spt_convnet(x) + fc1 = self.stucture_fc1(feat) + sp_weight_fusion = self.stucture_fc2(fc1) + sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1]) + if self.offsets: # SPIN w. AIN + lambda_color = sp_weight_fusion[:, self.spt_length, 0] + lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + sp_weight = sp_weight_fusion[:, :self.spt_length, :] + offsets = self.pool(self.offset_fc2(self.offset_fc1(feat))) + + assert offsets.shape[2] == 2 # 2 + assert offsets.shape[3] == 6 # 16 + offsets = self.sigmoid(offsets) # v12 + + if return_weight: + return offsets + offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear') + + if self.stn: + batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2]) + build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0], + self.I_r_size[0], + self.I_r_size[1], + 2]) + + else: # SPIN w.o. AIN + sp_weight = sp_weight_fusion[:, :self.spt_length, :] + lambda_color, offsets = None, None + + if self.stn: + batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2]) + build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0], + self.I_r_size[0], + self.I_r_size[1], + 2]) + + x = self.sp_net(x, sp_weight, offsets, lambda_color) + if self.stn: + x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border') + return x diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 1d414eb2e8562925f461b0c6f6ce15774b81bb8f..eeebc5803f321df0d6709bb57a009692659bfe77 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -27,7 +27,8 @@ from .sast_postprocess import SASTPostProcess from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ - SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode + SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ + SPINLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -44,7 +45,7 @@ def build_post_process(config, global_config=None): 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', - 'TableMasterLabelDecode' + 'TableMasterLabelDecode', 'SPINLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index cc7c2cb379cc476943152507569f0b0066189c46..3fe29aabe58f42faa02d1b25b4255ba8a19b3ea3 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -667,3 +667,18 @@ class ABINetLabelDecode(NRTRLabelDecode): def add_special_char(self, dict_character): dict_character = [''] + dict_character return dict_character + +class SPINLabelDecode(AttnLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SPINLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + [self.end_str] + dict_character + return dict_character \ No newline at end of file diff --git a/ppocr/utils/dict/spin_dict.txt b/ppocr/utils/dict/spin_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..8ee8347fd9c85228a3cf46c810d4fc28ab05c492 --- /dev/null +++ b/ppocr/utils/dict/spin_dict.txt @@ -0,0 +1,68 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +: +( +' +- +, +% +> +. +[ +? +) +" += +_ +* +] +; +& ++ +$ +@ +/ +| +! +< +# +` +{ +~ +\ +} +^ \ No newline at end of file diff --git a/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt index cab0cb0aa390c7cb1efa6e5d3bc636e9c974acba..7a20df7f6c1e2eeee19f55cefa3320d34a62a701 100644 --- a/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt +++ b/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o +norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -51,3 +51,9 @@ null:null null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml index 6c63af6ae8d62e098dec35d5918291291f654e32..a1497ba8fa4790a53cd602829edf3240ff8dc51a 100644 --- a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml +++ b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml @@ -6,7 +6,7 @@ Global: print_batch_step: 10 save_model_dir: ./output/rec_pp-OCRv2_distillation save_epoch_step: 3 - eval_batch_step: [0, 2000] + eval_batch_step: [0, 200000] cal_metric_during_train: true pretrained_model: checkpoints: @@ -114,7 +114,7 @@ Train: name: SimpleDataSet data_dir: ./train_data/ic15_data/ label_file_list: - - ./train_data/ic15_data/rec_gt_train.txt + - ./train_data/ic15_data/rec_gt_train4w.txt transforms: - DecodeImage: img_mode: BGR diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt index df42b342ba5fa3947a69c2bde5548975ca92d857..a96b87dede1e1b4c7b3ed59c4bd9c0470402e7e2 100644 --- a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o +norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o Global.print_batch_step=4 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -51,3 +51,9 @@ null:null null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,32,320]}] +===========================train_benchmark_params========================== +batch_size:64 +fp_items:fp32|fp16 +epoch:1 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt index a69e0ab81ec6963228e9ab2e39c5bb1d730b6323..bf10aebe3e9aa67e30ce7a20cb07f376825e39ae 100644 --- a/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt +++ b/test_tipc/configs/ch_PP-OCRv3_det/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o +norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.print_batch_step=1 Train.loader.shuffle=false Global.eval_batch_step=[4000,400] pact_train:null fpgm_train:null distill_train:null @@ -51,3 +51,9 @@ null:null null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml index f704a1dfb5dc2335c353a495dfbc0ce42cf35bf4..ee884f668767ea1c96782072c729bbcc700674d1 100644 --- a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml +++ b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml @@ -153,7 +153,7 @@ Train: data_dir: ./train_data/ic15_data/ ext_op_transform_idx: 1 label_file_list: - - ./train_data/ic15_data/rec_gt_train_lite.txt + - ./train_data/ic15_data/rec_gt_train4w.txt transforms: - DecodeImage: img_mode: BGR @@ -183,7 +183,7 @@ Eval: name: SimpleDataSet data_dir: ./train_data/ic15_data label_file_list: - - ./train_data/ic15_data/rec_gt_test_lite.txt + - ./train_data/ic15_data/rec_gt_test.txt transforms: - DecodeImage: img_mode: BGR diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt index 1feb9d49fce69d92ef141c3a942f858fc68cfaab..420c6592d71653377c740c703bedeb8e048cfc03 100644 --- a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml -o +norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml -o Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -51,3 +51,9 @@ null:null null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,48,320]}] +===========================train_benchmark_params========================== +batch_size:128 +fp_items:fp32|fp16 +epoch:1 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt index 789ed4d23d9c1fa3997daceee0627218aecd4c73..3db816cc0887eb2efc195965498174e868bfc6ec 100644 --- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained +norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -50,4 +50,10 @@ null:null --benchmark:True null:null ===========================infer_benchmark_params========================== -random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] \ No newline at end of file +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt index f02b93926cac8116844142fe3ecb03959abb0530..36fdb1b91eceede0d692ff4c2680d1403ec86024 100644 --- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o +norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,32,100]}] +===========================train_benchmark_params========================== +batch_size:256 +fp_items:fp32|fp16 +epoch:3 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml b/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml index f512de808141a0a0e815f9477de80b893ae3c946..6728703a10c2eb4e19b3bbf2225f089324e7d5cd 100644 --- a/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml +++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml @@ -2,13 +2,13 @@ Global: use_gpu: false epoch_num: 5 log_smooth_window: 20 - print_batch_step: 1 + print_batch_step: 2 save_model_dir: ./output/db_mv3/ save_epoch_step: 1200 # evaluation is run every 2000 iterations - eval_batch_step: [0, 400] + eval_batch_step: [0, 30000] cal_metric_during_train: False - pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + pretrained_model: checkpoints: save_inference_dir: use_visualdl: False diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt index c16ca150029d03052396ca28a6396520e63b3f84..7b90a4078a0c30f9d5ecab60c82acbd4052821ea 100644 --- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o +norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o quant_train:null fpgm_train:null distill_train:null @@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py --benchmark:True null:null ===========================infer_benchmark_params========================== -random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] \ No newline at end of file +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt index 64c0cf455cdd058d1840a9ad1f86954293d2e219..9fc117d67c6c2c048b2c8797bc07be8c93b0d519 100644 --- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o +norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,32,100]}] +===========================train_benchmark_params========================== +batch_size:256 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt index 2c8aa953449c4b97790842bb90256280b8b20d9a..62303b7e511bc42444e3610fb17eaf74ccf0848a 100644 --- a/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt +++ b/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained +norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -52,8 +52,8 @@ null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ===========================train_benchmark_params========================== -batch_size:8|16 +batch_size:16 fp_items:fp32|fp16 -epoch:15 +epoch:4 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile -flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 \ No newline at end of file diff --git a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt index 151f2769cc2d97d6a3546f338383dd811aa06ace..11af0ad18e948d9fa1f325745988877125583658 100644 --- a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt +++ b/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o +norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o Global.print_batch_step=1 Train.loader.shuffle=false quant_export:null fpgm_export:null distill_train:null @@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py --benchmark:True null:null ===========================infer_benchmark_params========================== -random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] \ No newline at end of file +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 \ No newline at end of file diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml new file mode 100644 index 0000000000000000000000000000000000000000..3a513b8f38cd5abf800c86f8fbeda789cb3d056a --- /dev/null +++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml @@ -0,0 +1,139 @@ +Global: + use_gpu: true + epoch_num: 1500 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/det_r50_dcn_fce_ctw/ + save_epoch_step: 100 + # evaluation is run every 835 iterations + eval_batch_step: [0, 4000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_fce/predicts_fce.txt + + +Architecture: + model_type: det + algorithm: FCE + Transform: + Backbone: + name: ResNet_vd + layers: 50 + dcn_stage: [False, True, True, True] + out_indices: [1,2,3] + Neck: + name: FCEFPN + out_channels: 256 + has_extra_convs: False + extra_stage: 0 + Head: + name: FCEHead + fourier_degree: 5 +Loss: + name: FCELoss + fourier_degree: 5 + num_sample: 50 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0001 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: FCEPostProcess + scales: [8, 16, 32] + alpha: 1.0 + beta: 1.0 + fourier_degree: 5 + box_type: 'poly' + +Metric: + name: DetFCEMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + ignore_orientation: True + - DetLabelEncode: # Class handling label + - ColorJitter: + brightness: 0.142 + saturation: 0.5 + contrast: 0.5 + - RandomScaling: + - RandomCropFlip: + crop_ratio: 0.5 + - RandomCropPolyInstances: + crop_ratio: 0.8 + min_side_ratio: 0.3 + - RandomRotatePolyInstances: + rotate_ratio: 0.5 + max_angle: 30 + pad_with_fixed_color: False + - SquareResizePad: + target_size: 800 + pad_ratio: 0.6 + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - FCENetTargets: + fourier_degree: 5 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'p3_maps', 'p4_maps', 'p5_maps'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 6 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + ignore_orientation: True + - DetLabelEncode: # Class handling label + - DetResizeForTest: + limit_type: 'min' + limit_side_len: 736 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - Pad: + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d294fd3038f5506a28d637dbe1aba44b5da237b --- /dev/null +++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:det_r50_dcn_fce_ctw_v2.0 +python:python3.7 +gpu_list:0 +Global.use_gpu:True|True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o +infer_quant:False +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +--save_log_path:null +--benchmark:True +--det_algorithm:FCE +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:6 +fp_items:fp32|fp16 +epoch:1 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 \ No newline at end of file diff --git a/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml b/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml index c6b6fc3ed79d0d717fe3dbd4cb9c8559ff8f07c4..844f42e9ad2b4eafbbd829de5711132f8119671f 100644 --- a/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml +++ b/test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml @@ -20,7 +20,7 @@ Architecture: algorithm: EAST Transform: Backbone: - name: ResNet + name: ResNet_vd layers: 50 Neck: name: EASTFPN diff --git a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt index 24e4d760c37828c213741b9ff127d55df2f9335a..5ee445a6cd03dfb888d1bc73eb51481a014cbb36 100644 --- a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt +++ b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml -o +norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml -o Global.pretrained_model=pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy.pdparams Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -55,4 +55,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] batch_size:8 fp_items:fp32|fp16 epoch:2 ---profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile \ No newline at end of file +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 \ No newline at end of file diff --git a/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml b/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml index f7e60fd1968820ef093455473346a6b8f0f8d34e..c069d1f132db5fc512011ead0107cf4082c144be 100644 --- a/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml +++ b/test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml @@ -8,7 +8,7 @@ Global: # evaluation is run every 125 iterations eval_batch_step: [ 0,1000 ] cal_metric_during_train: False - pretrained_model: + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy save_inference_dir: use_visualdl: False @@ -20,7 +20,7 @@ Architecture: algorithm: PSE Transform: Backbone: - name: ResNet + name: ResNet_vd layers: 50 Neck: name: FPN diff --git a/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt index 53511e6ae21003cb9df6a92d3931577fbbef5b18..78d25f6b17e30d6b7b12ae3acc1b264febfa97da 100644 --- a/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt +++ b/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml -o +norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml -o Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -54,5 +54,6 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ===========================train_benchmark_params========================== batch_size:8 fp_items:fp32|fp16 -epoch:10 +epoch:2 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 \ No newline at end of file diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml index 281038b968a5bf829483882117d779ec7de1976d..5d8e84c95c477a639130a342c6c72345e97701da 100755 --- a/test_tipc/configs/en_table_structure/table_mv3.yml +++ b/test_tipc/configs/en_table_structure/table_mv3.yml @@ -6,7 +6,7 @@ Global: save_model_dir: ./output/table_mv3/ save_epoch_step: 3 # evaluation is run every 400 iterations after the 0th iteration - eval_batch_step: [0, 400] + eval_batch_step: [0, 40000] cal_metric_during_train: True pretrained_model: checkpoints: diff --git a/test_tipc/configs/en_table_structure/train_infer_python.txt b/test_tipc/configs/en_table_structure/train_infer_python.txt index d9f3b30e16c75281a929130d877b947a23c16190..633b6185d976ac61408283025bd4ba305187317d 100644 --- a/test_tipc/configs/en_table_structure/train_infer_python.txt +++ b/test_tipc/configs/en_table_structure/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./ppstructure/docs/table/table.jpg null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o +norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o Global.print_batch_step=1 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -27,7 +27,7 @@ null:null ===========================infer_params=========================== Global.save_inference_dir:./output/ Global.checkpoints: -norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o +norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o quant_export: fpgm_export: distill_export:null @@ -51,3 +51,9 @@ null:null null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,488,488]}] +===========================train_benchmark_params========================== +batch_size:32 +fp_items:fp32|fp16 +epoch:1 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml index 463e8d1d7d7f43ba7b48810e2a2e8552eb5e4fe3..b0ba615293153cc3bbeb5b47d053596306ee45e2 100644 --- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml +++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml @@ -6,7 +6,7 @@ Global: save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_epoch_step: 3 # evaluation is run every 2000 iterations - eval_batch_step: [0, 2000] + eval_batch_step: [0, 20000] cal_metric_during_train: True pretrained_model: checkpoints: diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt index 39bf9227902480ffe4ed37d454c21d6a163c41bd..4e34a6a525fb8104407d04c617db39934b84e140 100644 --- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt +++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o +norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -50,4 +50,10 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --benchmark:True null:null ===========================infer_benchmark_params========================== -random_infer_input:[{float32,[3,32,100]}] \ No newline at end of file +random_infer_input:[{float32,[3,32,100]}] +===========================train_benchmark_params========================== +batch_size:256 +fp_items:fp32|fp16 +epoch:4 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml new file mode 100644 index 0000000000000000000000000000000000000000..d0cb20481f56a093f96c3d13f5fa2c2d13ae0c69 --- /dev/null +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 50 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ./ppocr/utils/dict/spin_dict.txt + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4, 5] + values: [0.001, 0.0003, 0.00009, 0.000027] + + clip_norm: 5 + +Architecture: + model_type: rec + algorithm: SPIN + in_channels: 1 + Transform: + name: GA_SPIN + offsets: True + default_type: 6 + loc_lr: 0.1 + stn: True + Backbone: + name: ResNet32 + out_channels: 512 + Neck: + name: SequenceEncoder + encoder_type: cascadernn + hidden_size: 256 + out_channels: [256, 512] + with_linear: True + Head: + name: SPINAttentionHead + hidden_size: 256 + + +Loss: + name: SPINAttentionLoss + ignore_index: 0 + +PostProcess: + name: SPINLabelDecode + use_space_char: False + + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 1 diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..4915055a576f0a5c1f7b0935a31d1d3c266903a5 --- /dev/null +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:rec_r32_gaspin_bilstm_att +python:python +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./inference/rec_inference +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN" +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1|6 +--use_tensorrt:False|False +--precision:fp32|int8 +--rec_model_dir: +--image_dir:./inference/rec_inference +--save_log_path:./test/output/ +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,32,100]}] diff --git a/test_tipc/docs/benchmark_train.md b/test_tipc/docs/benchmark_train.md index ad2524c165da3079d24b2b1570a5111d152f8373..a7f95eb6c530e1c451bb400cdb193694e2aee5f6 100644 --- a/test_tipc/docs/benchmark_train.md +++ b/test_tipc/docs/benchmark_train.md @@ -51,3 +51,25 @@ train_log/ ├── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C1_log └── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C4_log ``` +## 3. 各模型单卡性能数据一览 + +*注:本节中的速度指标均使用单卡(1块Nvidia V100 16G GPU)测得。通常情况下。 + + +|模型名称|配置文件|大数据集 float32 fps |小数据集 float32 fps |diff |大数据集 float16 fps|小数据集 float16 fps| diff | 大数据集大小 | 小数据集大小 | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +| ch_ppocr_mobile_v2.0_det |[config](../configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt) | 53.836 | 53.343 / 53.914 / 52.785 |0.020940758 | 45.574 | 45.57 / 46.292 / 46.213 | 0.015596647 | 10,000| 2,000| +| ch_ppocr_mobile_v2.0_rec |[config](../configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt) | 2083.311 | 2043.194 / 2066.372 / 2093.317 |0.023944295 | 2153.261 | 2167.561 / 2165.726 / 2155.614| 0.005511725 | 600,000| 160,000| +| ch_ppocr_server_v2.0_det |[config](../configs/ch_ppocr_server_v2.0_det/train_infer_python.txt) | 20.716 | 20.739 / 20.807 / 20.755 |0.003268131 | 20.592 | 20.498 / 20.993 / 20.75| 0.023579288 | 10,000| 2,000| +| ch_ppocr_server_v2.0_rec |[config](../configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt) | 528.56 | 528.386 / 528.991 / 528.391 |0.001143687 | 1189.788 | 1190.007 / 1176.332 / 1192.084| 0.013213834 | 600,000| 160,000| +| ch_PP-OCRv2_det |[config](../configs/ch_PP-OCRv2_det/train_infer_python.txt) | 13.87 | 13.386 / 13.529 / 13.428 |0.010569887 | 17.847 | 17.746 / 17.908 / 17.96| 0.011915367 | 10,000| 2,000| +| ch_PP-OCRv2_rec |[config](../configs/ch_PP-OCRv2_rec/train_infer_python.txt) | 109.248 | 106.32 / 106.318 / 108.587 |0.020895687 | 117.491 | 117.62 / 117.757 / 117.726| 0.001163413 | 140,000| 40,000| +| det_mv3_db_v2.0 |[config](../configs/det_mv3_db_v2_0/train_infer_python.txt) | 61.802 | 62.078 / 61.802 / 62.008 |0.00444602 | 82.947 | 84.294 / 84.457 / 84.005| 0.005351836 | 10,000| 2,000| +| det_r50_vd_db_v2.0 |[config](../configs/det_r50_vd_db_v2.0/train_infer_python.txt) | 29.955 | 29.092 / 29.31 / 28.844 |0.015899011 | 51.097 |50.367 / 50.879 / 50.227| 0.012814717 | 10,000| 2,000| +| det_r50_vd_east_v2.0 |[config](../configs/det_r50_vd_east_v2.0/train_infer_python.txt) | 42.485 | 42.624 / 42.663 / 42.561 |0.00239083 | 67.61 |67.825/ 68.299/ 68.51| 0.00999854 | 10,000| 2,000| +| det_r50_vd_pse_v2.0 |[config](../configs/det_r50_vd_pse_v2.0/train_infer_python.txt) | 16.455 | 16.517 / 16.555 / 16.353 |0.012201752 | 27.02 |27.288 / 27.152 / 27.408| 0.009340339 | 10,000| 2,000| +| rec_mv3_none_bilstm_ctc_v2.0 |[config](../configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt) | 2288.358 | 2291.906 / 2293.725 / 2290.05 |0.001602197 | 2336.17 |2327.042 / 2328.093 / 2344.915| 0.007622025 | 600,000| 160,000| +| PP-Structure-table |[config](../configs/en_table_structure/train_infer_python.txt) | 14.151 | 14.077 / 14.23 / 14.25 |0.012140351 | 16.285 | 16.595 / 16.878 / 16.531 | 0.020559308 | 20,000| 5,000| +| det_r50_dcn_fce_ctw_v2.0 |[config](../configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt) | 14.057 | 14.029 / 14.02 / 14.014 |0.001069214 | 18.298 |18.411 / 18.376 / 18.331| 0.004345228 | 10,000| 2,000| +| ch_PP-OCRv3_det |[config](../configs/ch_PP-OCRv3_det/train_infer_python.txt) | 8.622 | 8.431 / 8.423 / 8.479|0.006604552 | 14.203 |14.346 14.468 14.23| 0.016450097 | 10,000| 2,000| +| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 73.627 | 72.46 / 73.575 / 73.704|0.016878324 | | | | 160,000| 40,000| \ No newline at end of file diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index ec6dece42a0126e6d05405b3262c1c1d24f0a376..cb3fa2440d9672ba113904bd1548d458491d1d8c 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -22,27 +22,79 @@ trainer_list=$(func_parser_value "${lines[14]}") if [ ${MODE} = "benchmark_train" ];then pip install -r requirements.txt - if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then - rm -rf ./train_data/icdar2015 + if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate - wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate - cd ./train_data/ && tar xf icdar2015.tar && cd ../ + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ + fi + if [[ ${model_name} =~ "ch_ppocr_server_v2.0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ + fi + if [[ ${model_name} =~ "ch_PP-OCRv2_det" ]];then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate + cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../ + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ fi if [[ ${model_name} =~ "det_r50_vd_east_v2_0" ]]; then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../ - wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate - cd ./train_data/ && tar xf icdar2015.tar && cd ../ + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ fi - if [[ ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then + if [[ ${model_name} =~ "det_r50_db_v2.0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate - wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate - cd ./train_data/ && tar xf icdar2015.tar && cd ../ + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ fi if [[ ${model_name} =~ "det_r18_db_v2_0" ]];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate - wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate - cd ./train_data/ && tar xf icdar2015.tar && cd ../ + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ + fi + if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_rec" || ${model_name} =~ "ch_ppocr_server_v2.0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2.0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then + rm -rf ./train_data/ic15_data_benckmark + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf ic15_data_benckmark.tar + ln -s ./ic15_data_benckmark ./ic15_data + cd ../ + fi + if [[ ${model_name} == "en_table_structure" ]];then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate + cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../ + rm -rf ./train_data/pubtabnet + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/pubtabnet_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf pubtabnet_benckmark.tar + ln -s ./pubtabnet_benckmark ./pubtabnet + cd ../ + fi + if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]]; then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate + cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar && cd ../ + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf icdar2015_benckmark.tar + ln -s ./icdar2015_benckmark ./icdar2015 + cd ../ fi fi @@ -137,6 +189,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../ fi + if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate + cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../ + fi elif [ ${MODE} = "whole_train_whole_infer" ];then wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate @@ -363,6 +419,10 @@ elif [ ${MODE} = "whole_infer" ];then wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../ fi + if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate + cd ./inference/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../ + fi if [[ ${model_name} =~ "en_table_structure" ]];then wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate diff --git a/tools/export_model.py b/tools/export_model.py index afecbff8cbb834a5aa5ef3ea1448cf04fbd8c3bb..69ac904c661fad77255c70563fdf1f16c5c29875 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -91,7 +91,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "NRTR": + elif arch_config["algorithm"] in ["NRTR", "SPIN"]: other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"), diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index a95f55596647acc0eaca9616b5630917d7ebdf3a..e6ba4d04ec32e0971e3c44f15b800c2cd2bc6c51 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -81,6 +81,12 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "SPIN": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -258,6 +264,22 @@ class TextRecognizer(object): return padding_im, resize_shape, pad_shape, valid_ratio + def resize_norm_img_spin(self, img): + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # return padding_im + img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC) + img = np.array(img, np.float32) + img = np.expand_dims(img, -1) + img = img.transpose((2, 0, 1)) + mean = [127.5] + std = [127.5] + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + mean = np.float32(mean.reshape(1, -1)) + stdinv = 1 / np.float32(std.reshape(1, -1)) + img -= mean + img *= stdinv + return img def resize_norm_img_svtr(self, img, image_shape): imgC, imgH, imgW = image_shape @@ -337,6 +359,10 @@ class TextRecognizer(object): self.rec_image_shape) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == 'SPIN': + norm_img = self.resize_norm_img_spin(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) elif self.rec_algorithm == "ABINet": norm_img = self.resize_norm_img_abinet( img_list[indices[ino]], self.rec_image_shape) diff --git a/tools/program.py b/tools/program.py index 1d83b46216ad62d59e7123c1b2d590d2a1aae5ac..0fa0e609bd14d07cc593786b3a3f760cb9b98500 100755 --- a/tools/program.py +++ b/tools/program.py @@ -154,6 +154,24 @@ def check_xpu(use_xpu): except Exception as e: pass +def to_float32(preds): + if isinstance(preds, dict): + for k in preds: + if isinstance(preds[k], dict) or isinstance(preds[k], list): + preds[k] = to_float32(preds[k]) + else: + preds[k] = preds[k].astype(paddle.float32) + elif isinstance(preds, list): + for k in range(len(preds)): + if isinstance(preds[k], dict): + preds[k] = to_float32(preds[k]) + elif isinstance(preds[k], list): + preds[k] = to_float32(preds[k]) + else: + preds[k] = preds[k].astype(paddle.float32) + else: + preds = preds.astype(paddle.float32) + return preds def train(config, train_dataloader, @@ -207,7 +225,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: @@ -252,13 +270,19 @@ def train(config, # use amp if scaler: - with paddle.amp.auto_cast(): + with paddle.amp.auto_cast(level='O2'): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: preds = model(batch) else: preds = model(images) + preds = to_float32(preds) + loss = loss_class(preds, batch) + avg_loss = loss['loss'] + scaled_avg_loss = scaler.scale(avg_loss) + scaled_avg_loss.backward() + scaler.minimize(optimizer, scaled_avg_loss) else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) @@ -266,15 +290,8 @@ def train(config, preds = model(batch) else: preds = model(images) - - loss = loss_class(preds, batch) - avg_loss = loss['loss'] - - if scaler: - scaled_avg_loss = scaler.scale(avg_loss) - scaled_avg_loss.backward() - scaler.minimize(optimizer, scaled_avg_loss) - else: + loss = loss_class(preds, batch) + avg_loss = loss['loss'] avg_loss.backward() optimizer.step() optimizer.clear_grad() @@ -579,7 +596,7 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', - 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster' + 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index b7c25e34231fb650fd2c7c89dc17320f561962f9..309d4bb9e6b0fbcc9dd93545877662d746ada086 100755 --- a/tools/train.py +++ b/tools/train.py @@ -157,6 +157,7 @@ def main(config, device, logger, vdl_writer): scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) + model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True) else: scaler = None