提交 d10d8432 编写于 作者: A andyjpaddle

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

...@@ -28,7 +28,7 @@ Architecture: ...@@ -28,7 +28,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
Neck: Neck:
name: DBFPN name: DBFPN
......
...@@ -45,7 +45,7 @@ Architecture: ...@@ -45,7 +45,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
Neck: Neck:
name: DBFPN name: DBFPN
......
...@@ -61,7 +61,7 @@ Architecture: ...@@ -61,7 +61,7 @@ Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB
Backbone: Backbone:
name: ResNet name: ResNet_vd
in_channels: 3 in_channels: 3
layers: 50 layers: 50
Neck: Neck:
......
...@@ -25,7 +25,7 @@ Architecture: ...@@ -25,7 +25,7 @@ Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB
Backbone: Backbone:
name: ResNet name: ResNet_vd
in_channels: 3 in_channels: 3
layers: 50 layers: 50
Neck: Neck:
...@@ -40,7 +40,7 @@ Architecture: ...@@ -40,7 +40,7 @@ Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB
Backbone: Backbone:
name: ResNet name: ResNet_vd
in_channels: 3 in_channels: 3
layers: 50 layers: 50
Neck: Neck:
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
disable_se: True disable_se: True
Neck: Neck:
......
Global:
debug: false
use_gpu: true
epoch_num: 1000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_r50_icdar15/
save_epoch_step: 200
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 1000
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 1152
- 2048
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
Global:
debug: false
use_gpu: true
epoch_num: 1000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_r50_td_tr/
save_epoch_step: 200
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 1000
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/TD_TR/TD500/train_gt_labels.txt
- ./train_data/TD_TR/TR400/gt_labels.txt
ratio_list:
- 1.0
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/TD_TR/TD500/test_gt_labels.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 736
- 736
keep_ratio: True
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: DBFPN name: DBFPN
......
...@@ -21,7 +21,7 @@ Architecture: ...@@ -21,7 +21,7 @@ Architecture:
algorithm: FCE algorithm: FCE
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
dcn_stage: [False, True, True, True] dcn_stage: [False, True, True, True]
out_indices: [1,2,3] out_indices: [1,2,3]
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: EAST algorithm: EAST
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: EASTFPN name: EASTFPN
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: PSE algorithm: PSE
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: FPN name: FPN
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
disable_se: True disable_se: True
Neck: Neck:
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2048 seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re/
Architecture: Architecture:
...@@ -21,7 +21,7 @@ Architecture: ...@@ -21,7 +21,7 @@ Architecture:
Backbone: Backbone:
name: LayoutLMv2ForRe name: LayoutLMv2ForRe
pretrained: True pretrained: True
checkpoints: checkpoints:
Loss: Loss:
name: LossFromOutput name: LossFromOutput
...@@ -52,7 +52,7 @@ Train: ...@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -61,7 +61,7 @@ Train: ...@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -114,7 +114,7 @@ Eval: ...@@ -114,7 +114,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re/
Architecture: Architecture:
...@@ -52,7 +52,7 @@ Train: ...@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -61,7 +61,7 @@ Train: ...@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -114,7 +114,7 @@ Eval: ...@@ -114,7 +114,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser/
Architecture: Architecture:
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -112,7 +112,7 @@ Eval: ...@@ -112,7 +112,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser/
Architecture: Architecture:
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -113,7 +113,7 @@ Eval: ...@@ -113,7 +113,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -43,7 +43,7 @@ Optimizer: ...@@ -43,7 +43,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -113,7 +113,7 @@ Eval: ...@@ -113,7 +113,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
# DB # DB与DB++
- [1. 算法简介](#1) - [1. 算法简介](#1)
- [2. 环境配置](#2) - [2. 环境配置](#2)
...@@ -21,12 +21,24 @@ ...@@ -21,12 +21,24 @@
> Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang > Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang
> AAAI, 2020 > AAAI, 2020
> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
> Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang
> TPAMI, 2022
在ICDAR2015文本检测公开数据集上,算法复现效果如下: 在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接| |模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|DB++|ResNet50|[configs/det/det_r50_db++_ic15.yml](../../configs/det/det_r50_db++_ic15.yml)|90.89%|82.66%|86.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
在TD_TR文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB++|ResNet50|[configs/det/det_r50_db++_td_tr.yml](../../configs/det/det_r50_db++_td_tr.yml)|92.92%|86.48%|89.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_td_tr_train.tar)|
<a name="2"></a> <a name="2"></a>
...@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai ...@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai
DB文本检测模型推理,可以执行如下命令: DB文本检测模型推理,可以执行如下命令:
```shell ```shell
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" --det_algorithm="DB"
``` ```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
...@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式: ...@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式:
pages={11474--11481}, pages={11474--11481},
year={2020} year={2020}
} }
```
\ No newline at end of file @article{liao2022real,
title={Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion},
author={Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2022},
publisher={IEEE}
}
```
...@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ...@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
| ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads| [train](https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt) / [test](https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt) | | ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads| [train](https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt) / [test](https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt) |
| ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 | | ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 |
| total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 | | total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 |
| td tr |https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar| 图片下载地址中已包含 |
#### 1.2.1 ICDAR 2015 #### 1.2.1 ICDAR 2015
ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。 ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。
......
...@@ -912,7 +912,7 @@ class VQATokenLabelEncode(object): ...@@ -912,7 +912,7 @@ class VQATokenLabelEncode(object):
label = info['label'] label = info['label']
gt_label = self._parse_label(label, encode_res) gt_label = self._parse_label(label, encode_res)
# construct entities for re # construct entities for re
if train_re: if train_re:
if gt_label[0] != self.label2id_map["O"]: if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities) entity_id_to_index_map[info["id"]] = len(entities)
......
...@@ -205,9 +205,12 @@ class DetResizeForTest(object): ...@@ -205,9 +205,12 @@ class DetResizeForTest(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__() super(DetResizeForTest, self).__init__()
self.resize_type = 0 self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs: if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape'] self.image_shape = kwargs['image_shape']
self.resize_type = 1 self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs: elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len'] self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min') self.limit_type = kwargs.get('limit_type', 'min')
...@@ -237,6 +240,10 @@ class DetResizeForTest(object): ...@@ -237,6 +240,10 @@ class DetResizeForTest(object):
def resize_image_type1(self, img): def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c) ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h))) img = cv2.resize(img, (int(resize_w), int(resize_h)))
......
...@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer): ...@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch): def forward(self, predicts, batch):
labels = batch[1] labels = batch[5]
attention_mask = batch[4] attention_mask = batch[2]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape( active_outputs = predicts.reshape(
......
...@@ -18,9 +18,10 @@ __all__ = ["build_backbone"] ...@@ -18,9 +18,10 @@ __all__ = ["build_backbone"]
def build_backbone(config, model_type): def build_backbone(config, model_type):
if model_type == "det" or model_type == "table": if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"] support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
elif model_type == "rec" or model_type == "cls": elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
from .det_resnet_vd import DeformableConvV2, ConvBNLayer
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
is_dcn=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
act="relu", )
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu",
is_dcn=is_dcn,
dcn_groups=1, )
self.conv2 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters * 4,
kernel_size=1,
act=None, )
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters * 4,
kernel_size=1,
stride=stride, )
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu")
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
act=None)
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
out_indices=None,
dcn_stage=None):
super(ResNet, self).__init__()
self.layers = layers
self.input_image_channel = in_channels
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.dcn_stage = dcn_stage if dcn_stage is not None else [
False, False, False, False
]
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
self.conv = ConvBNLayer(
in_channels=self.input_image_channel,
out_channels=64,
kernel_size=7,
stride=2,
act="relu", )
self.pool2d_max = MaxPool2D(
kernel_size=3,
stride=2,
padding=1, )
self.stages = []
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list = []
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
is_dcn=is_dcn))
block_list.append(bottleneck_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
shortcut = False
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
block_list.append(basic_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
out = []
for i, block in enumerate(self.stages):
y = block(y)
if i in self.out_indices:
out.append(y)
return out
...@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D ...@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform from paddle.nn.initializer import Normal, Constant, XavierUniform
__all__ = ["ResNet"] __all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
class DeformableConvV2(nn.Layer): class DeformableConvV2(nn.Layer):
...@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer): ...@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer):
kernel_size, kernel_size,
stride=1, stride=1,
groups=1, groups=1,
dcn_groups=1,
is_vd_mode=False, is_vd_mode=False,
act=None, act=None,
is_dcn=False): is_dcn=False):
...@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer): ...@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
groups=2, #groups, groups=dcn_groups, #groups,
bias_attr=False) bias_attr=False)
self._batch_norm = nn.BatchNorm(out_channels, act=act) self._batch_norm = nn.BatchNorm(out_channels, act=act)
...@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer): ...@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer):
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
act='relu', act='relu',
is_dcn=is_dcn) is_dcn=is_dcn,
dcn_groups=2)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels * 4, out_channels=out_channels * 4,
...@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer): ...@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer):
return y return y
class ResNet(nn.Layer): class ResNet_vd(nn.Layer):
def __init__(self, def __init__(self,
in_channels=3, in_channels=3,
layers=50, layers=50,
dcn_stage=None, dcn_stage=None,
out_indices=None, out_indices=None,
**kwargs): **kwargs):
super(ResNet, self).__init__() super(ResNet_vd, self).__init__()
self.layers = layers self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200] supported_layers = [18, 34, 50, 101, 152, 200]
...@@ -321,7 +323,6 @@ class ResNet(nn.Layer): ...@@ -321,7 +323,6 @@ class ResNet(nn.Layer):
for block in range(len(depth)): for block in range(len(depth)):
block_list = [] block_list = []
shortcut = False shortcut = False
# is_dcn = self.dcn_stage[block]
for i in range(depth[block]): for i in range(depth[block]):
basic_block = self.add_sublayer( basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
......
...@@ -74,9 +74,9 @@ class LayoutLMForSer(NLPBaseModel): ...@@ -74,9 +74,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
attention_mask=x[4], attention_mask=x[2],
token_type_ids=x[5], token_type_ids=x[3],
position_ids=None, position_ids=None,
output_hidden_states=False) output_hidden_states=False)
return x return x
...@@ -96,13 +96,15 @@ class LayoutLMv2ForSer(NLPBaseModel): ...@@ -96,13 +96,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
image=x[3], attention_mask=x[2],
attention_mask=x[4], token_type_ids=x[3],
token_type_ids=x[5], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training:
return x
return x[0] return x[0]
...@@ -120,13 +122,15 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -120,13 +122,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
image=x[3], attention_mask=x[2],
attention_mask=x[4], token_type_ids=x[3],
token_type_ids=x[5], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training:
return x
return x[0] return x[0]
...@@ -140,12 +144,12 @@ class LayoutLMv2ForRe(NLPBaseModel): ...@@ -140,12 +144,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
labels=None, attention_mask=x[2],
image=x[2], token_type_ids=x[3],
attention_mask=x[3], image=x[4],
token_type_ids=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None,
entities=x[5], entities=x[5],
relations=x[6]) relations=x[6])
return x return x
...@@ -161,12 +165,12 @@ class LayoutXLMForRe(NLPBaseModel): ...@@ -161,12 +165,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
labels=None, attention_mask=x[2],
image=x[2], token_type_ids=x[3],
attention_mask=x[3], image=x[4],
token_type_ids=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None,
entities=x[5], entities=x[5],
relations=x[6]) relations=x[6])
return x return x
...@@ -105,9 +105,10 @@ class DSConv(nn.Layer): ...@@ -105,9 +105,10 @@ class DSConv(nn.Layer):
class DBFPN(nn.Layer): class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__() super(DBFPN, self).__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.use_asf = use_asf
weight_attr = paddle.nn.initializer.KaimingUniform() weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D( self.in2_conv = nn.Conv2D(
...@@ -163,6 +164,9 @@ class DBFPN(nn.Layer): ...@@ -163,6 +164,9 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr), weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False) bias_attr=False)
if self.use_asf is True:
self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
def forward(self, x): def forward(self, x):
c2, c3, c4, c5 = x c2, c3, c4, c5 = x
...@@ -187,6 +191,10 @@ class DBFPN(nn.Layer): ...@@ -187,6 +191,10 @@ class DBFPN(nn.Layer):
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1) fuse = paddle.concat([p5, p4, p3, p2], axis=1)
if self.use_asf is True:
fuse = self.asf(fuse, [p5, p4, p3, p2])
return fuse return fuse
...@@ -356,3 +364,64 @@ class LKPAN(nn.Layer): ...@@ -356,3 +364,64 @@ class LKPAN(nn.Layer):
fuse = paddle.concat([p5, p4, p3, p2], axis=1) fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse return fuse
class ASFBlock(nn.Layer):
"""
This code is refered from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""
def __init__(self, in_channels, inter_channels, out_features_num=4):
"""
Adaptive Scale Fusion (ASF) block of DBNet++
Args:
in_channels: the number of channels in the input data
inter_channels: the number of middle channels
out_features_num: the number of fused stages
"""
super(ASFBlock, self).__init__()
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.out_features_num = out_features_num
self.conv = nn.Conv2D(in_channels, inter_channels, 3, padding=1)
self.spatial_scale = nn.Sequential(
#Nx1xHxW
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=3,
bias_attr=False,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.ReLU(),
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=1,
bias_attr=False,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.Sigmoid())
self.channel_scale = nn.Sequential(
nn.Conv2D(
in_channels=inter_channels,
out_channels=out_features_num,
kernel_size=1,
bias_attr=False,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.Sigmoid())
def forward(self, fuse_features, features_list):
fuse_features = self.conv(fuse_features)
spatial_x = paddle.mean(fuse_features, axis=1, keepdim=True)
attention_scores = self.spatial_scale(spatial_x) + fuse_features
attention_scores = self.channel_scale(attention_scores)
assert len(features_list) == self.out_features_num
out_list = []
for i in range(self.out_features_num):
out_list.append(attention_scores[:, i:i + 1] * features_list[i])
return paddle.concat(out_list, axis=1)
...@@ -308,3 +308,38 @@ class Const(object): ...@@ -308,3 +308,38 @@ class Const(object):
end_lr=self.learning_rate, end_lr=self.learning_rate,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class DecayLearningRate(object):
"""
DecayLearningRate learning rate decay
new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
Args:
learning_rate(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
end_lr(float): The minimum final learning rate. Default: 0.0.
"""
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
factor=0.9,
end_lr=0,
**kwargs):
super(DecayLearningRate, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs + 1
self.factor = factor
self.end_lr = 0
self.decay_steps = step_each_epoch * epochs
def __call__(self):
learning_rate = lr.PolynomialDecay(
learning_rate=self.learning_rate,
decay_steps=self.decay_steps,
power=self.factor,
end_lr=self.end_lr)
return learning_rate
...@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs): def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[0]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
if batch is not None: if batch is not None:
return self._metric(preds, batch[1]) return self._metric(preds, batch[5])
else: else:
return self._infer(preds, **kwargs) return self._infer(preds, **kwargs)
...@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]]) j]])
return decode_out_list, label_decode_out_list return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): def _infer(self, preds, segment_offset_ids, ocr_infos):
results = [] results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip( for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
preds, attention_masks, segment_offset_ids, ocr_infos): ocr_infos):
pred = np.argmax(pred, axis=1) pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred] pred = [self.id2label_map[idx] for idx in pred]
......
# PP-Structure 系列模型列表 # PP-Structure 系列模型列表
- [1. 版面分析模型](#1) - [1. 版面分析模型](#1-版面分析模型)
- [2. OCR和表格识别模型](#2) - [2. OCR和表格识别模型](#2-ocr和表格识别模型)
- [2.1 OCR](#21) - [2.1 OCR](#21-ocr)
- [2.2 表格识别模型](#22) - [2.2 表格识别模型](#22-表格识别模型)
- [3. VQA模型](#3) - [3. VQA模型](#3-vqa模型)
- [4. KIE模型](#4) - [4. KIE模型](#4-kie模型)
<a name="1"></a> <a name="1"></a>
...@@ -42,11 +42,11 @@ ...@@ -42,11 +42,11 @@
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> <a name="4"></a>
## 4. KIE模型 ## 4. KIE模型
......
# PP-Structure Model list # PP-Structure Model list
- [1. Layout Analysis](#1) - [1. Layout Analysis](#1-layout-analysis)
- [2. OCR and Table Recognition](#2) - [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
- [2.1 OCR](#21) - [2.1 OCR](#21-ocr)
- [2.2 Table Recognition](#22) - [2.2 Table Recognition](#22-table-recognition)
- [3. VQA](#3) - [3. VQA](#3-vqa)
- [4. KIE](#4) - [4. KIE](#4-kie)
<a name="1"></a> <a name="1"></a>
...@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model ...@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download| |model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- | | --- |----------------------------------------------------------------| --- | --- |
|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> <a name="4"></a>
## 4. KIE ## 4. KIE
......
...@@ -40,6 +40,13 @@ def init_args(): ...@@ -40,6 +40,13 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
parser.add_argument(
"--ser_dict_path",
type=str,
default="../train_data/XFUND/class_list_xfun.txt")
# params for inference # params for inference
parser.add_argument( parser.add_argument(
"--mode", "--mode",
...@@ -65,7 +72,7 @@ def init_args(): ...@@ -65,7 +72,7 @@ def init_args():
"--recovery", "--recovery",
type=bool, type=bool,
default=False, default=False,
help='Whether to enable layout of recovery') help='Whether to enable layout of recovery')
return parser return parser
......
English | [简体中文](README_ch.md) English | [简体中文](README_ch.md)
- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering) - [1 Introduction](#1-introduction)
- [1. Introduction](#1-Introduction) - [2. Performance](#2-performance)
- [2. Performance](#2-performance) - [3. Effect demo](#3-effect-demo)
- [3. Effect demo](#3-Effect-demo) - [3.1 SER](#31-ser)
- [3.1 SER](#31-ser) - [3.2 RE](#32-re)
- [3.2 RE](#32-re) - [4. Install](#4-install)
- [4. Install](#4-Install) - [4.1 Install dependencies](#41-install-dependencies)
- [4.1 Installation dependencies](#41-Install-dependencies) - [5.3 RE](#53-re)
- [4.2 Install PaddleOCR](#42-Install-PaddleOCR) - [6. Reference Links](#6-reference-links)
- [5. Usage](#5-Usage) - [License](#license)
- [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. Reference](#6-Reference-Links)
# Document Visual Question Answering # Document Visual Question Answering
...@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o ...@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```` ````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed Finally, `precision`, `recall`, `hmean` and other indicators will be printed
* Use `OCR engine + SER` tandem prediction * `OCR + SER` tandem prediction based on training engine
Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example: Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell ```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```` ````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`. Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* End-to-end evaluation of `OCR engine + SER` prediction system * End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate. First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
...@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o ...@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```` ````
* export model
Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
* `OCR + SER` tandem prediction based on prediction engine
Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
<a name="53"></a> <a name="53"></a>
### 5.3 RE ### 5.3 RE
...@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed ...@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example: Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```` ````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`. Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* export model
cooming soon
* `OCR + SER + RE` tandem prediction based on prediction engine
cooming soon
## 6. Reference Links ## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
[English](README.md) | 简体中文 [English](README.md) | 简体中文
- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa) - [1. 简介](#1-简介)
- [1. 简介](#1-简介) - [2. 性能](#2-性能)
- [2. 性能](#2-性能) - [3. 效果演示](#3-效果演示)
- [3. 效果演示](#3-效果演示) - [3.1 SER](#31-ser)
- [3.1 SER](#31-ser) - [3.2 RE](#32-re)
- [3.2 RE](#32-re) - [4. 安装](#4-安装)
- [4. 安装](#4-安装) - [4.1 安装依赖](#41-安装依赖)
- [4.1 安装依赖](#41-安装依赖) - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa) - [5. 使用](#5-使用)
- [5. 使用](#5-使用) - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- [5.1 数据和预训练模型准备](#51-数据和预训练模型准备) - [5.2 SER](#52-ser)
- [5.2 SER](#52-ser) - [5.3 RE](#53-re)
- [5.3 RE](#53-re) - [6. 参考链接](#6-参考链接)
- [6. 参考链接](#6-参考链接) - [License](#license)
# 文档视觉问答(DOC-VQA) # 文档视觉问答(DOC-VQA)
...@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o ...@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标 最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER`串联预测 * 基于训练引擎的`OCR + SER`串联预测
使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例: 使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell ```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt` 最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。 首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
...@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l ...@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
``` ```
* 模型导出
使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
* 基于预测引擎的`OCR + SER`串联预测
使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE ### 5.3 RE
...@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o ...@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标 最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER + RE`串联预测 * 基于训练引擎的`OCR + SER + RE`串联预测
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例: 使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt` 最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
* 模型导出
cooming soon
* 基于预测引擎的`OCR + SER + RE`串联预测
cooming soon
## 6. 参考链接 ## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import numpy as np
import time
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args
from paddleocr import PaddleOCR
logger = get_logger()
class SerPredictor(object):
def __init__(self, args):
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
pre_process_list = [{
'VQATokenLabelEncode': {
'algorithm': args.vqa_algorithm,
'class_path': args.ser_dict_path,
'contains_re': False,
'ocr_engine': self.ocr_engine
}
}, {
'VQATokenPad': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'VQASerTokenChunk': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'Resize': {
'size': [224, 224]
}
}, {
'NormalizeImage': {
'std': [58.395, 57.12, 57.375],
'mean': [123.675, 116.28, 103.53],
'scale': '1',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
}
}]
postprocess_params = {
'name': 'VQASerTokenLayoutLMPostProcess',
"class_path": args.ser_dict_path,
}
self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'ser', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
for idx in range(len(self.input_tensor)):
expand_input = np.expand_dims(data[idx], axis=0)
self.input_tensor[idx].copy_from_cpu(expand_input)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
post_result = self.postprocess_op(
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
elapse = time.time() - starttime
return post_result, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerPredictor(args)
count = 0
total_time = 0
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
ser_res, elapse = ser_predictor(img)
ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
}, ensure_ascii=False))
f_w.write(res_str)
img_res = draw_ser_results(
image_file,
ser_res,
font_path="../doc/fonts/simfang.ttf", )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
sentencepiece sentencepiece
yacs yacs
seqeval seqeval
paddlenlp>=2.2.1 paddlenlp>=2.2.1
\ No newline at end of file pypandoc
attrdict
python_docx
\ No newline at end of file
===========================train_params===========================
model_name:det_r50_db++
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/det_r50_db++_ic15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_db++_train/best_accuracy
infer_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
infer_quant:False
inference:tools/infer/predict_det.py --det_algorithm="DB++"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
...@@ -59,6 +59,9 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -59,6 +59,9 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi fi
if [[ ${model_name} =~ "det_r50_db++" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate
fi
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../ cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
rm -rf ./train_data/ic15_data rm -rf ./train_data/ic15_data
......
...@@ -97,6 +97,22 @@ def export_single_model(model, ...@@ -97,6 +97,22 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image
]
if arch_config["algorithm"] == "LayoutLM":
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec": if arch_config["model_type"] == "rec":
...@@ -172,7 +188,7 @@ def main(): ...@@ -172,7 +188,7 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
load_model(config, model) load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval() model.eval()
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -67,6 +67,23 @@ class TextDetector(object): ...@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
'mean':
[0.48109378172549, 0.45752457890196, 0.40787054090196],
'scale': '1./255.',
'order': 'hwc'
}
}
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
...@@ -231,7 +248,7 @@ class TextDetector(object): ...@@ -231,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1] preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2] preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']: elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE': elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
......
...@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger): ...@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
elif mode == 'table': elif mode == 'table':
model_dir = args.table_model_dir model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
...@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger): ...@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: if mode in ['ser', 're']:
input_tensor = predictor.get_input_handle(name) input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
else:
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor) output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config return predictor, input_tensor, output_tensors, config
......
...@@ -44,6 +44,7 @@ def to_tensor(data): ...@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict from collections import defaultdict
data_dict = defaultdict(list) data_dict = defaultdict(list)
to_tensor_idxs = [] to_tensor_idxs = []
for idx, v in enumerate(data): for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs: if idx not in to_tensor_idxs:
...@@ -57,6 +58,7 @@ def to_tensor(data): ...@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object): class SerPredictor(object):
def __init__(self, config): def __init__(self, config):
global_config = config['Global'] global_config = config['Global']
self.algorithm = config['Architecture']["algorithm"]
# build post process # build post process
self.post_process_class = build_post_process(config['PostProcess'], self.post_process_class = build_post_process(config['PostProcess'],
...@@ -70,7 +72,10 @@ class SerPredictor(object): ...@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops # create data ops
transforms = [] transforms = []
...@@ -80,8 +85,8 @@ class SerPredictor(object): ...@@ -80,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = [
'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'token_type_ids', 'segment_offset_id', 'ocr_info', 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities' 'entities'
] ]
...@@ -99,11 +104,11 @@ class SerPredictor(object): ...@@ -99,11 +104,11 @@ class SerPredictor(object):
batch = transform(data, self.ops) batch = transform(data, self.ops)
batch = to_tensor(batch) batch = to_tensor(batch)
preds = self.model(batch) preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0]
post_result = self.post_process_class( post_result = self.post_process_class(
preds, preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
attention_masks=batch[4],
segment_offset_ids=batch[6],
ocr_infos=batch[7])
return post_result, batch return post_result, batch
...@@ -138,8 +143,6 @@ if __name__ == '__main__': ...@@ -138,8 +143,6 @@ if __name__ == '__main__':
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(data) result, _ = ser_engine(data)
result = result[0] result = result[0]
...@@ -149,3 +152,6 @@ if __name__ == '__main__': ...@@ -149,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result) img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
...@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model ...@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
from tools.program import ArgsParser, load_config, merge_config, check_gpu from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor from tools.infer_vqa_token_ser import SerPredictor
...@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results): ...@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input # remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7) ser_inputs.pop(7)
ser_inputs.pop(6) ser_inputs.pop(6)
ser_inputs.pop(1) ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch return ser_inputs, entity_idx_dict_batch
...@@ -131,9 +131,7 @@ class SerRePredictor(object): ...@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval() self.model.eval()
def __call__(self, img_path): def __call__(self, img_path):
ser_results, ser_inputs = self.ser_engine(img_path) ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
paddle.save(ser_inputs, 'ser_inputs.npy')
paddle.save(ser_results, 'ser_results.npy')
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input) preds = self.model(re_input)
post_result = self.post_process_class( post_result = self.post_process_class(
...@@ -155,7 +153,6 @@ def preprocess(): ...@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
...@@ -185,9 +182,7 @@ if __name__ == '__main__': ...@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result = ser_re_engine(img_path) result = ser_re_engine(img_path)
result = result[0] result = result[0]
...@@ -197,3 +192,6 @@ if __name__ == '__main__': ...@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result) img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
...@@ -577,7 +577,7 @@ def preprocess(is_train=False): ...@@ -577,7 +577,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet' 'SVTR', 'ViTSTR', 'ABINet', 'DB++'
] ]
if use_xpu: if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册