diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
index 3833cb0bad9915a2169b116e7406e01cadd0ef62..df429314cd0ec058aa6779a0ff55656f1b211bbf 100644
--- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
+++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
@@ -28,7 +28,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 18
Neck:
name: DBFPN
diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
index b4644244f849abb05d5501c7d2ed2bf92efe916e..d24ee11f3dee9d17a587d6d4d0765abfe459fbc9 100644
--- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
+++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
@@ -45,7 +45,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 18
Neck:
name: DBFPN
diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
index 3e77577c17abe2111c501d96ce6b1087ac44f8d6..ef58befd694e26704c734d7fd072ebc3370c8554 100644
--- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
+++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
@@ -61,7 +61,7 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
- name: ResNet
+ name: ResNet_vd
in_channels: 3
layers: 50
Neck:
diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
index 3f30ada13f2f1c3c01ba8886bbfba006da516f17..e1831f26396cffde28a2ca881a0b03312a68f801 100644
--- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
+++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
@@ -25,7 +25,7 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
- name: ResNet
+ name: ResNet_vd
in_channels: 3
layers: 50
Neck:
@@ -40,7 +40,7 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
- name: ResNet
+ name: ResNet_vd
in_channels: 3
layers: 50
Neck:
diff --git a/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml b/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
index 7b07ef99648956a70b5a71f1e61f09b592226f90..e983c221e2cea87e35b1c5f0a67daed0a55c4257 100644
--- a/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
+++ b/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
@@ -20,7 +20,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 18
disable_se: True
Neck:
diff --git a/configs/det/det_r50_db++_ic15.yml b/configs/det/det_r50_db++_ic15.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e0cd6012b660573a79ff013a1b6e2309074a3d86
--- /dev/null
+++ b/configs/det/det_r50_db++_ic15.yml
@@ -0,0 +1,163 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: 1000
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/det_r50_icdar15/
+ save_epoch_step: 200
+ eval_batch_step:
+ - 0
+ - 2000
+ cal_metric_during_train: false
+ pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
+ checkpoints: null
+ save_inference_dir: null
+ use_visualdl: false
+ infer_img: doc/imgs_en/img_10.jpg
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
+Architecture:
+ model_type: det
+ algorithm: DB++
+ Transform: null
+ Backbone:
+ name: ResNet
+ layers: 50
+ dcn_stage: [False, True, True, True]
+ Neck:
+ name: DBFPN
+ out_channels: 256
+ use_asf: True
+ Head:
+ name: DBHead
+ k: 50
+Loss:
+ name: DBLoss
+ balance_loss: true
+ main_loss_type: BCELoss
+ alpha: 5
+ beta: 10
+ ohem_ratio: 3
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ lr:
+ name: DecayLearningRate
+ learning_rate: 0.007
+ epochs: 1000
+ factor: 0.9
+ end_lr: 0
+ weight_decay: 0.0001
+PostProcess:
+ name: DBPostProcess
+ thresh: 0.3
+ box_thresh: 0.6
+ max_candidates: 1000
+ unclip_ratio: 1.5
+Metric:
+ name: DetMetric
+ main_indicator: hmean
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization/
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
+ ratio_list:
+ - 1.0
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - IaaAugment:
+ augmenter_args:
+ - type: Fliplr
+ args:
+ p: 0.5
+ - type: Affine
+ args:
+ rotate:
+ - -10
+ - 10
+ - type: Resize
+ args:
+ size:
+ - 0.5
+ - 3
+ - EastRandomCropData:
+ size:
+ - 640
+ - 640
+ max_tries: 10
+ keep_ratio: true
+ - MakeShrinkMap:
+ shrink_ratio: 0.4
+ min_text_size: 8
+ - MakeBorderMap:
+ shrink_ratio: 0.4
+ thresh_min: 0.3
+ thresh_max: 0.7
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.48109378172549
+ - 0.45752457890196
+ - 0.40787054090196
+ std:
+ - 1.0
+ - 1.0
+ - 1.0
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - threshold_map
+ - threshold_mask
+ - shrink_map
+ - shrink_mask
+ loader:
+ shuffle: true
+ drop_last: false
+ batch_size_per_card: 4
+ num_workers: 8
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - DetResizeForTest:
+ image_shape:
+ - 1152
+ - 2048
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.48109378172549
+ - 0.45752457890196
+ - 0.40787054090196
+ std:
+ - 1.0
+ - 1.0
+ - 1.0
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - shape
+ - polys
+ - ignore_tags
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 1
+ num_workers: 2
+profiler_options: null
diff --git a/configs/det/det_r50_db++_td_tr.yml b/configs/det/det_r50_db++_td_tr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..65021bb66184381ba732980ac1b7a65d7bd3a355
--- /dev/null
+++ b/configs/det/det_r50_db++_td_tr.yml
@@ -0,0 +1,166 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: 1000
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/det_r50_td_tr/
+ save_epoch_step: 200
+ eval_batch_step:
+ - 0
+ - 2000
+ cal_metric_during_train: false
+ pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
+ checkpoints: null
+ save_inference_dir: null
+ use_visualdl: false
+ infer_img: doc/imgs_en/img_10.jpg
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
+Architecture:
+ model_type: det
+ algorithm: DB++
+ Transform: null
+ Backbone:
+ name: ResNet
+ layers: 50
+ dcn_stage: [False, True, True, True]
+ Neck:
+ name: DBFPN
+ out_channels: 256
+ use_asf: True
+ Head:
+ name: DBHead
+ k: 50
+Loss:
+ name: DBLoss
+ balance_loss: true
+ main_loss_type: BCELoss
+ alpha: 5
+ beta: 10
+ ohem_ratio: 3
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ lr:
+ name: DecayLearningRate
+ learning_rate: 0.007
+ epochs: 1000
+ factor: 0.9
+ end_lr: 0
+ weight_decay: 0.0001
+PostProcess:
+ name: DBPostProcess
+ thresh: 0.3
+ box_thresh: 0.5
+ max_candidates: 1000
+ unclip_ratio: 1.5
+Metric:
+ name: DetMetric
+ main_indicator: hmean
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/
+ label_file_list:
+ - ./train_data/TD_TR/TD500/train_gt_labels.txt
+ - ./train_data/TD_TR/TR400/gt_labels.txt
+ ratio_list:
+ - 1.0
+ - 1.0
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - IaaAugment:
+ augmenter_args:
+ - type: Fliplr
+ args:
+ p: 0.5
+ - type: Affine
+ args:
+ rotate:
+ - -10
+ - 10
+ - type: Resize
+ args:
+ size:
+ - 0.5
+ - 3
+ - EastRandomCropData:
+ size:
+ - 640
+ - 640
+ max_tries: 10
+ keep_ratio: true
+ - MakeShrinkMap:
+ shrink_ratio: 0.4
+ min_text_size: 8
+ - MakeBorderMap:
+ shrink_ratio: 0.4
+ thresh_min: 0.3
+ thresh_max: 0.7
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.48109378172549
+ - 0.45752457890196
+ - 0.40787054090196
+ std:
+ - 1.0
+ - 1.0
+ - 1.0
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - threshold_map
+ - threshold_mask
+ - shrink_map
+ - shrink_mask
+ loader:
+ shuffle: true
+ drop_last: false
+ batch_size_per_card: 4
+ num_workers: 8
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/
+ label_file_list:
+ - ./train_data/TD_TR/TD500/test_gt_labels.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - DetResizeForTest:
+ image_shape:
+ - 736
+ - 736
+ keep_ratio: True
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.48109378172549
+ - 0.45752457890196
+ - 0.40787054090196
+ std:
+ - 1.0
+ - 1.0
+ - 1.0
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - shape
+ - polys
+ - ignore_tags
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 1
+ num_workers: 2
+profiler_options: null
diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/det_r50_vd_db.yml
index ab67786ece2db9c082ad0484e9dd9a71a795c2d7..288dcc8c1a5934b35d329acc38fa85451ebaeb19 100644
--- a/configs/det/det_r50_vd_db.yml
+++ b/configs/det/det_r50_vd_db.yml
@@ -20,7 +20,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 50
Neck:
name: DBFPN
diff --git a/configs/det/det_r50_vd_dcn_fce_ctw.yml b/configs/det/det_r50_vd_dcn_fce_ctw.yml
index a9f7c4143d4e9380c819f8cbc39d69f0149111b2..3a4075b322a173796f26ebdbe5b83ba98e98e72d 100755
--- a/configs/det/det_r50_vd_dcn_fce_ctw.yml
+++ b/configs/det/det_r50_vd_dcn_fce_ctw.yml
@@ -21,7 +21,7 @@ Architecture:
algorithm: FCE
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 50
dcn_stage: [False, True, True, True]
out_indices: [1,2,3]
diff --git a/configs/det/det_r50_vd_east.yml b/configs/det/det_r50_vd_east.yml
index e84a5fa7a7af34bde5e0abc6fed2e01f6ce42e6b..af90ef0adb929955141ef31779e917e4d73057ee 100644
--- a/configs/det/det_r50_vd_east.yml
+++ b/configs/det/det_r50_vd_east.yml
@@ -20,7 +20,7 @@ Architecture:
algorithm: EAST
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 50
Neck:
name: EASTFPN
diff --git a/configs/det/det_r50_vd_pse.yml b/configs/det/det_r50_vd_pse.yml
index 8e77506c410af5397a04f73674b414cb28a87c4d..1a971564fda2ce89bf091808dedb361f1caeddc3 100644
--- a/configs/det/det_r50_vd_pse.yml
+++ b/configs/det/det_r50_vd_pse.yml
@@ -20,7 +20,7 @@ Architecture:
algorithm: PSE
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 50
Neck:
name: FPN
diff --git a/configs/det/det_res18_db_v2.0.yml b/configs/det/det_res18_db_v2.0.yml
index 7b07ef99648956a70b5a71f1e61f09b592226f90..e983c221e2cea87e35b1c5f0a67daed0a55c4257 100644
--- a/configs/det/det_res18_db_v2.0.yml
+++ b/configs/det/det_res18_db_v2.0.yml
@@ -20,7 +20,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
- name: ResNet
+ name: ResNet_vd
layers: 18
disable_se: True
Neck:
diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/kie_unet_sdmgr.yml
index a6968aaa3aa7a717a848416efc5ccc567f774b4d..da2e4fda504fcbff280788b0baf6d803cf75fe4b 100644
--- a/configs/kie/kie_unet_sdmgr.yml
+++ b/configs/kie/kie_unet_sdmgr.yml
@@ -17,7 +17,7 @@ Global:
checkpoints:
save_inference_dir:
use_visualdl: False
- class_path: ./train_data/wildreceipt/class_list.txt
+ class_path: &class_path ./train_data/wildreceipt/class_list.txt
infer_img: ./train_data/wildreceipt/1.txt
save_res_path: ./output/sdmgr_kie/predicts_kie.txt
img_scale: [ 1024, 512 ]
@@ -72,6 +72,7 @@ Train:
order: 'hwc'
- KieLabelEncode: # Class handling label
character_dict_path: ./train_data/wildreceipt/dict.txt
+ class_path: *class_path
- KieResize:
- ToCHWImage:
- KeepKeys:
@@ -88,7 +89,6 @@ Eval:
data_dir: ./train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
- # - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
diff --git a/configs/vqa/re/layoutlmv2.yml b/configs/vqa/re/layoutlmv2.yml
index 2fa5fd1165c20bbfa8d8505bbb53d48744daebef..737dbf6b600b1b414a7f66f422e59f46154d91a9 100644
--- a/configs/vqa/re/layoutlmv2.yml
+++ b/configs/vqa/re/layoutlmv2.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2048
- infer_img: doc/vqa/input/zh_val_21.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
@@ -21,7 +21,7 @@ Architecture:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
- checkpoints:
+ checkpoints:
Loss:
name: LossFromOutput
@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -114,7 +114,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml
index ff16120ac1be92e989ebfda6af3ccf346dde89cd..d8585bb72593d55578ff3c6cd1401b5a843bb683 100644
--- a/configs/vqa/re/layoutxlm.yml
+++ b/configs/vqa/re/layoutxlm.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_21.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -114,7 +114,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml
index 87131170c9daabd8553269b900893ac26fb32bc8..53e114defd4cdfa427ae27b647603744302eb0e8 100644
--- a/configs/vqa/ser/layoutlm.yml
+++ b/configs/vqa/ser/layoutlm.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_0.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/
Architecture:
@@ -43,7 +43,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -89,7 +89,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -112,7 +112,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/ser/layoutlmv2.yml b/configs/vqa/ser/layoutlmv2.yml
index 33406252b31adf4175d7ea2f57772b0faf33cdab..e48c7469567a740ca74240f0ca9f782ed5bb3c6d 100644
--- a/configs/vqa/ser/layoutlmv2.yml
+++ b/configs/vqa/ser/layoutlmv2.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_0.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/
Architecture:
@@ -44,7 +44,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
@@ -55,7 +55,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -78,7 +78,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -113,7 +113,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml
index eb1cca5a215dd65ef9c302441d05b482f2622a79..fa9df192afbc1d638c220cba3ef3640715585b37 100644
--- a/configs/vqa/ser/layoutxlm.yml
+++ b/configs/vqa/ser/layoutxlm.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_42.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser
Architecture:
@@ -43,7 +43,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
@@ -78,7 +78,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -113,7 +113,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/doc/doc_ch/algorithm_det_db.md b/doc/doc_ch/algorithm_det_db.md
index 90837c2ac1ebbc04ee47cbb74ed6466352710e88..afdddb1a73a495cbb3186348704b235f8076c7d1 100644
--- a/doc/doc_ch/algorithm_det_db.md
+++ b/doc/doc_ch/algorithm_det_db.md
@@ -1,4 +1,4 @@
-# DB
+# DB与DB++
- [1. 算法简介](#1)
- [2. 环境配置](#2)
@@ -21,12 +21,24 @@
> Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang
> AAAI, 2020
+> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
+> Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang
+> TPAMI, 2022
+
+
在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
+|DB++|ResNet50|[configs/det/det_r50_db++_ic15.yml](../../configs/det/det_r50_db++_ic15.yml)|90.89%|82.66%|86.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
+
+在TD_TR文本检测公开数据集上,算法复现效果如下:
+
+|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
+| --- | --- | --- | --- | --- | --- | --- |
+|DB++|ResNet50|[configs/det/det_r50_db++_td_tr.yml](../../configs/det/det_r50_db++_td_tr.yml)|92.92%|86.48%|89.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_td_tr_train.tar)|
@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai
DB文本检测模型推理,可以执行如下命令:
```shell
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/"
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" --det_algorithm="DB"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式:
pages={11474--11481},
year={2020}
}
-```
\ No newline at end of file
+
+@article{liao2022real,
+ title={Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion},
+ author={Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang},
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
+ year={2022},
+ publisher={IEEE}
+}
+```
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index eb81e4cd6dae2542dd07d0e25fe543419f798c9b..ef96f6ec122594afd115b333ffc18fb836253b79 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -86,8 +86,9 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
-|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
-|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
+|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+
diff --git a/doc/doc_ch/dataset/ocr_datasets.md b/doc/doc_ch/dataset/ocr_datasets.md
index c6ff2e170f7c30a29e98ed2b1349cae2b84cf441..b7666fd63e6f17b734a17e2d11a0c8614d225964 100644
--- a/doc/doc_ch/dataset/ocr_datasets.md
+++ b/doc/doc_ch/dataset/ocr_datasets.md
@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
| ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads| [train](https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt) / [test](https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt) |
| ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 |
| total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 |
+| td tr |https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar| 图片下载地址中已包含 |
#### 1.2.1 ICDAR 2015
ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 28aca7c0d171008156104fbcc786707538fd49ef..bc96cdf2351f10454441e20d319e485019bbec00 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -85,8 +85,9 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
-|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
-|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
+|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 63dfda91f8d0eb200d3c635fda43670039375784..8d8f4e1ffab75473315425c030a1576e19d46e24 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,9 +23,10 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
-from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug
+
+from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
+ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 312d6dc9ad25bfa73aa9009f932fe6f3d3ca7644..0723e97ae719690ef2e6a500b327b039c7a46f66 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object):
class KieLabelEncode(object):
- def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
+ def __init__(self,
+ character_dict_path,
+ class_path,
+ norm=10,
+ directed=False,
+ **kwargs):
super(KieLabelEncode, self).__init__()
self.dict = dict({'': 0})
+ self.label2classid_map = dict()
with open(character_dict_path, 'r', encoding='utf-8') as fr:
idx = 1
for line in fr:
char = line.strip()
self.dict[char] = idx
idx += 1
+ with open(class_path, "r") as fin:
+ lines = fin.readlines()
+ for idx, line in enumerate(lines):
+ line = line.strip("\n")
+ self.label2classid_map[line] = idx
self.norm = norm
self.directed = directed
@@ -408,7 +419,7 @@ class KieLabelEncode(object):
text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind)
if 'label' in ann.keys():
- labels.append(ann['label'])
+ labels.append(self.label2classid_map[ann['label']])
elif 'key_cls' in ann.keys():
labels.append(ann['key_cls'])
else:
@@ -876,15 +887,16 @@ class VQATokenLabelEncode(object):
for info in ocr_info:
if train_re:
# for re
- if len(info["text"]) == 0:
+ if len(info["transcription"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
+ info["bbox"] = self.trans_poly_to_bbox(info["points"])
bbox = self._smooth_box(info["bbox"], height, width)
- text = info["text"]
+ text = info["transcription"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
@@ -944,29 +956,29 @@ class VQATokenLabelEncode(object):
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
- def _load_ocr_info(self, data):
- def trans_poly_to_bbox(poly):
- x1 = np.min([p[0] for p in poly])
- x2 = np.max([p[0] for p in poly])
- y1 = np.min([p[1] for p in poly])
- y2 = np.max([p[1] for p in poly])
- return [x1, y1, x2, y2]
+ def trans_poly_to_bbox(self, poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+ def _load_ocr_info(self, data):
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
- "text": res[1][0],
- "bbox": trans_poly_to_bbox(res[0]),
- "poly": res[0],
+ "transcription": res[1][0],
+ "bbox": self.trans_poly_to_bbox(res[0]),
+ "points": res[0],
})
return ocr_info
else:
info = data['label']
# read text info
info_dict = json.loads(info)
- return info_dict["ocr_info"]
+ return info_dict
def _smooth_box(self, bbox, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width)
@@ -977,7 +989,7 @@ class VQATokenLabelEncode(object):
def _parse_label(self, label, encode_res):
gt_label = []
- if label.lower() == "other":
+ if label.lower() in ["other", "others", "ignore"]:
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 5397d71ccb466235e64f85e1eb9365ba03d2aa17..04cc2848fb4d25baaf553c6eda235ddb0e86511f 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -205,9 +205,12 @@ class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
+ self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
+ if 'keep_ratio' in kwargs:
+ self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
@@ -237,6 +240,10 @@ class DetResizeForTest(object):
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
+ if self.keep_ratio is True:
+ resize_w = ori_w * resize_h / ori_h
+ N = math.ceil(resize_w / 32)
+ resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
index a5025e7985198e7ee40d6c92d8e1814eb1797032..bde175115536a3f644750260082204fe5f10dc05 100644
--- a/ppocr/data/imaug/vqa/__init__.py
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -13,7 +13,12 @@
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
+from .augment import DistortBBox
__all__ = [
- 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
+ 'VQATokenPad',
+ 'VQASerTokenChunk',
+ 'VQAReTokenChunk',
+ 'VQAReTokenRelation',
+ 'DistortBBox',
]
diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcdc9685e9855c3a2d8e9f6f5add270f95f15a6c
--- /dev/null
+++ b/ppocr/data/imaug/vqa/augment.py
@@ -0,0 +1,37 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import random
+
+
+class DistortBBox:
+ def __init__(self, prob=0.5, max_scale=1, **kwargs):
+ """Random distort bbox
+ """
+ self.prob = prob
+ self.max_scale = max_scale
+
+ def __call__(self, data):
+ if random.random() > self.prob:
+ return data
+ bbox = np.array(data['bbox'])
+ rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
+ bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
+ data['bbox'] = np.clip(data['bbox'], 0, 1000)
+ data['bbox'] = bbox.tolist()
+ sys.stdout.flush()
+ return data
diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py
index 244893d97d0e422c5ca270bdece689e13aba2b07..f9cd4634731a26dd990d6ffac3d8defc8cdf7e97 100755
--- a/ppocr/losses/vqa_token_layoutlm_loss.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch):
- labels = batch[1]
- attention_mask = batch[4]
+ labels = batch[5]
+ attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape(
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index f8959e263ecffb301dff227ff22e5e913375f919..ab2939c24d92fcebcd9afe574cb02a6b113190fa 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -18,9 +18,10 @@ __all__ = ["build_backbone"]
def build_backbone(config, model_type):
if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3
- from .det_resnet_vd import ResNet
+ from .det_resnet import ResNet
+ from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
- support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
+ support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
diff --git a/ppocr/modeling/backbones/det_resnet.py b/ppocr/modeling/backbones/det_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..87eef11cf0e33c24c0f539c8074b21f589345282
--- /dev/null
+++ b/ppocr/modeling/backbones/det_resnet.py
@@ -0,0 +1,236 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
+from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
+from paddle.nn.initializer import Uniform
+
+import math
+
+from paddle.vision.ops import DeformConv2D
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import Normal, Constant, XavierUniform
+from .det_resnet_vd import DeformableConvV2, ConvBNLayer
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ stride,
+ shortcut=True,
+ is_dcn=False):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=1,
+ act="relu", )
+ self.conv1 = ConvBNLayer(
+ in_channels=num_filters,
+ out_channels=num_filters,
+ kernel_size=3,
+ stride=stride,
+ act="relu",
+ is_dcn=is_dcn,
+ dcn_groups=1, )
+ self.conv2 = ConvBNLayer(
+ in_channels=num_filters,
+ out_channels=num_filters * 4,
+ kernel_size=1,
+ act=None, )
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters * 4,
+ kernel_size=1,
+ stride=stride, )
+
+ self.shortcut = shortcut
+
+ self._num_channels_out = num_filters * 4
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ stride,
+ shortcut=True,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=3,
+ stride=stride,
+ act="relu")
+ self.conv1 = ConvBNLayer(
+ in_channels=num_filters,
+ out_channels=num_filters,
+ kernel_size=3,
+ act=None)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=1,
+ stride=stride)
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ layers=50,
+ out_indices=None,
+ dcn_stage=None):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ self.input_image_channel = in_channels
+
+ supported_layers = [18, 34, 50, 101, 152]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ num_channels = [64, 256, 512,
+ 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512]
+
+ self.dcn_stage = dcn_stage if dcn_stage is not None else [
+ False, False, False, False
+ ]
+ self.out_indices = out_indices if out_indices is not None else [
+ 0, 1, 2, 3
+ ]
+
+ self.conv = ConvBNLayer(
+ in_channels=self.input_image_channel,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act="relu", )
+ self.pool2d_max = MaxPool2D(
+ kernel_size=3,
+ stride=2,
+ padding=1, )
+
+ self.stages = []
+ self.out_channels = []
+ if layers >= 50:
+ for block in range(len(depth)):
+ shortcut = False
+ block_list = []
+ is_dcn = self.dcn_stage[block]
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ conv_name,
+ BottleneckBlock(
+ num_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ is_dcn=is_dcn))
+ block_list.append(bottleneck_block)
+ shortcut = True
+ if block in self.out_indices:
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ shortcut = False
+ block_list = []
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ conv_name,
+ BasicBlock(
+ num_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut))
+ block_list.append(basic_block)
+ shortcut = True
+ if block in self.out_indices:
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ y = self.conv(inputs)
+ y = self.pool2d_max(y)
+ out = []
+ for i, block in enumerate(self.stages):
+ y = block(y)
+ if i in self.out_indices:
+ out.append(y)
+ return out
diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py
index 8c955a4af377374f21e7c09f0d10952f2fe1ceed..a421da0ab440e9b87c1c7efc7d2448f8f76ad205 100644
--- a/ppocr/modeling/backbones/det_resnet_vd.py
+++ b/ppocr/modeling/backbones/det_resnet_vd.py
@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
-__all__ = ["ResNet"]
+__all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
class DeformableConvV2(nn.Layer):
@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer):
kernel_size,
stride=1,
groups=1,
+ dcn_groups=1,
is_vd_mode=False,
act=None,
is_dcn=False):
@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer):
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
- groups=2, #groups,
+ groups=dcn_groups, #groups,
bias_attr=False)
self._batch_norm = nn.BatchNorm(out_channels, act=act)
@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer):
kernel_size=3,
stride=stride,
act='relu',
- is_dcn=is_dcn)
+ is_dcn=is_dcn,
+ dcn_groups=2)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer):
return y
-class ResNet(nn.Layer):
+class ResNet_vd(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
dcn_stage=None,
out_indices=None,
**kwargs):
- super(ResNet, self).__init__()
+ super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
@@ -321,7 +323,6 @@ class ResNet(nn.Layer):
for block in range(len(depth)):
block_list = []
shortcut = False
- # is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index ede5b7a35af65fac351277cefccd89b251f5cdb7..2fd1b1b2a78a98dba1930378f4a06783aadd8834 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -74,9 +74,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
position_ids=None,
output_hidden_states=False)
return x
@@ -96,13 +96,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- image=x[3],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
labels=None)
+ if not self.training:
+ return x
return x[0]
@@ -120,13 +122,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- image=x[3],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
labels=None)
+ if not self.training:
+ return x
return x[0]
@@ -140,12 +144,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
- labels=None,
- image=x[2],
- attention_mask=x[3],
- token_type_ids=x[4],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
+ labels=None,
entities=x[5],
relations=x[6])
return x
@@ -161,12 +165,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
- labels=None,
- image=x[2],
- attention_mask=x[3],
- token_type_ids=x[4],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
+ labels=None,
entities=x[5],
relations=x[6])
return x
diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py
index 93ed2dbfd1fac9bf2d163c54d23a20e16b537981..8c3f52a331db5daafab2a38c0a441edd44eb141d 100644
--- a/ppocr/modeling/necks/db_fpn.py
+++ b/ppocr/modeling/necks/db_fpn.py
@@ -105,9 +105,10 @@ class DSConv(nn.Layer):
class DBFPN(nn.Layer):
- def __init__(self, in_channels, out_channels, **kwargs):
+ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__()
self.out_channels = out_channels
+ self.use_asf = use_asf
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
@@ -163,6 +164,9 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
+ if self.use_asf is True:
+ self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
+
def forward(self, x):
c2, c3, c4, c5 = x
@@ -187,6 +191,10 @@ class DBFPN(nn.Layer):
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
+
+ if self.use_asf is True:
+ fuse = self.asf(fuse, [p5, p4, p3, p2])
+
return fuse
@@ -356,3 +364,64 @@ class LKPAN(nn.Layer):
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse
+
+
+class ASFBlock(nn.Layer):
+ """
+ This code is refered from:
+ https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
+ """
+
+ def __init__(self, in_channels, inter_channels, out_features_num=4):
+ """
+ Adaptive Scale Fusion (ASF) block of DBNet++
+ Args:
+ in_channels: the number of channels in the input data
+ inter_channels: the number of middle channels
+ out_features_num: the number of fused stages
+ """
+ super(ASFBlock, self).__init__()
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in_channels = in_channels
+ self.inter_channels = inter_channels
+ self.out_features_num = out_features_num
+ self.conv = nn.Conv2D(in_channels, inter_channels, 3, padding=1)
+
+ self.spatial_scale = nn.Sequential(
+ #Nx1xHxW
+ nn.Conv2D(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ bias_attr=False,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr)),
+ nn.ReLU(),
+ nn.Conv2D(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=1,
+ bias_attr=False,
+ weight_attr=ParamAttr(initializer=weight_attr)),
+ nn.Sigmoid())
+
+ self.channel_scale = nn.Sequential(
+ nn.Conv2D(
+ in_channels=inter_channels,
+ out_channels=out_features_num,
+ kernel_size=1,
+ bias_attr=False,
+ weight_attr=ParamAttr(initializer=weight_attr)),
+ nn.Sigmoid())
+
+ def forward(self, fuse_features, features_list):
+ fuse_features = self.conv(fuse_features)
+ spatial_x = paddle.mean(fuse_features, axis=1, keepdim=True)
+ attention_scores = self.spatial_scale(spatial_x) + fuse_features
+ attention_scores = self.channel_scale(attention_scores)
+ assert len(features_list) == self.out_features_num
+
+ out_list = []
+ for i in range(self.out_features_num):
+ out_list.append(attention_scores[:, i:i + 1] * features_list[i])
+ return paddle.concat(out_list, axis=1)
diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py
index fe251f36e736bb1eac8a71a8115c941cbd7443e6..8e05a84011e3c463dd15a3b6bd76f24fa3ab81ef 100644
--- a/ppocr/optimizer/learning_rate.py
+++ b/ppocr/optimizer/learning_rate.py
@@ -308,3 +308,38 @@ class Const(object):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
+
+
+class DecayLearningRate(object):
+ """
+ DecayLearningRate learning rate decay
+ new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
+ Args:
+ learning_rate(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ epochs(int): total training epochs
+ factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
+ end_lr(float): The minimum final learning rate. Default: 0.0.
+ """
+
+ def __init__(self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ factor=0.9,
+ end_lr=0,
+ **kwargs):
+ super(DecayLearningRate, self).__init__()
+ self.learning_rate = learning_rate
+ self.epochs = epochs + 1
+ self.factor = factor
+ self.end_lr = 0
+ self.decay_steps = step_each_epoch * epochs
+
+ def __call__(self):
+ learning_rate = lr.PolynomialDecay(
+ learning_rate=self.learning_rate,
+ decay_steps=self.decay_steps,
+ power=self.factor,
+ end_lr=self.end_lr)
+ return learning_rate
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
index 782cdea6c58c69e0d728787e0e21e200c9e13790..8a6669f71f5ae6a7a16931e565b43355de5928d9 100644
--- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, tuple):
+ preds = preds[0]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
- return self._metric(preds, batch[1])
+ return self._metric(preds, batch[5])
else:
return self._infer(preds, **kwargs)
@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]])
return decode_out_list, label_decode_out_list
- def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
+ def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
- for pred, attention_mask, segment_offset_id, ocr_info in zip(
- preds, attention_masks, segment_offset_ids, ocr_infos):
+ for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
+ ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py
index 4a25ff8b2fa182faaf4f4ce8909c9ec2e9b55ccc..b881fcab20bc5ca076a0002bd72349768c7d881a 100755
--- a/ppocr/utils/utility.py
+++ b/ppocr/utils/utility.py
@@ -91,18 +91,19 @@ def check_and_read_gif(img_path):
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
- lines = [line.strip() for line in lines]
- if "O" not in lines:
- lines.insert(0, "O")
- labels = []
- for line in lines:
- if line == "O":
- labels.append("O")
- else:
- labels.append("B-" + line)
- labels.append("I-" + line)
- label2id_map = {label: idx for idx, label in enumerate(labels)}
- id2label_map = {idx: label for idx, label in enumerate(labels)}
+ old_lines = [line.strip() for line in lines]
+ lines = ["O"]
+ for line in old_lines:
+ # "O" has already been in lines
+ if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
+ continue
+ lines.append(line)
+ labels = ["O"]
+ for line in lines[1:]:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
return label2id_map, id2label_map
diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py
index 7a8c1674a74f89299de59f7cd120b4577a7499d8..235eb572a3975b4446ae2f2c9ad9c8558d5c5ad8 100644
--- a/ppocr/utils/visual.py
+++ b/ppocr/utils/visual.py
@@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image,
ocr_results,
font_path="doc/fonts/simfang.ttf",
- font_size=18):
+ font_size=14):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
@@ -40,9 +40,15 @@ def draw_ser_results(image,
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
- text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
- draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
+ if "bbox" in ocr_info:
+ # draw with ocr engine
+ bbox = ocr_info["bbox"]
+ else:
+ # draw with ocr groundtruth
+ bbox = trans_poly_to_bbox(ocr_info["points"])
+ draw_box_txt(bbox, text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
@@ -62,6 +68,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color):
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
def draw_re_results(image,
result,
font_path="doc/fonts/simfang.ttf",
@@ -80,10 +94,10 @@ def draw_re_results(image,
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
- draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
- font_size, color_head)
- draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
- font_size, color_tail)
+ draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
+ draw, font, font_size, color_head)
+ draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
+ draw, font, font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
diff --git a/ppstructure/docs/kie.md b/ppstructure/docs/kie.md
index 35498b33478d1010fd2548dfcb8586b4710723a1..315dd9f7bafa6b6160489eab330e8d278b2d119d 100644
--- a/ppstructure/docs/kie.md
+++ b/ppstructure/docs/kie.md
@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
```
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
执行预测:
diff --git a/ppstructure/docs/kie_en.md b/ppstructure/docs/kie_en.md
index 1fe38b0b399e9290526dafa5409673dc87026db7..7b3752223dd765e780d56d146c90bd0f892aac7b 100644
--- a/ppstructure/docs/kie_en.md
+++ b/ppstructure/docs/kie_en.md
@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
```shell
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
Download the pretrained model and predict the result:
diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md
index c7dab999ff6e370c56c5495e22e91f117b3d1275..dabce3a5149a88833d38a4395e31ac1f82306c4f 100644
--- a/ppstructure/docs/models_list.md
+++ b/ppstructure/docs/models_list.md
@@ -1,11 +1,11 @@
# PP-Structure 系列模型列表
-- [1. 版面分析模型](#1)
-- [2. OCR和表格识别模型](#2)
- - [2.1 OCR](#21)
- - [2.2 表格识别模型](#22)
-- [3. VQA模型](#3)
-- [4. KIE模型](#4)
+- [1. 版面分析模型](#1-版面分析模型)
+- [2. OCR和表格识别模型](#2-ocr和表格识别模型)
+ - [2.1 OCR](#21-ocr)
+ - [2.2 表格识别模型](#22-表格识别模型)
+- [3. VQA模型](#3-vqa模型)
+- [4. KIE模型](#4-kie模型)
@@ -42,11 +42,11 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
-|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
+|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
-|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
+|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 4. KIE模型
diff --git a/ppstructure/docs/models_list_en.md b/ppstructure/docs/models_list_en.md
index b92c10c241df72c85649b64f915b4266cd3fe410..e133a0bb2a9b017207b5e92ea444aba4633a7457 100644
--- a/ppstructure/docs/models_list_en.md
+++ b/ppstructure/docs/models_list_en.md
@@ -1,11 +1,11 @@
# PP-Structure Model list
-- [1. Layout Analysis](#1)
-- [2. OCR and Table Recognition](#2)
- - [2.1 OCR](#21)
- - [2.2 Table Recognition](#22)
-- [3. VQA](#3)
-- [4. KIE](#4)
+- [1. Layout Analysis](#1-layout-analysis)
+- [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
+ - [2.1 OCR](#21-ocr)
+ - [2.2 Table Recognition](#22-table-recognition)
+- [3. VQA](#3-vqa)
+- [4. KIE](#4-kie)
@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- |
-|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
+|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
-|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
+|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 4. KIE
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 1ad902e7e6be95a6901e3774420fad337f594861..4ae56099b83a46c85ce2dc362c1c6417b324dbe1 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -40,6 +40,13 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
+ # params for vqa
+ parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
+ parser.add_argument("--ser_model_dir", type=str)
+ parser.add_argument(
+ "--ser_dict_path",
+ type=str,
+ default="../train_data/XFUND/class_list_xfun.txt")
# params for inference
parser.add_argument(
"--mode",
@@ -65,7 +72,7 @@ def init_args():
"--recovery",
type=bool,
default=False,
- help='Whether to enable layout of recovery')
+ help='Whether to enable layout of recovery')
return parser
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index e3a10671ddb6494eb15073e7ac007aa1e8e6a32a..05635265b5e5eff18429e2d595fc4195381299f5 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -1,19 +1,15 @@
English | [简体中文](README_ch.md)
-- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering)
- - [1. Introduction](#1-Introduction)
- - [2. Performance](#2-performance)
- - [3. Effect demo](#3-Effect-demo)
- - [3.1 SER](#31-ser)
- - [3.2 RE](#32-re)
- - [4. Install](#4-Install)
- - [4.1 Installation dependencies](#41-Install-dependencies)
- - [4.2 Install PaddleOCR](#42-Install-PaddleOCR)
- - [5. Usage](#5-Usage)
- - [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- - [5.2 SER](#52-ser)
- - [5.3 RE](#53-re)
- - [6. Reference](#6-Reference-Links)
+- [1 Introduction](#1-introduction)
+- [2. Performance](#2-performance)
+- [3. Effect demo](#3-effect-demo)
+ - [3.1 SER](#31-ser)
+ - [3.2 RE](#32-re)
+- [4. Install](#4-install)
+ - [4.1 Install dependencies](#41-install-dependencies)
+ - [5.3 RE](#53-re)
+- [6. Reference Links](#6-reference-links)
+- [License](#license)
# Document Visual Question Answering
@@ -125,13 +121,13 @@ If you want to experience the prediction process directly, you can download the
* Download the processed dataset
-The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar).
+The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar).
Download and unzip the dataset, and place the dataset in the current directory after unzipping.
```shell
-wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
````
* Convert the dataset
@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed
-* Use `OCR engine + SER` tandem prediction
+* `OCR + SER` tandem prediction based on training engine
-Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example:
+Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell
-CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
-* End-to-end evaluation of `OCR engine + SER` prediction system
+* End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
````
+* export model
+
+Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
+
+```shell
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
+```
+The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
+
+* `OCR + SER` tandem prediction based on prediction engine
+
+Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
+
+```shell
+cd ppstructure
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+```
+After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
### 5.3 RE
@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell
export CUDA_VISIBLE_DEVICES=0
-python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
+* export model
+
+cooming soon
+
+* `OCR + SER + RE` tandem prediction based on prediction engine
+
+cooming soon
+
## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/README_ch.md b/ppstructure/vqa/README_ch.md
index b677dc07bce6c1a752d753b6a1c538b4d3f99271..b421a82d3a1cbe39f5c740bea486ec26593ab20f 100644
--- a/ppstructure/vqa/README_ch.md
+++ b/ppstructure/vqa/README_ch.md
@@ -1,19 +1,19 @@
[English](README.md) | 简体中文
-- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa)
- - [1. 简介](#1-简介)
- - [2. 性能](#2-性能)
- - [3. 效果演示](#3-效果演示)
- - [3.1 SER](#31-ser)
- - [3.2 RE](#32-re)
- - [4. 安装](#4-安装)
- - [4.1 安装依赖](#41-安装依赖)
- - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- - [5. 使用](#5-使用)
- - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- - [5.2 SER](#52-ser)
- - [5.3 RE](#53-re)
- - [6. 参考链接](#6-参考链接)
+- [1. 简介](#1-简介)
+- [2. 性能](#2-性能)
+- [3. 效果演示](#3-效果演示)
+ - [3.1 SER](#31-ser)
+ - [3.2 RE](#32-re)
+- [4. 安装](#4-安装)
+ - [4.1 安装依赖](#41-安装依赖)
+ - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
+- [5. 使用](#5-使用)
+ - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
+ - [5.2 SER](#52-ser)
+ - [5.3 RE](#53-re)
+- [6. 参考链接](#6-参考链接)
+- [License](#license)
# 文档视觉问答(DOC-VQA)
@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt
* 下载处理好的数据集
-处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
+处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)。
下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell
-wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
```
* 转换数据集
@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
-* 使用`OCR引擎 + SER`串联预测
+* 基于训练引擎的`OCR + SER`串联预测
-使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例:
+使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
-* 对`OCR引擎 + SER`预测系统进行端到端评估
+* 对`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
+* 模型导出
+
+使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
+
+```shell
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
+```
+转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
+
+* 基于预测引擎的`OCR + SER`串联预测
+
+使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
+
+```shell
+cd ppstructure
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+```
+预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE
@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
-* 使用`OCR引擎 + SER + RE`串联预测
+* 基于训练引擎的`OCR + SER + RE`串联预测
-使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例:
+使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell
export CUDA_VISIBLE_DEVICES=0
-python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
+* 模型导出
+
+cooming soon
+
+* 基于预测引擎的`OCR + SER + RE`串联预测
+
+cooming soon
+
## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt
deleted file mode 100644
index 508e48112412f62538baf0c78bcf99ec8945196e..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/labels/labels_ser.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-QUESTION
-ANSWER
-HEADER
diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..de0bbfe72d80d9a16de8b09657a98dc5285bb348
--- /dev/null
+++ b/ppstructure/vqa/predict_vqa_token_ser.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import json
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.visual import draw_ser_results
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppstructure.utility import parse_args
+
+from paddleocr import PaddleOCR
+
+logger = get_logger()
+
+
+class SerPredictor(object):
+ def __init__(self, args):
+ self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+
+ pre_process_list = [{
+ 'VQATokenLabelEncode': {
+ 'algorithm': args.vqa_algorithm,
+ 'class_path': args.ser_dict_path,
+ 'contains_re': False,
+ 'ocr_engine': self.ocr_engine
+ }
+ }, {
+ 'VQATokenPad': {
+ 'max_seq_len': 512,
+ 'return_attention_mask': True
+ }
+ }, {
+ 'VQASerTokenChunk': {
+ 'max_seq_len': 512,
+ 'return_attention_mask': True
+ }
+ }, {
+ 'Resize': {
+ 'size': [224, 224]
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [58.395, 57.12, 57.375],
+ 'mean': [123.675, 116.28, 103.53],
+ 'scale': '1',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': [
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
+ 'entities'
+ ]
+ }
+ }]
+ postprocess_params = {
+ 'name': 'VQASerTokenLayoutLMPostProcess',
+ "class_path": args.ser_dict_path,
+ }
+
+ self.preprocess_op = create_operators(pre_process_list,
+ {'infer_mode': True})
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'ser', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ for idx in range(len(self.input_tensor)):
+ expand_input = np.expand_dims(data[idx], axis=0)
+ self.input_tensor[idx].copy_from_cpu(expand_input)
+
+ self.predictor.run()
+
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = outputs[0]
+
+ post_result = self.postprocess_op(
+ preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
+ elapse = time.time() - starttime
+ return post_result, elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ ser_predictor = SerPredictor(args)
+ count = 0
+ total_time = 0
+
+ os.makedirs(args.output, exist_ok=True)
+ with open(
+ os.path.join(args.output, 'infer.txt'), mode='w',
+ encoding='utf-8') as f_w:
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ img = img[:, :, ::-1]
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ ser_res, elapse = ser_predictor(img)
+ ser_res = ser_res[0]
+
+ res_str = '{}\t{}\n'.format(
+ image_file,
+ json.dumps(
+ {
+ "ocr_info": ser_res,
+ }, ensure_ascii=False))
+ f_w.write(res_str)
+
+ img_res = draw_ser_results(
+ image_file,
+ ser_res,
+ font_path="../doc/fonts/simfang.ttf", )
+
+ img_save_path = os.path.join(args.output,
+ os.path.basename(image_file))
+ cv2.imwrite(img_save_path, img_res)
+ logger.info("save vis result to {}".format(img_save_path))
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
index 0042ec0baedcc3e7bbecb922d10b93c95219219d..fcd882274c4402ba2a1d34f20ee6e2befa157121 100644
--- a/ppstructure/vqa/requirements.txt
+++ b/ppstructure/vqa/requirements.txt
@@ -1,4 +1,7 @@
sentencepiece
yacs
seqeval
-paddlenlp>=2.2.1
\ No newline at end of file
+paddlenlp>=2.2.1
+pypandoc
+attrdict
+python_docx
\ No newline at end of file
diff --git a/ppstructure/vqa/tools/trans_xfun_data.py b/ppstructure/vqa/tools/trans_xfun_data.py
index 93ec98163c6cec96ec93399c1d41524200ddc499..11d221bea40367f091b3e09dde42e87f2217a617 100644
--- a/ppstructure/vqa/tools/trans_xfun_data.py
+++ b/ppstructure/vqa/tools/trans_xfun_data.py
@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None):
json_info = json.loads(lines[0])
documents = json_info["documents"]
- label_info = {}
with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents):
+ label_info = []
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]
- label_info["height"] = img_info["height"]
- label_info["width"] = img_info["width"]
-
- label_info["ocr_info"] = []
-
for doc in document:
- label_info["ocr_info"].append({
- "text": doc["text"],
+ x1, y1, x2, y2 = doc["box"]
+ points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
+ label_info.append({
+ "transcription": doc["text"],
"label": doc["label"],
- "bbox": doc["box"],
+ "points": points,
"id": doc["id"],
- "linking": doc["linking"],
- "words": doc["words"]
+ "linking": doc["linking"]
})
fout.write(image_path + "\t" + json.dumps(
diff --git a/test_tipc/configs/det_r50_db++/train_infer_python.txt b/test_tipc/configs/det_r50_db++/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bcf393a52b0e073267aa7423960179d8b5eba4bd
--- /dev/null
+++ b/test_tipc/configs/det_r50_db++/train_infer_python.txt
@@ -0,0 +1,59 @@
+===========================train_params===========================
+model_name:det_r50_db++
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c configs/det/det_r50_db++_ic15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+inference_dir:null
+train_model:./inference/det_r50_db++_train/best_accuracy
+infer_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
+infer_quant:False
+inference:tools/infer/predict_det.py --det_algorithm="DB++"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+null:null
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8|16
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
index 8477a4fa74f7a0617104aa83617fc6f61b8234b3..24e4d760c37828c213741b9ff127d55df2f9335a 100644
--- a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
@@ -1,13 +1,13 @@
===========================train_params===========================
model_name:det_r50_vd_east_v2_0
python:python3.7
-gpu_list:0
+gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
-Global.pretrained_model:null
+Global.pretrained_model:./pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 2c9bd2901b52ff7d4b6af483a9aa201aef339099..8cb1a2133a3565fc7e3ad36d4274195cf8790deb 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -59,6 +59,9 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi
+ if [[ ${model_name} =~ "det_r50_db++" ]];then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate
+ fi
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
rm -rf ./train_data/ic15_data
@@ -120,6 +123,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
+ if [ ${model_name} == "det_r50_vd_east_v2_0" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
+ fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
diff --git a/test_tipc/test_ptq_inference_python.sh b/test_tipc/test_ptq_inference_python.sh
index c1aa3daa6c0c647d7a9f9980e7698a85ee09ca78..288e6098966be4aaf2953d627e7890963100cb6e 100644
--- a/test_tipc/test_ptq_inference_python.sh
+++ b/test_tipc/test_ptq_inference_python.sh
@@ -139,7 +139,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}_klquant"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
- export_log_path="${LOG_PATH}/_export_${Count}.log"
+ export_log_path="${LOG_PATH}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
echo $export_cmd
diff --git a/test_tipc/test_serving_infer_cpp.sh b/test_tipc/test_serving_infer_cpp.sh
index 6e313cd9493b84eef731a4e0d86d721a212b3ead..0be6a45adf3105f088a96336dddfbe9ac612f19b 100644
--- a/test_tipc/test_serving_infer_cpp.sh
+++ b/test_tipc/test_serving_infer_cpp.sh
@@ -87,8 +87,7 @@ function func_serving(){
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list})
cd ${serving_dir_value}
- unset https_proxy
- unset http_proxy
+
# cpp serving
for gpu_id in ${gpu_value[*]}; do
if [ ${gpu_id} = "null" ]; then
diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh
index b14517265db2f19d0e4c8b9bb8dc325494723527..4ccccc06e23ce086e7dac1f3446aae9130605444 100644
--- a/test_tipc/test_serving_infer_python.sh
+++ b/test_tipc/test_serving_infer_python.sh
@@ -112,8 +112,7 @@ function func_serving(){
cd ${serving_dir_value}
python=${python_list[0]}
- unset https_proxy
- unset http_proxy
+
# python serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index fa68cb2632ee69fe361f99093e7a2352006ed283..907efcec9008f89740971bb6d4253bafb44938c4 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -193,7 +193,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
- export_log_path="${LOG_PATH}/_export_${Count}.log"
+ export_log_path="${LOG_PATH}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
echo $export_cmd
@@ -295,6 +295,7 @@ else
fi
# run train
eval $cmd
+ eval "cat ${save_log}/train.log >> ${save_log}.log"
status_check $? "${cmd}" "${status_log}" "${model_name}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
diff --git a/tools/export_model.py b/tools/export_model.py
index b10d41d5b288258ad895cefa7d8cc243eff10546..65573cf46a9d650b8f833fdec43235de57faf5ac 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -97,6 +97,22 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
+ input_spec = [
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # input_ids
+ paddle.static.InputSpec(
+ shape=[None, 512, 4], dtype="int64"), # bbox
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # attention_mask
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # token_type_ids
+ paddle.static.InputSpec(
+ shape=[None, 3, 224, 224], dtype="int64"), # image
+ ]
+ if arch_config["algorithm"] == "LayoutLM":
+ input_spec.pop(4)
+ model = to_static(model, input_spec=[input_spec])
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
@@ -172,7 +188,7 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
- load_model(config, model)
+ load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval()
save_path = config["Global"]["save_inference_dir"]
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 7b6bebf1fbced2de5bb0e4e75840fb8dd7beb374..394a48948b1f284bd405532769b76eeb298668bd 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
+ elif self.det_algorithm == "DB++":
+ postprocess_params['name'] = 'DBPostProcess'
+ postprocess_params["thresh"] = args.det_db_thresh
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
+ postprocess_params["max_candidates"] = 1000
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
+ postprocess_params["use_dilation"] = args.use_dilation
+ postprocess_params["score_mode"] = args.det_db_score_mode
+ pre_process_list[1] = {
+ 'NormalizeImage': {
+ 'std': [1.0, 1.0, 1.0],
+ 'mean':
+ [0.48109378172549, 0.45752457890196, 0.40787054090196],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
@@ -231,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
- elif self.det_algorithm in ['DB', 'PSE']:
+ elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 366212f228eec33f11c825bfaf1e360258af9b2e..7eb77dec74bf283936e1143edcb5b5dfc28365bd 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir
elif mode == 'table':
model_dir = args.table_model_dir
+ elif mode == 'ser':
+ model_dir = args.ser_model_dir
else:
model_dir = args.e2e_model_dir
@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
- for name in input_names:
- input_tensor = predictor.get_input_handle(name)
+ if mode in ['ser', 're']:
+ input_tensor = []
+ for name in input_names:
+ input_tensor.append(predictor.get_input_handle(name))
+ else:
+ for name in input_names:
+ input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
diff --git a/tools/infer_kie.py b/tools/infer_kie.py
index 0cb0b8702cbd7ea74a7b7fcff69122731578a1bd..346e2e0aeeee695ab49577b6b13dcc058150df1a 100755
--- a/tools/infer_kie.py
+++ b/tools/infer_kie.py
@@ -39,13 +39,12 @@ import time
def read_class_list(filepath):
- dict = {}
+ ret = {}
with open(filepath, "r") as f:
lines = f.readlines()
- for line in lines:
- key, value = line.split(" ")
- dict[key] = value.rstrip()
- return dict
+ for idx, line in enumerate(lines):
+ ret[idx] = line.strip("\n")
+ return ret
def draw_kie_result(batch, node, idx_to_cls, count):
@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box]))
- pred_label = str(node_pred_label[i])
+ pred_label = node_pred_label[i]
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i])
@@ -109,8 +108,7 @@ def main():
save_res_path = config['Global']['save_res_path']
class_path = config['Global']['class_path']
idx_to_cls = read_class_list(class_path)
- if not os.path.exists(os.path.dirname(save_res_path)):
- os.makedirs(os.path.dirname(save_res_path))
+ os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
model.eval()
diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py
index 83ed72b392e627c161903c3945f57be0abfabc2b..0173a554cace31e20ab47dbe36d132a4dbb2127b 100755
--- a/tools/infer_vqa_token_ser.py
+++ b/tools/infer_vqa_token_ser.py
@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict
data_dict = defaultdict(list)
to_tensor_idxs = []
+
for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object):
def __init__(self, config):
global_config = config['Global']
+ self.algorithm = config['Architecture']["algorithm"]
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR
- self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+ self.ocr_engine = PaddleOCR(
+ use_angle_cls=False,
+ show_log=False,
+ use_gpu=global_config['use_gpu'])
# create data ops
transforms = []
@@ -80,29 +85,30 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
- 'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
- 'token_type_ids', 'segment_offset_id', 'ocr_info',
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
transforms.append(op)
- global_config['infer_mode'] = True
+ if config["Global"].get("infer_mode", None) is None:
+ global_config['infer_mode'] = True
self.ops = create_operators(config['Eval']['dataset']['transforms'],
global_config)
self.model.eval()
- def __call__(self, img_path):
- with open(img_path, 'rb') as f:
+ def __call__(self, data):
+ with open(data["img_path"], 'rb') as f:
img = f.read()
- data = {'image': img}
+ data["image"] = img
batch = transform(data, self.ops)
batch = to_tensor(batch)
preds = self.model(batch)
+ if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
+ preds = preds[0]
+
post_result = self.post_process_class(
- preds,
- attention_masks=batch[4],
- segment_offset_ids=batch[6],
- ocr_infos=batch[7])
+ preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
return post_result, batch
@@ -112,20 +118,33 @@ if __name__ == '__main__':
ser_engine = SerPredictor(config)
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ if config["Global"].get("infer_mode", None) is False:
+ data_dir = config['Eval']['dataset']['data_dir']
+ with open(config['Global']['infer_img'], "rb") as f:
+ infer_imgs = f.readlines()
+ else:
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
+ for idx, info in enumerate(infer_imgs):
+ if config["Global"].get("infer_mode", None) is False:
+ data_line = info.decode('utf-8')
+ substr = data_line.strip("\n").split("\t")
+ img_path = os.path.join(data_dir, substr[0])
+ data = {'img_path': img_path, 'label': substr[1]}
+ else:
+ img_path = info
+ data = {'img_path': img_path}
+
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
- result, _ = ser_engine(img_path)
+ result, _ = ser_engine(data)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
@@ -133,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
+
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py
index 6210f7f3c24227c9d366b08ce93ccfe4df849ce1..20ab1fe176c3be75f7a7b01a8d77df6419c58c75 100755
--- a/tools/infer_vqa_token_ser_re.py
+++ b/tools/infer_vqa_token_ser_re.py
@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
-from tools.program import ArgsParser, load_config, merge_config, check_gpu
+from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor
@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7)
ser_inputs.pop(6)
- ser_inputs.pop(1)
+ ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch
@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval()
def __call__(self, img_path):
- ser_results, ser_inputs = self.ser_engine(img_path)
- paddle.save(ser_inputs, 'ser_inputs.npy')
- paddle.save(ser_results, 'ser_results.npy')
+ ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
- check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
config['Global']['save_res_path'],
- os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
result = ser_re_engine(img_path)
result = result[0]
@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
+
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
diff --git a/tools/program.py b/tools/program.py
index aa3ba82c44d6afba725a8059dc7f8cae41947b3d..7bd54ba083b912bd489efb1a763a8169685f2d9a 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -576,8 +576,8 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
- 'ViTSTR', 'ABINet'
+ 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++'
]
if use_xpu: