diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index 4d9c52740a5ca5bcdd891bb55ff769f23e7a2499..8babac6558fa7c3b15779298630af03e47c97cb1 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -1031,7 +1031,7 @@ class MainWindow(QMainWindow, WindowMixin):
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
- if trans_dic["label"] is "" and mode == 'Auto':
+ if trans_dic["label"] == "" and mode == 'Auto':
continue
shapes.append(trans_dic)
@@ -1450,7 +1450,7 @@ class MainWindow(QMainWindow, WindowMixin):
item = QListWidgetItem(closeicon, filename)
self.fileListWidget.addItem(item)
- print('dirPath in importDirImages is', dirpath)
+ print('DirPath in importDirImages is', dirpath)
self.iconlist.clear()
self.additems5(dirpath)
self.changeFileFolder = True
@@ -1459,7 +1459,6 @@ class MainWindow(QMainWindow, WindowMixin):
self.reRecogButton.setEnabled(True)
self.actions.AutoRec.setEnabled(True)
self.actions.reRec.setEnabled(True)
- self.actions.saveLabel.setEnabled(True)
def openPrevImg(self, _value=False):
@@ -1764,7 +1763,7 @@ class MainWindow(QMainWindow, WindowMixin):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
- if result[0][0] is not '':
+ if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
self.result_dic.append(result)
@@ -1795,7 +1794,7 @@ class MainWindow(QMainWindow, WindowMixin):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
- if result[0][0] is not '':
+ if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
if result[1][0] == shape.label:
@@ -1862,6 +1861,8 @@ class MainWindow(QMainWindow, WindowMixin):
for each in states:
file, state = each.split('\t')
self.fileStatedict[file] = 1
+ self.actions.saveLabel.setEnabled(True)
+ self.actions.saveRec.setEnabled(True)
def saveFilestate(self):
@@ -1919,22 +1920,29 @@ class MainWindow(QMainWindow, WindowMixin):
rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/'
+ ques_img = []
if not os.path.exists(crop_img_dir):
os.mkdir(crop_img_dir)
with open(rec_gt_dir, 'w', encoding='utf-8') as f:
for key in self.fileStatedict:
idx = self.getImglabelidx(key)
- for i, label in enumerate(self.PPlabel[idx]):
- if label['difficult']: continue
+ try:
img = cv2.imread(key)
- img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
- img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
- cv2.imwrite(crop_img_dir+img_name, img_crop)
- f.write('crop_img/'+ img_name + '\t')
- f.write(label['transcription'] + '\n')
-
- QMessageBox.information(self, "Information", "Cropped images has been saved in "+str(crop_img_dir))
+ for i, label in enumerate(self.PPlabel[idx]):
+ if label['difficult']: continue
+ img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
+ img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
+ cv2.imwrite(crop_img_dir+img_name, img_crop)
+ f.write('crop_img/'+ img_name + '\t')
+ f.write(label['transcription'] + '\n')
+ except Exception as e:
+ ques_img.append(key)
+ print("Can not read image ",e)
+ if ques_img:
+ QMessageBox.information(self, "Information", "The following images can not be saved, "
+ "please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img))
+ QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir))
def speedChoose(self):
if self.labelDialogOption.isChecked():
@@ -1991,7 +1999,7 @@ if __name__ == '__main__':
resource_file = './libs/resources.py'
if not os.path.exists(resource_file):
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
- assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
+ assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
"directory resources.py "
import libs.resources
sys.exit(main())
diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml
index 38f1e8691e6056ada01a2d5c19f70955e8117498..00c1db885e000d80ed3c3f42c2afbaa11c452ab5 100644
--- a/configs/rec/rec_mv3_none_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml
@@ -1,5 +1,5 @@
Global:
- use_gpu: true
+ use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
@@ -59,7 +59,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -78,7 +78,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml
index 33079ad48c94c217ef86ef3f245492a540559350..6711b1d23f843551d72e1dffc003637734727754 100644
--- a/configs/rec/rec_mv3_none_none_ctc.yml
+++ b/configs/rec/rec_mv3_none_none_ctc.yml
@@ -58,7 +58,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -77,7 +77,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
index 08f68939d4f1e6de1c3688652bd86f6556a43384..1b9fb0a08db7cfd68bf2deb35d9216d68b58a12e 100644
--- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
@@ -63,7 +63,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
index 4ad2ff89ef1e72c58c426670742bc2ada27cfc4a..e4d301a6a173ea772898c0528c4b3082670870ff 100644
--- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
@@ -58,7 +58,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -77,7 +77,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml
index 9c1eeb304f41d46e49cee350e5d659dd1e0c8b0e..4a17a004228185db7e52dd71aadcff36d407d2cf 100644
--- a/configs/rec/rec_r34_vd_none_none_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_none_ctc.yml
@@ -56,7 +56,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -75,7 +75,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
index aeded4926a6d09cf30210f2d348d2933461a06b1..62edf84379ec1be9ef5f7155b240099f5fbb7b00 100644
--- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
@@ -62,7 +62,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -81,7 +81,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ec7f170560f5309818d537953a93c180b9de0bb7
--- /dev/null
+++ b/configs/rec/rec_r50_fpn_srn.yml
@@ -0,0 +1,107 @@
+Global:
+ use_gpu: True
+ epoch_num: 72
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/rec/srn_new
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 5000]
+ # if pretrained_model is saved in static mode, load_static_weights must set to True
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path:
+ character_type: en
+ max_text_length: 25
+ num_heads: 8
+ infer_mode: False
+ use_space_char: False
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 10.0
+ lr:
+ learning_rate: 0.0001
+
+Architecture:
+ model_type: rec
+ algorithm: SRN
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: ResNetFPN
+ Head:
+ name: SRNHead
+ max_text_length: 25
+ num_heads: 8
+ num_encoder_TUs: 2
+ num_decoder_TUs: 4
+ hidden_dims: 512
+
+Loss:
+ name: SRNLoss
+
+PostProcess:
+ name: SRNLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/srn_train_data_duiqi
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - SRNRecResizeImg:
+ image_shape: [1, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image',
+ 'label',
+ 'length',
+ 'encoder_word_pos',
+ 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1',
+ 'gsrm_slf_attn_bias2'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ batch_size_per_card: 64
+ drop_last: False
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - SRNRecResizeImg:
+ image_shape: [1, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image',
+ 'label',
+ 'length',
+ 'encoder_word_pos',
+ 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1',
+ 'gsrm_slf_attn_bias2']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 32
+ num_workers: 4
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 59d1bc8c444e3a70bbea83f87afcbd2f5cf44191..abbc5da4c21cf89466a5faef6cf6cb0c1eb14d13 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -41,7 +41,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
-- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
+- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -53,5 +53,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
+|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index c4601e1526d29e0a8c62030a4b47d2b2cc193d5d..0daddd9bb02d41c139f1f16b1fcd81c03f43f6ac 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -22,8 +22,9 @@ inference 模型(`paddle.jit.save`保存的模型)
- [三、文本识别模型推理](#文本识别模型推理)
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
- - [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- - [4. 多语言模型的推理](#多语言模型的推理)
+ - [3. 基于SRN损失的识别模型推理](#基于SRN损失的识别模型推理)
+ - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
+ - [5. 多语言模型的推理](#多语言模型的推理)
- [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
@@ -295,8 +296,20 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
+
+### 3. 基于SRN损失的识别模型推理
+基于SRN损失的识别模型,需要额外设置识别算法参数 --rec_algorithm="SRN"。
+同时需要保证预测shape与训练时一致,如: --rec_image_shape="1, 64, 256"
-### 3. 自定义文本识别字典的推理
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
+ --rec_model_dir="./inference/srn/" \
+ --rec_image_shape="1, 64, 256" \
+ --rec_char_type="en" \
+ --rec_algorithm="SRN"
+```
+
+### 4. 自定义文本识别字典的推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
```
@@ -304,7 +317,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
-### 4. 多语言模型的推理
+### 5. 多语言模型的推理
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index c5f459bdb88558b1cdea93b9b85eed0e4bb8433b..bc877ab78c583f04dd0bf740712457094325d00e 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -36,6 +36,7 @@ ln -sf /train_data/dataset
* 数据下载
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
+如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
* 使用自己数据集
@@ -200,6 +201,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
+| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 68bfd529972183208220b1c87227639d683fea62..7d7896e7144c9d2db28b29a6f16b23677925d67a 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -43,7 +43,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
-- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
+- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -55,5 +55,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
+|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index ccbb71847d5946e854b88817a162957af0e6ed00..c8ce1424f5451ca9ee22b9b49ac9b702be72826f 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -25,6 +25,7 @@ Next, we first introduce how to convert a trained model into an inference model,
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
+ - [3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE](#SRN-BASED_RECOGNITION)
- [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
@@ -304,8 +305,23 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
+
+### 3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE
+
+The recognition model based on SRN requires additional setting of the recognition algorithm parameter
+--rec_algorithm="SRN". At the same time, it is necessary to ensure that the predicted shape is consistent
+with the training, such as: --rec_image_shape="1, 64, 256"
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
+ --rec_model_dir="./inference/srn/" \
+ --rec_image_shape="1, 64, 256" \
+ --rec_char_type="en" \
+ --rec_algorithm="SRN"
+```
+
-### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
+### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
@@ -313,7 +329,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
-### 4. MULTILINGAUL MODEL INFERENCE
+### 5. MULTILINGAUL MODEL INFERENCE
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 22f89cdef080afe0b119d08d1e88f02ede5932c1..f29703d14454ce979ad4f7d8cda0d2768721b53d 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -195,6 +195,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
+| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 4809886b7bf820f47a4f234a345ea1fac8fa5d49..7cb50d7a62aa3f24811e517768e0635ac7b7321a 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
-from ppocr.data.lmdb_dataset import LMDBDateSet
+from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDateSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 6ea4dd8ed6d0f58fbee3362e8eb82a0eda65e812..250ac75e7683df2353d9fad02ef42b9e133681d3 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
-from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
+from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment
from .operators import *
from .label_ops import *
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 6d9ea1902fcbe1c902a1136861d6bdaa8954ba23..191bda92ca4efe8866af4c4d6e76e4fb5ffac38d 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -102,6 +102,8 @@ class BaseRecLabelEncode(object):
support_character_type, character_type)
self.max_text_len = max_text_length
+ self.beg_str = "sos"
+ self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@@ -231,3 +233,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class SRNLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length=25,
+ character_dict_path=None,
+ character_type='en',
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ character_type, use_space_char)
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ char_num = len(self.character_str)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [char_num] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 2ccb2d1d2b6780138098f08c78cce3be3e3b9ceb..28e6bd0bce768c45dbc334c15ace601fd6403f5d 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
import math
import cv2
import numpy as np
@@ -77,6 +63,26 @@ class RecResizeImg(object):
return data
+class SRNRecResizeImg(object):
+ def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
+ self.image_shape = image_shape
+ self.num_heads = num_heads
+ self.max_text_length = max_text_length
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img = resize_norm_img_srn(img, self.image_shape)
+ data['image'] = norm_img
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
+
+ data['encoder_word_pos'] = encoder_word_pos
+ data['gsrm_word_pos'] = gsrm_word_pos
+ data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
+ data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
+ return data
+
+
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
- max_wh_ratio = 0
+ max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im
+def resize_norm_img_srn(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+
+def srn_other_inputs(image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
+ [num_heads, 1, 1]) * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
+ [num_heads, 1, 1]) * [-1e9]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+
def flag():
"""
flag
diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py
index bd0630f6351d4e9e860f21b18f6503777a4d8679..e2d6dc9327bf3725d2fb6c32d18c0b71bd6ac408 100644
--- a/ppocr/data/lmdb_dataset.py
+++ b/ppocr/data/lmdb_dataset.py
@@ -20,9 +20,9 @@ import cv2
from .imaug import transform, create_operators
-class LMDBDateSet(Dataset):
+class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
- super(LMDBDateSet, self).__init__()
+ super(LMDBDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 94314235cf00353e687bb9bf27c1e26f3efd7c33..3881abf7741b8be78306bd070afb11df15606327 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -24,12 +24,14 @@ def build_loss(config):
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
+ from .rec_srn_loss import SRNLoss
# cls loss
from .cls_loss import ClsLoss
support_dict = [
- 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss'
+ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
+ 'SRNLoss'
]
config = copy.deepcopy(config)
diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d5b65ebaf1ee135d1fefe8d93ddc3f77985b132
--- /dev/null
+++ b/ppocr/losses/rec_srn_loss.py
@@ -0,0 +1,47 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SRNLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(SRNLoss, self).__init__()
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
+
+ def forward(self, predicts, batch):
+ predict = predicts['predict']
+ word_predict = predicts['word_out']
+ gsrm_predict = predicts['gsrm_out']
+ label = batch[1]
+
+ casted_label = paddle.cast(x=label, dtype='int64')
+ casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
+
+ cost_word = self.loss_func(word_predict, label=casted_label)
+ cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
+ cost_vsfd = self.loss_func(predict, label=casted_label)
+
+ cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
+ cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
+ cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
+
+ sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
+
+ return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index a86fc8382f40b5b73edc7ec8e9d4dbe3e5822283..b3aa9f38f8378eee7104d7e3696b86bede0de903 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target:
correct_num += 1
all_num += 1
- # if all_num < 10 and kwargs.get('show_str', False):
- # print('{} -> {}'.format(pred, target))
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
@@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
- acc = self.correct_num / self.all_num
+ acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index ab44b53a2bf8214b63954aa867b67a3ed1e05fab..09b6e0346d998e3b90762e6163e8a34b48daff36 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
- def forward(self, x):
+ def forward(self, x, data=None):
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
if self.use_neck:
x = self.neck(x)
- x = self.head(x)
+ if data is None:
+ x = self.head(x)
+ else:
+ x = self.head(x, data)
return x
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 43103e53d2413eb63c2831cfd54e91f357b3b496..03c15508a58313b234a72bb3ef47ac27dc3ebb7e 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
- support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
+ from .rec_resnet_fpn import ResNetFPN
+ support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e876a2bd52a0ea70479c2009a291e4e2f8ce1f
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_fpn.py
@@ -0,0 +1,307 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import paddle
+import numpy as np
+
+__all__ = ["ResNetFPN"]
+
+
+class ResNetFPN(nn.Layer):
+ def __init__(self, in_channels=1, layers=50, **kwargs):
+ super(ResNetFPN, self).__init__()
+ supported_layers = {
+ 18: {
+ 'depth': [2, 2, 2, 2],
+ 'block_class': BasicBlock
+ },
+ 34: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BasicBlock
+ },
+ 50: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BottleneckBlock
+ },
+ 101: {
+ 'depth': [3, 4, 23, 3],
+ 'block_class': BottleneckBlock
+ },
+ 152: {
+ 'depth': [3, 8, 36, 3],
+ 'block_class': BottleneckBlock
+ }
+ }
+ stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
+ num_filters = [64, 128, 256, 512]
+ self.depth = supported_layers[layers]['depth']
+ self.F = []
+ self.conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act="relu",
+ name="conv1")
+ self.block_list = []
+ in_ch = 64
+ if layers >= 50:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ block_list = self.add_sublayer(
+ "bottleneckBlock_{}_{}".format(block, i),
+ BottleneckBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ name=conv_name))
+ in_ch = num_filters[block] * 4
+ self.block_list.append(block_list)
+ self.F.append(block_list)
+ else:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+ basic_block = self.add_sublayer(
+ conv_name,
+ BasicBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ is_first=block == i == 0,
+ name=conv_name))
+ in_ch = basic_block.out_channels
+ self.block_list.append(basic_block)
+ out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
+ self.base_block = []
+ self.conv_trans = []
+ self.bn_block = []
+ for i in [-2, -3]:
+ in_channels = out_ch_list[i + 1] + out_ch_list[i]
+
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_0".format(i),
+ nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_ch_list[i],
+ kernel_size=1,
+ weight_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_1".format(i),
+ nn.Conv2D(
+ in_channels=out_ch_list[i],
+ out_channels=out_ch_list[i],
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_2".format(i),
+ nn.BatchNorm(
+ num_channels=out_ch_list[i],
+ act="relu",
+ param_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_3".format(i),
+ nn.Conv2D(
+ in_channels=out_ch_list[i],
+ out_channels=512,
+ kernel_size=1,
+ bias_attr=ParamAttr(trainable=True),
+ weight_attr=ParamAttr(trainable=True))))
+ self.out_channels = 512
+
+ def __call__(self, x):
+ x = self.conv(x)
+ fpn_list = []
+ F = []
+ for i in range(len(self.depth)):
+ fpn_list.append(np.sum(self.depth[:i + 1]))
+
+ for i, block in enumerate(self.block_list):
+ x = block(x)
+ for number in fpn_list:
+ if i + 1 == number:
+ F.append(x)
+ base = F[-1]
+
+ j = 0
+ for i, block in enumerate(self.base_block):
+ if i % 3 == 0 and i < 6:
+ j = j + 1
+ b, c, w, h = F[-j - 1].shape
+ if [w, h] == list(base.shape[2:]):
+ base = base
+ else:
+ base = self.conv_trans[j - 1](base)
+ base = self.bn_block[j - 1](base)
+ base = paddle.concat([base, F[-j - 1]], axis=1)
+ base = block(base)
+ return base
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2 if stride == (1, 1) else kernel_size,
+ dilation=2 if stride == (1, 1) else 1,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
+ bias_attr=False, )
+
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name=name + '.output.1.w_0'),
+ bias_attr=ParamAttr(name=name + '.output.1.b_0'),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance")
+
+ def __call__(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class ShortCut(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name, is_first=False):
+ super(ShortCut, self).__init__()
+ self.use_conv = True
+
+ if in_channels != out_channels or stride != 1 or is_first == True:
+ if stride == (1, 1):
+ self.conv = ConvBNLayer(
+ in_channels, out_channels, 1, 1, name=name)
+ else: # stride==(2,2)
+ self.conv = ConvBNLayer(
+ in_channels, out_channels, 1, stride, name=name)
+ else:
+ self.use_conv = False
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name):
+ super(BottleneckBlock, self).__init__()
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ self.short = ShortCut(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ stride=stride,
+ is_first=False,
+ name=name + "_branch1")
+ self.out_channels = out_channels * 4
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = y + self.short(x)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name, is_first):
+ super(BasicBlock, self).__init__()
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act='relu',
+ stride=stride,
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+ self.short = ShortCut(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ is_first=is_first,
+ name=name + "_branch1")
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = y + self.short(x)
+ return F.relu(y)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 29d0ba8005d802fc48f22a84373e32d8f0eaa225..efe05718506e94a5ae8ad5ff47bcff26d44c1473 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -24,11 +24,13 @@ def build_head(config):
# rec head
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
+ from .rec_srn_head import SRNHead
# cls head
from .cls_head import ClsHead
support_dict = [
- 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead'
+ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
+ 'SRNHead'
]
module_name = config.pop('name')
diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aaf65e1ae018dd410bbc05d0d7dcac821d062a3
--- /dev/null
+++ b/ppocr/modeling/heads/rec_srn_head.py
@@ -0,0 +1,279 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import numpy as np
+from .self_attention import WrapEncoderForFeature
+from .self_attention import WrapEncoder
+from paddle.static import Program
+from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
+import paddle.fluid.framework as framework
+
+from collections import OrderedDict
+gradient_clip = 10
+
+
+class PVAM(nn.Layer):
+ def __init__(self, in_channels, char_num, max_text_length, num_heads,
+ num_encoder_tus, hidden_dims):
+ super(PVAM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_tus
+ self.hidden_dims = hidden_dims
+ # Transformer encoder
+ t = 256
+ c = 512
+ self.wrap_encoder_for_feature = WrapEncoderForFeature(
+ src_vocab_size=1,
+ max_length=t,
+ n_layer=self.num_encoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ # PVAM
+ self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels,
+ out_features=in_channels, )
+ self.emb = paddle.nn.Embedding(
+ num_embeddings=self.max_length, embedding_dim=in_channels)
+ self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
+ self.fc1 = paddle.nn.Linear(
+ in_features=in_channels, out_features=1, bias_attr=False)
+
+ def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
+ b, c, h, w = inputs.shape
+ conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
+ conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
+ # transformer encoder
+ b, t, c = conv_features.shape
+
+ enc_inputs = [conv_features, encoder_word_pos, None]
+ word_features = self.wrap_encoder_for_feature(enc_inputs)
+
+ # pvam
+ b, t, c = word_features.shape
+ word_features = self.fc0(word_features)
+ word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
+ word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
+ word_pos_feature = self.emb(gsrm_word_pos)
+ word_pos_feature_ = paddle.reshape(word_pos_feature,
+ [-1, self.max_length, 1, c])
+ word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
+ y = word_pos_feature_ + word_features_
+ y = F.tanh(y)
+ attention_weight = self.fc1(y)
+ attention_weight = paddle.reshape(
+ attention_weight, shape=[-1, self.max_length, t])
+ attention_weight = F.softmax(attention_weight, axis=-1)
+ pvam_features = paddle.matmul(attention_weight,
+ word_features) #[b, max_length, c]
+ return pvam_features
+
+
+class GSRM(nn.Layer):
+ def __init__(self, in_channels, char_num, max_text_length, num_heads,
+ num_encoder_tus, num_decoder_tus, hidden_dims):
+ super(GSRM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_tus
+ self.num_decoder_TUs = num_decoder_tus
+ self.hidden_dims = hidden_dims
+
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels, out_features=self.char_num)
+ self.wrap_encoder0 = WrapEncoder(
+ src_vocab_size=self.char_num + 1,
+ max_length=self.max_length,
+ n_layer=self.num_decoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ self.wrap_encoder1 = WrapEncoder(
+ src_vocab_size=self.char_num + 1,
+ max_length=self.max_length,
+ n_layer=self.num_decoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ self.mul = lambda x: paddle.matmul(x=x,
+ y=self.wrap_encoder0.prepare_decoder.emb0.weight,
+ transpose_y=True)
+
+ def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2):
+ # ===== GSRM Visual-to-semantic embedding block =====
+ b, t, c = inputs.shape
+ pvam_features = paddle.reshape(inputs, [-1, c])
+ word_out = self.fc0(pvam_features)
+ word_ids = paddle.argmax(F.softmax(word_out), axis=1)
+ word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
+
+ #===== GSRM Semantic reasoning block =====
+ """
+ This module is achieved through bi-transformers,
+ ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
+ """
+ pad_idx = self.char_num
+
+ word1 = paddle.cast(word_ids, "float32")
+ word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
+ word1 = paddle.cast(word1, "int64")
+ word1 = word1[:, :-1, :]
+ word2 = word_ids
+
+ enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
+ enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
+
+ gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
+ gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
+
+ gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
+ value=0.,
+ data_format="NLC")
+ gsrm_feature2 = gsrm_feature2[:, 1:, ]
+ gsrm_features = gsrm_feature1 + gsrm_feature2
+
+ gsrm_out = self.mul(gsrm_features)
+
+ b, t, c = gsrm_out.shape
+ gsrm_out = paddle.reshape(gsrm_out, [-1, c])
+
+ return gsrm_features, word_out, gsrm_out
+
+
+class VSFD(nn.Layer):
+ def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
+ super(VSFD, self).__init__()
+ self.char_num = char_num
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels * 2, out_features=pvam_ch)
+ self.fc1 = paddle.nn.Linear(
+ in_features=pvam_ch, out_features=self.char_num)
+
+ def forward(self, pvam_feature, gsrm_feature):
+ b, t, c1 = pvam_feature.shape
+ b, t, c2 = gsrm_feature.shape
+ combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
+ img_comb_feature_ = paddle.reshape(
+ combine_feature_, shape=[-1, c1 + c2])
+ img_comb_feature_map = self.fc0(img_comb_feature_)
+ img_comb_feature_map = F.sigmoid(img_comb_feature_map)
+ img_comb_feature_map = paddle.reshape(
+ img_comb_feature_map, shape=[-1, t, c1])
+ combine_feature = img_comb_feature_map * pvam_feature + (
+ 1.0 - img_comb_feature_map) * gsrm_feature
+ img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
+
+ out = self.fc1(img_comb_feature)
+ return out
+
+
+class SRNHead(nn.Layer):
+ def __init__(self, in_channels, out_channels, max_text_length, num_heads,
+ num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
+ super(SRNHead, self).__init__()
+ self.char_num = out_channels
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_TUs
+ self.num_decoder_TUs = num_decoder_TUs
+ self.hidden_dims = hidden_dims
+
+ self.pvam = PVAM(
+ in_channels=in_channels,
+ char_num=self.char_num,
+ max_text_length=self.max_length,
+ num_heads=self.num_heads,
+ num_encoder_tus=self.num_encoder_TUs,
+ hidden_dims=self.hidden_dims)
+
+ self.gsrm = GSRM(
+ in_channels=in_channels,
+ char_num=self.char_num,
+ max_text_length=self.max_length,
+ num_heads=self.num_heads,
+ num_encoder_tus=self.num_encoder_TUs,
+ num_decoder_tus=self.num_decoder_TUs,
+ hidden_dims=self.hidden_dims)
+ self.vsfd = VSFD(in_channels=in_channels)
+
+ self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
+
+ def forward(self, inputs, others):
+ encoder_word_pos = others[0]
+ gsrm_word_pos = others[1]
+ gsrm_slf_attn_bias1 = others[2]
+ gsrm_slf_attn_bias2 = others[3]
+
+ pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
+
+ gsrm_feature, word_out, gsrm_out = self.gsrm(
+ pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2)
+
+ final_out = self.vsfd(pvam_feature, gsrm_feature)
+ if not self.training:
+ final_out = F.softmax(final_out, axis=1)
+
+ _, decoded_out = paddle.topk(final_out, k=1)
+
+ predicts = OrderedDict([
+ ('predict', final_out),
+ ('pvam_feature', pvam_feature),
+ ('decoded_out', decoded_out),
+ ('word_out', word_out),
+ ('gsrm_out', gsrm_out),
+ ])
+
+ return predicts
diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d5198f558dcb7e0351f04b3a884b71707104d4
--- /dev/null
+++ b/ppocr/modeling/heads/self_attention.py
@@ -0,0 +1,409 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import paddle
+from paddle import ParamAttr, nn
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import numpy as np
+gradient_clip = 10
+
+
+class WrapEncoderForFeature(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0):
+ super(WrapEncoderForFeature, self).__init__()
+
+ self.prepare_encoder = PrepareEncoder(
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ bos_idx=bos_idx,
+ word_emb_param_name="src_word_emb_table")
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ conv_features, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self.prepare_encoder(conv_features, src_pos)
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class WrapEncoder(nn.Layer):
+ """
+ embedder + encoder
+ """
+
+ def __init__(self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0):
+ super(WrapEncoder, self).__init__()
+
+ self.prepare_decoder = PrepareDecoder(
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ bos_idx=bos_idx)
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ src_word, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self.prepare_decoder(src_word, src_pos)
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class Encoder(nn.Layer):
+ """
+ encoder
+ """
+
+ def __init__(self,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(Encoder, self).__init__()
+
+ self.encoder_layers = list()
+ for i in range(n_layer):
+ self.encoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd,
+ postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ for encoder_layer in self.encoder_layers:
+ enc_output = encoder_layer(enc_input, attn_bias)
+ enc_input = enc_output
+ enc_output = self.processer(enc_output)
+ return enc_output
+
+
+class EncoderLayer(nn.Layer):
+ """
+ EncoderLayer
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(EncoderLayer, self).__init__()
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ attn_output = self.self_attn(
+ self.preprocesser1(enc_input), None, None, attn_bias)
+ attn_output = self.postprocesser1(attn_output, enc_input)
+ ffn_output = self.ffn(self.preprocesser2(attn_output))
+ ffn_output = self.postprocesser2(ffn_output, attn_output)
+ return ffn_output
+
+
+class MultiHeadAttention(nn.Layer):
+ """
+ Multi-Head Attention
+ """
+
+ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_key = d_key
+ self.d_value = d_value
+ self.d_model = d_model
+ self.dropout_rate = dropout_rate
+ self.q_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ self.k_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ self.v_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_value * n_head, bias_attr=False)
+ self.proj_fc = paddle.nn.Linear(
+ in_features=d_value * n_head, out_features=d_model, bias_attr=False)
+
+ def _prepare_qkv(self, queries, keys, values, cache=None):
+ if keys is None: # self-attention
+ keys, values = queries, queries
+ static_kv = False
+ else: # cross-attention
+ static_kv = True
+
+ q = self.q_fc(queries)
+ q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
+ q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
+
+ if cache is not None and static_kv and "static_k" in cache:
+ # for encoder-decoder attention in inference and has cached
+ k = cache["static_k"]
+ v = cache["static_v"]
+ else:
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
+ v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
+
+ if cache is not None:
+ if static_kv and not "static_k" in cache:
+ # for encoder-decoder attention in inference and has not cached
+ cache["static_k"], cache["static_v"] = k, v
+ elif not static_kv:
+ # for decoder self-attention in inference
+ cache_k, cache_v = cache["k"], cache["v"]
+ k = paddle.concat([cache_k, k], axis=2)
+ v = paddle.concat([cache_v, v], axis=2)
+ cache["k"], cache["v"] = k, v
+
+ return q, k, v
+
+ def forward(self, queries, keys, values, attn_bias, cache=None):
+ # compute q ,k ,v
+ keys = queries if keys is None else keys
+ values = keys if values is None else values
+ q, k, v = self._prepare_qkv(queries, keys, values, cache)
+
+ # scale dot product attention
+ product = paddle.matmul(x=q, y=k, transpose_y=True)
+ product = product * self.d_model**-0.5
+ if attn_bias is not None:
+ product += attn_bias
+ weights = F.softmax(product)
+ if self.dropout_rate:
+ weights = F.dropout(
+ weights, p=self.dropout_rate, mode="downscale_in_infer")
+ out = paddle.matmul(weights, v)
+
+ # combine heads
+ out = paddle.transpose(out, perm=[0, 2, 1, 3])
+ out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
+
+ # project to output
+ out = self.proj_fc(out)
+
+ return out
+
+
+class PrePostProcessLayer(nn.Layer):
+ """
+ PrePostProcessLayer
+ """
+
+ def __init__(self, process_cmd, d_model, dropout_rate):
+ super(PrePostProcessLayer, self).__init__()
+ self.process_cmd = process_cmd
+ self.functors = []
+ for cmd in self.process_cmd:
+ if cmd == "a": # add residual connection
+ self.functors.append(lambda x, y: x + y if y is not None else x)
+ elif cmd == "n": # add layer normalization
+ self.functors.append(
+ self.add_sublayer(
+ "layer_norm_%d" % len(
+ self.sublayers(include_sublayers=False)),
+ paddle.nn.LayerNorm(
+ normalized_shape=d_model,
+ weight_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(1.)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.)))))
+ elif cmd == "d": # add dropout
+ self.functors.append(lambda x: F.dropout(
+ x, p=dropout_rate, mode="downscale_in_infer")
+ if dropout_rate else x)
+
+ def forward(self, x, residual=None):
+ for i, cmd in enumerate(self.process_cmd):
+ if cmd == "a":
+ x = self.functors[i](x, residual)
+ else:
+ x = self.functors[i](x)
+ return x
+
+
+class PrepareEncoder(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareEncoder, self).__init__()
+ self.src_emb_dim = src_emb_dim
+ self.src_max_len = src_max_len
+ self.emb = paddle.nn.Embedding(
+ num_embeddings=self.src_max_len,
+ embedding_dim=self.src_emb_dim,
+ sparse=True)
+ self.dropout_rate = dropout_rate
+
+ def forward(self, src_word, src_pos):
+ src_word_emb = src_word
+ src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
+ src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
+ src_pos = paddle.squeeze(src_pos, axis=-1)
+ src_pos_enc = self.emb(src_pos)
+ src_pos_enc.stop_gradient = True
+ enc_input = src_word_emb + src_pos_enc
+ if self.dropout_rate:
+ out = F.dropout(
+ x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ else:
+ out = enc_input
+ return out
+
+
+class PrepareDecoder(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareDecoder, self).__init__()
+ self.src_emb_dim = src_emb_dim
+ """
+ self.emb0 = Embedding(num_embeddings=src_vocab_size,
+ embedding_dim=src_emb_dim)
+ """
+ self.emb0 = paddle.nn.Embedding(
+ num_embeddings=src_vocab_size,
+ embedding_dim=self.src_emb_dim,
+ padding_idx=bos_idx,
+ weight_attr=paddle.ParamAttr(
+ name=word_emb_param_name,
+ initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
+ self.emb1 = paddle.nn.Embedding(
+ num_embeddings=src_max_len,
+ embedding_dim=self.src_emb_dim,
+ weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
+ self.dropout_rate = dropout_rate
+
+ def forward(self, src_word, src_pos):
+ src_word = fluid.layers.cast(src_word, 'int64')
+ src_word = paddle.squeeze(src_word, axis=-1)
+ src_word_emb = self.emb0(src_word)
+ src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
+ src_pos = paddle.squeeze(src_pos, axis=-1)
+ src_pos_enc = self.emb1(src_pos)
+ src_pos_enc.stop_gradient = True
+ enc_input = src_word_emb + src_pos_enc
+ if self.dropout_rate:
+ out = F.dropout(
+ x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ else:
+ out = enc_input
+ return out
+
+
+class FFN(nn.Layer):
+ """
+ Feed-Forward Network
+ """
+
+ def __init__(self, d_inner_hid, d_model, dropout_rate):
+ super(FFN, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.fc1 = paddle.nn.Linear(
+ in_features=d_model, out_features=d_inner_hid)
+ self.fc2 = paddle.nn.Linear(
+ in_features=d_inner_hid, out_features=d_model)
+
+ def forward(self, x):
+ hidden = self.fc1(x)
+ hidden = F.relu(hidden)
+ if self.dropout_rate:
+ hidden = F.dropout(
+ hidden, p=self.dropout_rate, mode="downscale_in_infer")
+ out = self.fc2(hidden)
+ return out
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 2b8d00a9e76911eddedd9ae30cda0c20820eaea4..0156e438e9e24820943c9e48b04565710ea2fd4b 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -26,12 +26,12 @@ def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
- from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
+ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
- 'AttnLabelDecode', 'ClsPostProcess', 'AttnLabelDecode'
+ 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
]
config = copy.deepcopy(config)
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 1ac352466a497916f913dccfca32e294d361ecd2..2b82750fcb2135526e0bbbbe61de91e627a93153 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -33,6 +33,9 @@ class BaseRecLabelDecode(object):
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
+ self.beg_str = "sos"
+ self.end_str = "eos"
+
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@@ -109,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
-
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
@@ -194,3 +196,84 @@ class AttnLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class SRNLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ character_dict_path=None,
+ character_type='en',
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelDecode, self).__init__(character_dict_path,
+ character_type, use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ pred = preds['predict']
+ char_num = len(self.character_str) + 2
+ if isinstance(pred, paddle.Tensor):
+ pred = pred.numpy()
+ pred = np.reshape(pred, [-1, char_num])
+
+ preds_idx = np.argmax(pred, axis=1)
+ preds_prob = np.max(pred, axis=1)
+
+ preds_idx = np.reshape(preds_idx, [-1, 25])
+
+ preds_prob = np.reshape(preds_prob, [-1, 25])
+
+ text = self.decode(preds_idx, preds_prob)
+
+ if label is None:
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ return text
+ label = self.decode(label)
+ return text, label
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list)))
+ return result_list
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
diff --git a/tools/export_model.py b/tools/export_model.py
index a9b9e7dd5145e46eb4094da8e0c65e4678f0818a..1e9526e03d6b9001249d5891c37bee071c1f36a3 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-c", "--config", help="configuration file to use")
+ parser.add_argument(
+ "-o", "--output_path", type=str, default='./output/infer/')
+ return parser.parse_args()
+
+
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
@@ -52,23 +60,39 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
- infer_shape = [3, -1, -1]
- if config['Architecture']['model_type'] == "rec":
- infer_shape = [3, 32, -1] # for rec model, H must be 32
- if 'Transform' in config['Architecture'] and config['Architecture'][
- 'Transform'] is not None and config['Architecture'][
- 'Transform']['name'] == 'TPS':
- logger.info(
- 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
- )
- infer_shape[-1] = 100
-
- model = to_static(
- model,
- input_spec=[
+ if config['Architecture']['algorithm'] == "SRN":
+ other_shape = [
paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype='float32')
- ])
+ shape=[None, 1, 64, 256], dtype='float32'), [
+ paddle.static.InputSpec(
+ shape=[None, 256, 1],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 25, 1],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 8, 25, 25], dtype="int64"),
+ paddle.static.InputSpec(
+ shape=[None, 8, 25, 25], dtype="int64")
+ ]
+ ]
+ model = to_static(model, input_spec=other_shape)
+ else:
+ infer_shape = [3, -1, -1]
+ if config['Architecture']['model_type'] == "rec":
+ infer_shape = [3, 32, -1] # for rec model, H must be 32
+ if 'Transform' in config['Architecture'] and config['Architecture'][
+ 'Transform'] is not None and config['Architecture'][
+ 'Transform']['name'] == 'TPS':
+ logger.info(
+ 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
+ )
+ infer_shape[-1] = 100
+ model = to_static(
+ model,
+ input_spec=[
+ paddle.static.InputSpec(
+ shape=[None] + infer_shape, dtype='float32')
+ ])
+
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 974fdbb6c7f4d33bd39e818945be480d858c0d09..fd895e50719941877fd620cab929a20c7d88b8e5 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -25,6 +25,7 @@ import numpy as np
import math
import time
import traceback
+import paddle
import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
@@ -46,6 +47,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ if self.rec_algorithm == "SRN":
+ postprocess_params = {
+ 'name': 'SRNLabelDecode',
+ "character_type": args.rec_char_type,
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger)
@@ -70,6 +78,78 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
+ def resize_norm_img_srn(self, img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+ def srn_other_inputs(self, image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [-1, 1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(
+ gsrm_slf_attn_bias1,
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [-1, 1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(
+ gsrm_slf_attn_bias2,
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+
+ encoder_word_pos = encoder_word_pos[np.newaxis, :]
+ gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+ def process_image_srn(self, img, image_shape, num_heads, max_text_length):
+ norm_img = self.resize_norm_img_srn(img, image_shape)
+ norm_img = norm_img[np.newaxis, :]
+
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ self.srn_other_inputs(image_shape, num_heads, max_text_length)
+
+ gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
+ gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
+ encoder_word_pos = encoder_word_pos.astype(np.int64)
+ gsrm_word_pos = gsrm_word_pos.astype(np.int64)
+
+ return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2)
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -93,21 +173,64 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
- # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
- norm_img = norm_img[np.newaxis, :]
- norm_img_batch.append(norm_img)
+ if self.rec_algorithm != "SRN":
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ else:
+ norm_img = self.process_image_srn(
+ img_list[indices[ino]], self.rec_image_shape, 8, 25)
+ encoder_word_pos_list = []
+ gsrm_word_pos_list = []
+ gsrm_slf_attn_bias1_list = []
+ gsrm_slf_attn_bias2_list = []
+ encoder_word_pos_list.append(norm_img[1])
+ gsrm_word_pos_list.append(norm_img[2])
+ gsrm_slf_attn_bias1_list.append(norm_img[3])
+ gsrm_slf_attn_bias2_list.append(norm_img[4])
+ norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
- starttime = time.time()
- self.input_tensor.copy_from_cpu(norm_img_batch)
- self.predictor.run()
- outputs = []
- for output_tensor in self.output_tensors:
- output = output_tensor.copy_to_cpu()
- outputs.append(output)
- preds = outputs[0]
+
+ if self.rec_algorithm == "SRN":
+ starttime = time.time()
+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+ gsrm_slf_attn_bias1_list = np.concatenate(
+ gsrm_slf_attn_bias1_list)
+ gsrm_slf_attn_bias2_list = np.concatenate(
+ gsrm_slf_attn_bias2_list)
+
+ inputs = [
+ norm_img_batch,
+ encoder_word_pos_list,
+ gsrm_word_pos_list,
+ gsrm_slf_attn_bias1_list,
+ gsrm_slf_attn_bias2_list,
+ ]
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(input_names[
+ i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = {"predict": outputs[2]}
+ else:
+ starttime = time.time()
+ self.input_tensor.copy_from_cpu(norm_img_batch)
+ self.predictor.run()
+
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = outputs[0]
+
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 7e4b081140c37ff1eb8c5e0085185b8961198a0b..075ec261e492cf21c668364ae6119fb4903f823b 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -62,7 +62,13 @@ def main():
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image']
+ if config['Architecture']['algorithm'] == "SRN":
+ op[op_name]['keep_keys'] = [
+ 'image', 'encoder_word_pos', 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
+ ]
+ else:
+ op[op_name]['keep_keys'] = ['image']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
@@ -74,10 +80,25 @@ def main():
img = f.read()
data = {'image': img}
batch = transform(data, ops)
+ if config['Architecture']['algorithm'] == "SRN":
+ encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
+ gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
+ gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
+ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
+
+ others = [
+ paddle.to_tensor(encoder_word_pos_list),
+ paddle.to_tensor(gsrm_word_pos_list),
+ paddle.to_tensor(gsrm_slf_attn_bias1_list),
+ paddle.to_tensor(gsrm_slf_attn_bias2_list)
+ ]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
- preds = model(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ preds = model(images, others)
+ else:
+ preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
diff --git a/tools/program.py b/tools/program.py
index fb9e3802a0818b2ee92117d10bda6b70261abace..f3ba49450a21f600589b6888710a2420ccdaa321 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -174,6 +174,7 @@ def train(config,
best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict)
train_stats = TrainingStats(log_smooth_window, ['lr'])
+ model_average = False
model.train()
if 'start_epoch' in best_model_dict:
@@ -194,7 +195,12 @@ def train(config,
break
lr = optimizer.get_lr()
images = batch[0]
- preds = model(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ others = batch[-4:]
+ preds = model(images, others)
+ model_average = True
+ else:
+ preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
@@ -238,7 +244,14 @@ def train(config,
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
- cur_metric = eval(model, valid_dataloader, post_process_class,
+ if model_average:
+ Model_Average = paddle.incubate.optimizer.ModelAverage(
+ 0.15,
+ parameters=model.parameters(),
+ min_average_window=10000,
+ max_average_window=15625)
+ Model_Average.apply()
+ cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
@@ -273,6 +286,7 @@ def train(config,
best_model_dict[main_indicator],
global_step)
global_step += 1
+ optimizer.clear_grad()
batch_start = time.time()
if dist.get_rank() == 0:
save_model(
@@ -313,7 +327,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
break
images = batch[0]
start = time.time()
- preds = model(images)
+ if "SRN" in str(model.head):
+ others = batch[-4:]
+ preds = model(images, others)
+ else:
+ preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods