From 9e4ae9dc122dda2c77ca55e63a520231ae26bebd Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sat, 6 Aug 2022 15:41:20 +0800 Subject: [PATCH] add vqa code (#7096) * add vqa code * add order ocr info * rename tb-yx order * polish configs * add trt offline-tuning * fix seed and remove unused configs --- .../re_layoutlmv2_xfund_zh.yml} | 4 +- .../re_layoutxlm_xfund_zh.yml} | 8 +- .../ser_layoutlm_xfund_zh.yml} | 5 +- .../ser_layoutlmv2_xfund_zh.yml} | 1 + .../ser_layoutxlm_xfund_zh.yml} | 1 + configs/kie/{ => sdmgr}/kie_unet_sdmgr.yml | 0 .../re_vi_layoutxlm_xfund_zh.yml} | 40 ++-- .../re_vi_layoutxlm_xfund_zh_udml.yml | 175 +++++++++++++++++ .../ser_vi_layoutxlm_xfund_zh.yml} | 40 ++-- .../ser_vi_layoutxlm_xfund_zh_udml.yml | 183 ++++++++++++++++++ configs/vqa/re/layoutlmv2_funsd.yml | 125 ------------ configs/vqa/ser/layoutlm_sroie.yml | 124 ------------ configs/vqa/ser/layoutlmv2_funsd.yml | 123 ------------ configs/vqa/ser/layoutlmv2_sroie.yml | 123 ------------ configs/vqa/ser/layoutxlm_funsd.yml | 123 ------------ configs/vqa/ser/layoutxlm_sroie.yml | 123 ------------ configs/vqa/ser/layoutxlm_wildreceipt.yml | 123 ------------ ppocr/data/imaug/label_ops.py | 20 +- ppocr/data/imaug/vqa/__init__.py | 2 - ppocr/data/imaug/vqa/augment.py | 30 ++- ppocr/losses/basic_loss.py | 19 +- ppocr/losses/combined_loss.py | 3 + ppocr/losses/distillation_loss.py | 132 +++++++++++++ ppocr/losses/vqa_token_layoutlm_loss.py | 14 +- ppocr/metrics/distillation_metric.py | 2 + ppocr/modeling/architectures/base_model.py | 34 ++-- ppocr/modeling/backbones/vqa_layoutlm.py | 127 ++++++++---- ppocr/postprocess/__init__.py | 37 +++- .../vqa_token_re_layoutlm_postprocess.py | 22 +++ .../vqa_token_ser_layoutlm_postprocess.py | 22 +++ ppocr/utils/save_load.py | 16 +- tools/infer/utility.py | 14 +- tools/infer_vqa_token_ser_re.py | 30 ++- 33 files changed, 833 insertions(+), 1012 deletions(-) rename configs/{vqa/re/layoutlmv2_xund_zh.yml => kie/layoutlm_series/re_layoutlmv2_xfund_zh.yml} (98%) rename configs/{vqa/re/layoutxlm_xfund_zh.yml => kie/layoutlm_series/re_layoutxlm_xfund_zh.yml} (95%) rename configs/{vqa/ser/layoutlm_xfund_zh.yml => kie/layoutlm_series/ser_layoutlm_xfund_zh.yml} (96%) rename configs/{vqa/ser/layoutlmv2_xfund_zh.yml => kie/layoutlm_series/ser_layoutlmv2_xfund_zh.yml} (99%) rename configs/{vqa/ser/layoutxlm_xfund_zh.yml => kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml} (99%) rename configs/kie/{ => sdmgr}/kie_unet_sdmgr.yml (100%) rename configs/{vqa/re/layoutxlm_funsd.yml => kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml} (73%) create mode 100644 configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml rename configs/{vqa/ser/layoutlm_funsd.yml => kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml} (72%) create mode 100644 configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml delete mode 100644 configs/vqa/re/layoutlmv2_funsd.yml delete mode 100644 configs/vqa/ser/layoutlm_sroie.yml delete mode 100644 configs/vqa/ser/layoutlmv2_funsd.yml delete mode 100644 configs/vqa/ser/layoutlmv2_sroie.yml delete mode 100644 configs/vqa/ser/layoutxlm_funsd.yml delete mode 100644 configs/vqa/ser/layoutxlm_sroie.yml delete mode 100644 configs/vqa/ser/layoutxlm_wildreceipt.yml diff --git a/configs/vqa/re/layoutlmv2_xund_zh.yml b/configs/kie/layoutlm_series/re_layoutlmv2_xfund_zh.yml similarity index 98% rename from configs/vqa/re/layoutlmv2_xund_zh.yml rename to configs/kie/layoutlm_series/re_layoutlmv2_xfund_zh.yml index 986b9b5c..4b330d8d 100644 --- a/configs/vqa/re/layoutlmv2_xund_zh.yml +++ b/configs/kie/layoutlm_series/re_layoutlmv2_xfund_zh.yml @@ -6,11 +6,11 @@ Global: save_model_dir: ./output/re_layoutlmv2_xfund_zh save_epoch_step: 2000 # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 57 ] + eval_batch_step: [ 0, 19 ] cal_metric_during_train: False save_inference_dir: use_visualdl: False - seed: 2048 + seed: 2022 infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg save_res_path: ./output/re_layoutlmv2_xfund_zh/res/ diff --git a/configs/vqa/re/layoutxlm_xfund_zh.yml b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml similarity index 95% rename from configs/vqa/re/layoutxlm_xfund_zh.yml rename to configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml index d8585bb7..a092106e 100644 --- a/configs/vqa/re/layoutxlm_xfund_zh.yml +++ b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml @@ -1,9 +1,9 @@ Global: use_gpu: True - epoch_num: &epoch_num 200 + epoch_num: &epoch_num 130 log_smooth_window: 10 print_batch_step: 10 - save_model_dir: ./output/re_layoutxlm/ + save_model_dir: ./output/re_layoutxlm_xfund_zh save_epoch_step: 2000 # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] @@ -12,7 +12,7 @@ Global: use_visualdl: False seed: 2022 infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg - save_res_path: ./output/re/ + save_res_path: ./output/re_layoutxlm_xfund_zh/res/ Architecture: model_type: vqa @@ -81,7 +81,7 @@ Train: loader: shuffle: True drop_last: False - batch_size_per_card: 8 + batch_size_per_card: 2 num_workers: 8 collate_fn: ListCollator diff --git a/configs/vqa/ser/layoutlm_xfund_zh.yml b/configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml similarity index 96% rename from configs/vqa/ser/layoutlm_xfund_zh.yml rename to configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml index 99763c19..8c754dd8 100644 --- a/configs/vqa/ser/layoutlm_xfund_zh.yml +++ b/configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml @@ -6,13 +6,13 @@ Global: save_model_dir: ./output/ser_layoutlm_xfund_zh save_epoch_step: 2000 # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 57 ] + eval_batch_step: [ 0, 19 ] cal_metric_during_train: False save_inference_dir: use_visualdl: False seed: 2022 infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg - save_res_path: ./output/ser_layoutlm_xfund_zh/res/ + save_res_path: ./output/re_layoutlm_xfund_zh/res Architecture: model_type: vqa @@ -55,6 +55,7 @@ Train: data_dir: train_data/XFUND/zh_train/image label_file_list: - train_data/XFUND/zh_train/train.json + ratio_list: [ 1.0 ] transforms: - DecodeImage: # load image img_mode: RGB diff --git a/configs/vqa/ser/layoutlmv2_xfund_zh.yml b/configs/kie/layoutlm_series/ser_layoutlmv2_xfund_zh.yml similarity index 99% rename from configs/vqa/ser/layoutlmv2_xfund_zh.yml rename to configs/kie/layoutlm_series/ser_layoutlmv2_xfund_zh.yml index ebdc5f31..3c0ffabe 100644 --- a/configs/vqa/ser/layoutlmv2_xfund_zh.yml +++ b/configs/kie/layoutlm_series/ser_layoutlmv2_xfund_zh.yml @@ -27,6 +27,7 @@ Architecture: Loss: name: VQASerTokenLayoutLMLoss num_classes: *num_classes + key: "backbone_out" Optimizer: name: AdamW diff --git a/configs/vqa/ser/layoutxlm_xfund_zh.yml b/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml similarity index 99% rename from configs/vqa/ser/layoutxlm_xfund_zh.yml rename to configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml index 68df7d9f..18f87bde 100644 --- a/configs/vqa/ser/layoutxlm_xfund_zh.yml +++ b/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml @@ -27,6 +27,7 @@ Architecture: Loss: name: VQASerTokenLayoutLMLoss num_classes: *num_classes + key: "backbone_out" Optimizer: name: AdamW diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/sdmgr/kie_unet_sdmgr.yml similarity index 100% rename from configs/kie/kie_unet_sdmgr.yml rename to configs/kie/sdmgr/kie_unet_sdmgr.yml diff --git a/configs/vqa/re/layoutxlm_funsd.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml similarity index 73% rename from configs/vqa/re/layoutxlm_funsd.yml rename to configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml index af28be10..89f7d5c3 100644 --- a/configs/vqa/re/layoutxlm_funsd.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml @@ -1,18 +1,18 @@ Global: use_gpu: True - epoch_num: &epoch_num 200 + epoch_num: &epoch_num 130 log_smooth_window: 10 print_batch_step: 10 - save_model_dir: ./output/re_layoutxlm_funsd + save_model_dir: ./output/re_vi_layoutxlm_xfund_zh save_epoch_step: 2000 # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 57 ] + eval_batch_step: [ 0, 19 ] cal_metric_during_train: False save_inference_dir: use_visualdl: False seed: 2022 - infer_img: train_data/FUNSD/testing_data/images/83624198.png - save_res_path: ./output/re_layoutxlm_funsd/res/ + infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg + save_res_path: ./output/re/xfund_zh/with_gt Architecture: model_type: vqa @@ -21,6 +21,7 @@ Architecture: Backbone: name: LayoutXLMForRe pretrained: True + mode: vi checkpoints: Loss: @@ -50,10 +51,9 @@ Metric: Train: dataset: name: SimpleDataSet - data_dir: ./train_data/FUNSD/training_data/images/ + data_dir: train_data/XFUND/zh_train/image label_file_list: - - ./train_data/FUNSD/train_v4.json - # - ./train_data/FUNSD/train.json + - train_data/XFUND/zh_train/train.json ratio_list: [ 1.0 ] transforms: - DecodeImage: # load image @@ -62,8 +62,9 @@ Train: - VQATokenLabelEncode: # Class handling label contains_re: True algorithm: *algorithm - class_path: &class_path ./train_data/FUNSD/class_list.txt + class_path: &class_path train_data/XFUND/class_list_xfun.txt use_textline_bbox_info: &use_textline_bbox_info True + order_method: &order_method "tb-yx" - VQATokenPad: max_seq_len: &max_seq_len 512 return_attention_mask: True @@ -79,22 +80,20 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order loader: - shuffle: False + shuffle: True drop_last: False - batch_size_per_card: 8 - num_workers: 16 + batch_size_per_card: 2 + num_workers: 4 collate_fn: ListCollator Eval: dataset: name: SimpleDataSet - data_dir: ./train_data/FUNSD/testing_data/images/ - label_file_list: - - ./train_data/FUNSD/test_v4.json - # - ./train_data/FUNSD/test.json + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/val.json transforms: - DecodeImage: # load image img_mode: RGB @@ -104,6 +103,7 @@ Eval: algorithm: *algorithm class_path: *class_path use_textline_bbox_info: *use_textline_bbox_info + order_method: *order_method - VQATokenPad: max_seq_len: *max_seq_len return_attention_mask: True @@ -119,11 +119,11 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False batch_size_per_card: 8 num_workers: 8 collate_fn: ListCollator + diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml new file mode 100644 index 00000000..c1bfdb6c --- /dev/null +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml @@ -0,0 +1,175 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 130 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/re_vi_layoutxlm_xfund_zh_udml + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg + save_res_path: ./output/re/xfund_zh/with_gt + +Architecture: + model_type: &model_type "vqa" + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: &algorithm "LayoutXLM" + Transform: + Backbone: + name: LayoutXLMForRe + pretrained: True + mode: vi + checkpoints: + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: *algorithm + Transform: + Backbone: + name: LayoutXLMForRe + pretrained: True + mode: vi + checkpoints: + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationLossFromOutput: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: loss + reduction: mean + - DistillationVQADistanceLoss: + weight: 0.5 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: hidden_states_5 + name: "loss_5" + - DistillationVQADistanceLoss: + weight: 0.5 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: hidden_states_8 + name: "loss_8" + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + clip_norm: 10 + lr: + learning_rate: 0.00005 + warmup_epoch: 10 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: DistillationRePostProcess + model_name: ["Student", "Teacher"] + key: null + + +Metric: + name: DistillationMetric + base_metric_name: VQAReTokenMetric + main_indicator: hmean + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/train.json + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: &class_path train_data/XFUND/class_list_xfun.txt + use_textline_bbox_info: &use_textline_bbox_info True + # [None, "tb-yx"] + order_method: &order_method "tb-yx" + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQAReTokenRelation: + - VQAReTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 2 + num_workers: 4 + collate_fn: ListCollator + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: *class_path + use_textline_bbox_info: *use_textline_bbox_info + order_method: *order_method + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQAReTokenRelation: + - VQAReTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 8 + collate_fn: ListCollator + + diff --git a/configs/vqa/ser/layoutlm_funsd.yml b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml similarity index 72% rename from configs/vqa/ser/layoutlm_funsd.yml rename to configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml index 0ef3502b..d54125db 100644 --- a/configs/vqa/ser/layoutlm_funsd.yml +++ b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml @@ -3,30 +3,38 @@ Global: epoch_num: &epoch_num 200 log_smooth_window: 10 print_batch_step: 10 - save_model_dir: ./output/ser_layoutlm_funsd + save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh save_epoch_step: 2000 # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 57 ] + eval_batch_step: [ 0, 19 ] cal_metric_during_train: False save_inference_dir: use_visualdl: False seed: 2022 - infer_img: train_data/FUNSD/testing_data/images/83624198.png - save_res_path: ./output/ser_layoutlm_funsd/res/ + infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg + # if you want to predict using the groundtruth ocr info, + # you can use the following config + # infer_img: train_data/XFUND/zh_val/val.json + # infer_mode: False + + save_res_path: ./output/ser/xfund_zh/res Architecture: model_type: vqa - algorithm: &algorithm "LayoutLM" + algorithm: &algorithm "LayoutXLM" Transform: Backbone: - name: LayoutLMForSer + name: LayoutXLMForSer pretrained: True checkpoints: + # one of base or vi + mode: vi num_classes: &num_classes 7 Loss: name: VQASerTokenLayoutLMLoss num_classes: *num_classes + key: "backbone_out" Optimizer: name: AdamW @@ -43,7 +51,7 @@ Optimizer: PostProcess: name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ./train_data/FUNSD/class_list.txt + class_path: &class_path train_data/XFUND/class_list_xfun.txt Metric: name: VQASerTokenMetric @@ -52,9 +60,10 @@ Metric: Train: dataset: name: SimpleDataSet - data_dir: ./train_data/FUNSD/training_data/images/ + data_dir: train_data/XFUND/zh_train/image label_file_list: - - ./train_data/FUNSD/train.json + - train_data/XFUND/zh_train/train.json + ratio_list: [ 1.0 ] transforms: - DecodeImage: # load image img_mode: RGB @@ -64,6 +73,8 @@ Train: algorithm: *algorithm class_path: *class_path use_textline_bbox_info: &use_textline_bbox_info True + # one of [None, "tb-yx"] + order_method: &order_method "tb-yx" - VQATokenPad: max_seq_len: &max_seq_len 512 return_attention_mask: True @@ -78,8 +89,7 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order loader: shuffle: True drop_last: False @@ -89,9 +99,9 @@ Train: Eval: dataset: name: SimpleDataSet - data_dir: train_data/FUNSD/testing_data/images/ + data_dir: train_data/XFUND/zh_val/image label_file_list: - - ./train_data/FUNSD/test.json + - train_data/XFUND/zh_val/val.json transforms: - DecodeImage: # load image img_mode: RGB @@ -101,6 +111,7 @@ Eval: algorithm: *algorithm class_path: *class_path use_textline_bbox_info: *use_textline_bbox_info + order_method: *order_method - VQATokenPad: max_seq_len: *max_seq_len return_attention_mask: True @@ -115,8 +126,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml new file mode 100644 index 00000000..6f0961c8 --- /dev/null +++ b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml @@ -0,0 +1,183 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg + save_res_path: ./output/ser_layoutxlm_xfund_zh/res + + +Architecture: + model_type: &model_type "vqa" + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: &algorithm "LayoutXLM" + Transform: + Backbone: + name: LayoutXLMForSer + pretrained: True + # one of base or vi + mode: vi + checkpoints: + num_classes: &num_classes 7 + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: *algorithm + Transform: + Backbone: + name: LayoutXLMForSer + pretrained: True + # one of base or vi + mode: vi + checkpoints: + num_classes: *num_classes + + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationVQASerTokenLayoutLMLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: backbone_out + num_classes: *num_classes + - DistillationSERDMLLoss: + weight: 1.0 + act: "softmax" + use_log: true + model_name_pairs: + - ["Student", "Teacher"] + key: backbone_out + - DistillationVQADistanceLoss: + weight: 0.5 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: hidden_states_5 + name: "loss_5" + - DistillationVQADistanceLoss: + weight: 0.5 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: hidden_states_8 + name: "loss_8" + + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Linear + learning_rate: 0.00005 + epochs: *epoch_num + warmup_epoch: 10 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: DistillationSerPostProcess + model_name: ["Student", "Teacher"] + key: backbone_out + class_path: &class_path train_data/XFUND/class_list_xfun.txt + +Metric: + name: DistillationMetric + base_metric_name: VQASerTokenMetric + main_indicator: hmean + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/train.json + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + # one of [None, "tb-yx"] + order_method: &order_method "tb-yx" + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 4 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + order_method: *order_method + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + diff --git a/configs/vqa/re/layoutlmv2_funsd.yml b/configs/vqa/re/layoutlmv2_funsd.yml deleted file mode 100644 index 1c3d8f78..00000000 --- a/configs/vqa/re/layoutlmv2_funsd.yml +++ /dev/null @@ -1,125 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 200 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/re_layoutlmv2_funsd - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 57 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data/FUNSD/testing_data/images/83624198.png - save_res_path: ./output/re_layoutlmv2_funsd/res/ - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutLMv2" - Transform: - Backbone: - name: LayoutLMv2ForRe - pretrained: True - checkpoints: - -Loss: - name: LossFromOutput - key: loss - reduction: mean - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - clip_norm: 10 - lr: - learning_rate: 0.00005 - warmup_epoch: 10 - regularizer: - name: L2 - factor: 0.00000 - -PostProcess: - name: VQAReTokenLayoutLMPostProcess - -Metric: - name: VQAReTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/FUNSD/training_data/images/ - label_file_list: - - ./train_data/FUNSD/train.json - ratio_list: [ 1.0 ] - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: True - algorithm: *algorithm - class_path: &class_path train_data/FUNSD/class_list.txt - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQAReTokenRelation: - - VQAReTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1./255. - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 8 - collate_fn: ListCollator - -Eval: - dataset: - name: SimpleDataSet - data_dir: ./train_data/FUNSD/testing_data/images/ - label_file_list: - - ./train_data/FUNSD/test.json - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: True - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQAReTokenRelation: - - VQAReTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1./255. - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 8 - collate_fn: ListCollator diff --git a/configs/vqa/ser/layoutlm_sroie.yml b/configs/vqa/ser/layoutlm_sroie.yml deleted file mode 100644 index 6abb1151..00000000 --- a/configs/vqa/ser/layoutlm_sroie.yml +++ /dev/null @@ -1,124 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 200 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/ser_layoutlm_sroie - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 200 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data/SROIE/test/X00016469670.jpg - save_res_path: ./output/ser_layoutlm_sroie/res/ - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutLM" - Transform: - Backbone: - name: LayoutLMForSer - pretrained: True - checkpoints: - num_classes: &num_classes 9 - -Loss: - name: VQASerTokenLayoutLMLoss - num_classes: *num_classes - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - lr: - name: Linear - learning_rate: 0.00005 - epochs: *epoch_num - warmup_epoch: 2 - regularizer: - name: L2 - factor: 0.00000 - -PostProcess: - name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ./train_data/SROIE/class_list.txt - -Metric: - name: VQASerTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/SROIE/train - label_file_list: - - ./train_data/SROIE/train.txt - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - use_textline_bbox_info: &use_textline_bbox_info True - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 4 - -Eval: - dataset: - name: SimpleDataSet - data_dir: ./train_data/SROIE/test - label_file_list: - - ./train_data/SROIE/test.txt - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - use_textline_bbox_info: *use_textline_bbox_info - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 4 diff --git a/configs/vqa/ser/layoutlmv2_funsd.yml b/configs/vqa/ser/layoutlmv2_funsd.yml deleted file mode 100644 index 438edc1a..00000000 --- a/configs/vqa/ser/layoutlmv2_funsd.yml +++ /dev/null @@ -1,123 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 200 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/ser_layoutlmv2_funsd - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 100 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data/FUNSD/testing_data/images/83624198.png - save_res_path: ./output/ser_layoutlmv2_funsd/res/ - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutLMv2" - Transform: - Backbone: - name: LayoutLMv2ForSer - pretrained: True - checkpoints: - num_classes: &num_classes 7 - -Loss: - name: VQASerTokenLayoutLMLoss - num_classes: *num_classes - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - lr: - name: Linear - learning_rate: 0.00005 - epochs: *epoch_num - warmup_epoch: 2 - regularizer: - - name: L2 - factor: 0.00000 - -PostProcess: - name: VQASerTokenLayoutLMPostProcess - class_path: &class_path train_data/FUNSD/class_list.txt - -Metric: - name: VQASerTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/FUNSD/training_data/images/ - label_file_list: - - ./train_data/FUNSD/train.json - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 4 - -Eval: - dataset: - name: SimpleDataSet - data_dir: ./train_data/FUNSD/testing_data/images/ - label_file_list: - - ./train_data/FUNSD/test.json - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 4 diff --git a/configs/vqa/ser/layoutlmv2_sroie.yml b/configs/vqa/ser/layoutlmv2_sroie.yml deleted file mode 100644 index 549beb8e..00000000 --- a/configs/vqa/ser/layoutlmv2_sroie.yml +++ /dev/null @@ -1,123 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 200 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/ser_layoutlmv2_sroie - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 200 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data/SROIE/test/X00016469670.jpg - save_res_path: ./output/ser_layoutlmv2_sroie/res/ - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutLMv2" - Transform: - Backbone: - name: LayoutLMv2ForSer - pretrained: True - checkpoints: - num_classes: &num_classes 9 - -Loss: - name: VQASerTokenLayoutLMLoss - num_classes: *num_classes - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - lr: - name: Linear - learning_rate: 0.00005 - epochs: *epoch_num - warmup_epoch: 2 - regularizer: - - name: L2 - factor: 0.00000 - -PostProcess: - name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ./train_data/SROIE/class_list.txt - -Metric: - name: VQASerTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/SROIE/train - label_file_list: - - ./train_data/SROIE/train.txt - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 4 - -Eval: - dataset: - name: SimpleDataSet - data_dir: ./train_data/SROIE/test - label_file_list: - - ./train_data/SROIE/test.txt - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 4 diff --git a/configs/vqa/ser/layoutxlm_funsd.yml b/configs/vqa/ser/layoutxlm_funsd.yml deleted file mode 100644 index be1e9d4f..00000000 --- a/configs/vqa/ser/layoutxlm_funsd.yml +++ /dev/null @@ -1,123 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 200 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/ser_layoutxlm_funsd - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 57 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data/FUNSD/testing_data/images/83624198.png - save_res_path: output/ser_layoutxlm_funsd/res/ - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutXLM" - Transform: - Backbone: - name: LayoutXLMForSer - pretrained: True - checkpoints: - num_classes: &num_classes 7 - -Loss: - name: VQASerTokenLayoutLMLoss - num_classes: *num_classes - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - lr: - name: Linear - learning_rate: 0.00005 - epochs: *epoch_num - warmup_epoch: 2 - regularizer: - name: L2 - factor: 0.00000 - -PostProcess: - name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ./train_data/FUNSD/class_list.txt - -Metric: - name: VQASerTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/FUNSD/training_data/images/ - label_file_list: - - ./train_data/FUNSD/train.json - ratio_list: [ 1.0 ] - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 4 - -Eval: - dataset: - name: SimpleDataSet - data_dir: train_data/FUNSD/testing_data/images/ - label_file_list: - - ./train_data/FUNSD/test.json - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 4 diff --git a/configs/vqa/ser/layoutxlm_sroie.yml b/configs/vqa/ser/layoutxlm_sroie.yml deleted file mode 100644 index dd63d888..00000000 --- a/configs/vqa/ser/layoutxlm_sroie.yml +++ /dev/null @@ -1,123 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 200 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/ser_layoutxlm_sroie - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 200 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data/SROIE/test/X00016469670.jpg - save_res_path: res_img_aug_with_gt - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutXLM" - Transform: - Backbone: - name: LayoutXLMForSer - pretrained: True - checkpoints: - num_classes: &num_classes 9 - -Loss: - name: VQASerTokenLayoutLMLoss - num_classes: *num_classes - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - lr: - name: Linear - learning_rate: 0.00005 - epochs: *epoch_num - warmup_epoch: 2 - regularizer: - name: L2 - factor: 0.00000 - -PostProcess: - name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ./train_data/SROIE/class_list.txt - -Metric: - name: VQASerTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/SROIE/train - label_file_list: - - ./train_data/SROIE/train.txt - ratio_list: [ 1.0 ] - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 4 - -Eval: - dataset: - name: SimpleDataSet - data_dir: train_data/SROIE/test - label_file_list: - - ./train_data/SROIE/test.txt - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 4 diff --git a/configs/vqa/ser/layoutxlm_wildreceipt.yml b/configs/vqa/ser/layoutxlm_wildreceipt.yml deleted file mode 100644 index 92c03942..00000000 --- a/configs/vqa/ser/layoutxlm_wildreceipt.yml +++ /dev/null @@ -1,123 +0,0 @@ -Global: - use_gpu: True - epoch_num: &epoch_num 100 - log_smooth_window: 10 - print_batch_step: 10 - save_model_dir: ./output/ser_layoutxlm_wildreceipt - save_epoch_step: 2000 - # evaluation is run every 10 iterations after the 0th iteration - eval_batch_step: [ 0, 200 ] - cal_metric_during_train: False - save_inference_dir: - use_visualdl: False - seed: 2022 - infer_img: train_data//wildreceipt/image_files/Image_12/10/845be0dd6f5b04866a2042abd28d558032ef2576.jpeg - save_res_path: ./output/ser_layoutxlm_wildreceipt/res - -Architecture: - model_type: vqa - algorithm: &algorithm "LayoutXLM" - Transform: - Backbone: - name: LayoutXLMForSer - pretrained: True - checkpoints: - num_classes: &num_classes 51 - -Loss: - name: VQASerTokenLayoutLMLoss - num_classes: *num_classes - -Optimizer: - name: AdamW - beta1: 0.9 - beta2: 0.999 - lr: - name: Linear - learning_rate: 0.00005 - epochs: *epoch_num - warmup_epoch: 2 - regularizer: - name: L2 - factor: 0.00000 - -PostProcess: - name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ./train_data/wildreceipt/class_list.txt - -Metric: - name: VQASerTokenMetric - main_indicator: hmean - -Train: - dataset: - name: SimpleDataSet - data_dir: ./train_data/wildreceipt/ - label_file_list: - - ./train_data/wildreceipt/wildreceipt_train.txt - ratio_list: [ 1.0 ] - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: &max_seq_len 512 - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: True - drop_last: False - batch_size_per_card: 8 - num_workers: 4 - -Eval: - dataset: - name: SimpleDataSet - data_dir: train_data/wildreceipt - label_file_list: - - ./train_data/wildreceipt/wildreceipt_test.txt - transforms: - - DecodeImage: # load image - img_mode: RGB - channel_first: False - - VQATokenLabelEncode: # Class handling label - contains_re: False - algorithm: *algorithm - class_path: *class_path - - VQATokenPad: - max_seq_len: *max_seq_len - return_attention_mask: True - - VQASerTokenChunk: - max_seq_len: *max_seq_len - - Resize: - size: [224,224] - - NormalizeImage: - scale: 1 - mean: [ 123.675, 116.28, 103.53 ] - std: [ 58.395, 57.12, 57.375 ] - order: 'hwc' - - ToCHWImage: - - KeepKeys: - # dataloader will return list in this order - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] - loader: - shuffle: False - drop_last: False - batch_size_per_card: 8 - num_workers: 4 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index d84f9662..180fe97b 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -24,6 +24,7 @@ from shapely.geometry import LineString, Point, Polygon import json import copy from ppocr.utils.logging import get_logger +from ppocr.data.imaug.vqa.augment import order_by_tbyx class ClsLabelEncode(object): @@ -870,6 +871,7 @@ class VQATokenLabelEncode(object): add_special_ids=False, algorithm='LayoutXLM', use_textline_bbox_info=True, + order_method=None, infer_mode=False, ocr_engine=None, **kwargs): @@ -899,6 +901,8 @@ class VQATokenLabelEncode(object): self.infer_mode = infer_mode self.ocr_engine = ocr_engine self.use_textline_bbox_info = use_textline_bbox_info + self.order_method = order_method + assert self.order_method in [None, "tb-yx"] def split_bbox(self, bbox, text, tokenizer): words = text.split() @@ -938,6 +942,14 @@ class VQATokenLabelEncode(object): # load bbox and label info ocr_info = self._load_ocr_info(data) + for idx in range(len(ocr_info)): + if "bbox" not in ocr_info[idx]: + ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][ + "points"]) + + if self.order_method == "tb-yx": + ocr_info = order_by_tbyx(ocr_info) + # for re train_re = self.contains_re and not self.infer_mode if train_re: @@ -1052,10 +1064,10 @@ class VQATokenLabelEncode(object): return data def trans_poly_to_bbox(self, poly): - x1 = np.min([p[0] for p in poly]) - x2 = np.max([p[0] for p in poly]) - y1 = np.min([p[1] for p in poly]) - y2 = np.max([p[1] for p in poly]) + x1 = int(np.min([p[0] for p in poly])) + x2 = int(np.max([p[0] for p in poly])) + y1 = int(np.min([p[1] for p in poly])) + y2 = int(np.max([p[1] for p in poly])) return [x1, y1, x2, y2] def _load_ocr_info(self, data): diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py index bde17511..34189bce 100644 --- a/ppocr/data/imaug/vqa/__init__.py +++ b/ppocr/data/imaug/vqa/__init__.py @@ -13,12 +13,10 @@ # limitations under the License. from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation -from .augment import DistortBBox __all__ = [ 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation', - 'DistortBBox', ] diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py index fcdc9685..b95fcdf0 100644 --- a/ppocr/data/imaug/vqa/augment.py +++ b/ppocr/data/imaug/vqa/augment.py @@ -16,22 +16,18 @@ import os import sys import numpy as np import random +from copy import deepcopy -class DistortBBox: - def __init__(self, prob=0.5, max_scale=1, **kwargs): - """Random distort bbox - """ - self.prob = prob - self.max_scale = max_scale - - def __call__(self, data): - if random.random() > self.prob: - return data - bbox = np.array(data['bbox']) - rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale - bbox = np.round(bbox + rnd_scale).astype(bbox.dtype) - data['bbox'] = np.clip(data['bbox'], 0, 1000) - data['bbox'] = bbox.tolist() - sys.stdout.flush() - return data +def order_by_tbyx(ocr_info): + res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0])) + for i in range(len(res) - 1): + for j in range(i, 0, -1): + if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \ + (res[j + 1]["bbox"][0] < res[j]["bbox"][0]): + tmp = deepcopy(res[j]) + res[j] = deepcopy(res[j + 1]) + res[j + 1] = deepcopy(tmp) + else: + break + return res diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 74490791..da9faa08 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -63,18 +63,21 @@ class KLJSLoss(object): def __call__(self, p1, p2, reduction="mean"): if self.mode.lower() == 'kl': - loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) + loss = paddle.multiply(p2, + paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) loss += paddle.multiply( - p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) + p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) loss *= 0.5 elif self.mode.lower() == "js": - loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + loss = paddle.multiply( + p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) loss += paddle.multiply( - p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) + p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) loss *= 0.5 else: - raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']") - + raise ValueError( + "The mode.lower() if KLJSLoss should be one of ['kl', 'js']") + if reduction == "mean": loss = paddle.mean(loss, axis=[1, 2]) elif reduction == "none" or reduction is None: @@ -154,7 +157,9 @@ class LossFromOutput(nn.Layer): self.reduction = reduction def forward(self, predicts, batch): - loss = predicts[self.key] + loss = predicts + if self.key is not None and isinstance(predicts, dict): + loss = loss[self.key] if self.reduction == 'mean': loss = paddle.mean(loss) elif self.reduction == 'sum': diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index f4cdee8f..8d697d54 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -24,6 +24,9 @@ from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationSARLoss from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss +from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss +from .distillation_loss import DistillationLossFromOutput +from .distillation_loss import DistillationVQADistanceLoss class CombinedLoss(nn.Layer): diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 565b066d..87fed623 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -21,8 +21,10 @@ from .rec_ctc_loss import CTCLoss from .rec_sar_loss import SARLoss from .basic_loss import DMLLoss from .basic_loss import DistanceLoss +from .basic_loss import LossFromOutput from .det_db_loss import DBLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss +from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss def _sum_loss(loss_dict): @@ -322,3 +324,133 @@ class DistillationDistanceLoss(DistanceLoss): loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)] = loss return loss_dict + + +class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss): + def __init__(self, + num_classes, + model_name_list=[], + key=None, + name="loss_ser"): + super().__init__(num_classes=num_classes) + self.model_name_list = model_name_list + self.key = key + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"] + return loss_dict + + +class DistillationLossFromOutput(LossFromOutput): + def __init__(self, + reduction="none", + model_name_list=[], + dist_key=None, + key="loss", + name="loss_re"): + super().__init__(key=key, reduction=reduction) + self.model_name_list = model_name_list + self.name = name + self.dist_key = dist_key + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.dist_key is not None: + out = out[self.dist_key] + loss = super().forward(out, batch) + loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"] + return loss_dict + + +class DistillationSERDMLLoss(DMLLoss): + """ + """ + + def __init__(self, + act="softmax", + use_log=True, + num_classes=7, + model_name_pairs=[], + key=None, + name="loss_dml_ser"): + super().__init__(act=act, use_log=use_log) + assert isinstance(model_name_pairs, list) + self.key = key + self.name = name + self.num_classes = num_classes + self.model_name_pairs = model_name_pairs + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + out1 = out1.reshape([-1, out1.shape[-1]]) + out2 = out2.reshape([-1, out2.shape[-1]]) + + attention_mask = batch[2] + if attention_mask is not None: + active_output = attention_mask.reshape([-1, ]) == 1 + out1 = out1[active_output] + out2 = out2[active_output] + + loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, + out2) + + return loss_dict + + +class DistillationVQADistanceLoss(DistanceLoss): + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_distance", + **kargs): + super().__init__(mode=mode, **kargs) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + "_l2" + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + attention_mask = batch[2] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + if attention_mask is not None: + max_len = attention_mask.shape[-1] + out1 = out1[:, :max_len] + out2 = out2[:, :max_len] + out1 = out1.reshape([-1, out1.shape[-1]]) + out2 = out2.reshape([-1, out2.shape[-1]]) + if attention_mask is not None: + active_output = attention_mask.reshape([-1, ]) == 1 + out1 = out1[active_output] + out2 = out2[active_output] + + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}nohu_{}".format(self.name, key, + idx)] = loss[key] + else: + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = loss + return loss_dict diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py index f9cd4634..5d564c0e 100755 --- a/ppocr/losses/vqa_token_layoutlm_loss.py +++ b/ppocr/losses/vqa_token_layoutlm_loss.py @@ -17,26 +17,30 @@ from __future__ import division from __future__ import print_function from paddle import nn +from ppocr.losses.basic_loss import DMLLoss class VQASerTokenLayoutLMLoss(nn.Layer): - def __init__(self, num_classes): + def __init__(self, num_classes, key=None): super().__init__() self.loss_class = nn.CrossEntropyLoss() self.num_classes = num_classes self.ignore_index = self.loss_class.ignore_index + self.key = key def forward(self, predicts, batch): + if isinstance(predicts, dict) and self.key is not None: + predicts = predicts[self.key] labels = batch[5] attention_mask = batch[2] if attention_mask is not None: active_loss = attention_mask.reshape([-1, ]) == 1 - active_outputs = predicts.reshape( + active_output = predicts.reshape( [-1, self.num_classes])[active_loss] - active_labels = labels.reshape([-1, ])[active_loss] - loss = self.loss_class(active_outputs, active_labels) + active_label = labels.reshape([-1, ])[active_loss] + loss = self.loss_class(active_output, active_label) else: loss = self.loss_class( predicts.reshape([-1, self.num_classes]), labels.reshape([-1, ])) - return {'loss': loss} + return {'loss': loss} \ No newline at end of file diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py index c440cebd..e2cbc4dc 100644 --- a/ppocr/metrics/distillation_metric.py +++ b/ppocr/metrics/distillation_metric.py @@ -19,6 +19,8 @@ from .rec_metric import RecMetric from .det_metric import DetMetric from .e2e_metric import E2EMetric from .cls_metric import ClsMetric +from .vqa_token_ser_metric import VQASerTokenMetric +from .vqa_token_re_metric import VQAReTokenMetric class DistillationMetric(object): diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index c6b50d48..ed2a909c 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -73,28 +73,40 @@ class BaseModel(nn.Layer): self.return_all_feats = config.get("return_all_feats", False) def forward(self, x, data=None): + y = dict() if self.use_transform: x = self.transform(x) x = self.backbone(x) - y["backbone_out"] = x + if isinstance(x, dict): + y.update(x) + else: + y["backbone_out"] = x + final_name = "backbone_out" if self.use_neck: x = self.neck(x) - y["neck_out"] = x + if isinstance(x, dict): + y.update(x) + else: + y["neck_out"] = x + final_name = "neck_out" if self.use_head: x = self.head(x, targets=data) - # for multi head, save ctc neck out for udml - if isinstance(x, dict) and 'ctc_neck' in x.keys(): - y["neck_out"] = x["ctc_neck"] - y["head_out"] = x - elif isinstance(x, dict): - y.update(x) - else: - y["head_out"] = x + # for multi head, save ctc neck out for udml + if isinstance(x, dict) and 'ctc_neck' in x.keys(): + y["neck_out"] = x["ctc_neck"] + y["head_out"] = x + elif isinstance(x, dict): + y.update(x) + else: + y["head_out"] = x + final_name = "head_out" if self.return_all_feats: if self.training: return y + elif isinstance(x, dict): + return x else: - return {"head_out": y["head_out"]} + return {final_name: x} else: return x diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index 34dd9d10..d4ced350 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -22,13 +22,22 @@ from paddle import nn from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction +from paddlenlp.transformers import AutoModel -__all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] +__all__ = ["LayoutXLMForSer", "LayoutLMForSer"] pretrained_model_dict = { - LayoutXLMModel: 'layoutxlm-base-uncased', - LayoutLMModel: 'layoutlm-base-uncased', - LayoutLMv2Model: 'layoutlmv2-base-uncased' + LayoutXLMModel: { + "base": "layoutxlm-base-uncased", + "vi": "layoutxlm-wo-backbone-base-uncased", + }, + LayoutLMModel: { + "base": "layoutlm-base-uncased", + }, + LayoutLMv2Model: { + "base": "layoutlmv2-base-uncased", + "vi": "layoutlmv2-wo-backbone-base-uncased", + }, } @@ -36,42 +45,47 @@ class NLPBaseModel(nn.Layer): def __init__(self, base_model_class, model_class, - type='ser', + mode="base", + type="ser", pretrained=True, checkpoints=None, **kwargs): super(NLPBaseModel, self).__init__() - if checkpoints is not None: + if checkpoints is not None: # load the trained model self.model = model_class.from_pretrained(checkpoints) - elif isinstance(pretrained, (str, )) and os.path.exists(pretrained): - self.model = model_class.from_pretrained(pretrained) - else: - pretrained_model_name = pretrained_model_dict[base_model_class] + else: # load the pretrained-model + pretrained_model_name = pretrained_model_dict[base_model_class][ + mode] if pretrained is True: base_model = base_model_class.from_pretrained( pretrained_model_name) else: - base_model = base_model_class( - **base_model_class.pretrained_init_configuration[ - pretrained_model_name]) - if type == 'ser': + base_model = base_model_class.from_pretrained(pretrained) + if type == "ser": self.model = model_class( - base_model, num_classes=kwargs['num_classes'], dropout=None) + base_model, num_classes=kwargs["num_classes"], dropout=None) else: self.model = model_class(base_model, dropout=None) self.out_channels = 1 + self.use_visual_backbone = True class LayoutLMForSer(NLPBaseModel): - def __init__(self, num_classes, pretrained=True, checkpoints=None, + def __init__(self, + num_classes, + pretrained=True, + checkpoints=None, + mode="base", **kwargs): super(LayoutLMForSer, self).__init__( LayoutLMModel, LayoutLMForTokenClassification, - 'ser', + mode, + "ser", pretrained, checkpoints, - num_classes=num_classes) + num_classes=num_classes, ) + self.use_visual_backbone = False def forward(self, x): x = self.model( @@ -85,62 +99,92 @@ class LayoutLMForSer(NLPBaseModel): class LayoutLMv2ForSer(NLPBaseModel): - def __init__(self, num_classes, pretrained=True, checkpoints=None, + def __init__(self, + num_classes, + pretrained=True, + checkpoints=None, + mode="base", **kwargs): super(LayoutLMv2ForSer, self).__init__( LayoutLMv2Model, LayoutLMv2ForTokenClassification, - 'ser', + mode, + "ser", pretrained, checkpoints, num_classes=num_classes) + self.use_visual_backbone = True + if hasattr(self.model.layoutlmv2, "use_visual_backbone" + ) and self.model.layoutlmv2.use_visual_backbone is False: + self.use_visual_backbone = False def forward(self, x): + if self.use_visual_backbone is True: + image = x[4] + else: + image = None x = self.model( input_ids=x[0], bbox=x[1], attention_mask=x[2], token_type_ids=x[3], - image=x[4], + image=image, position_ids=None, head_mask=None, labels=None) - if not self.training: + if self.training: + res = {"backbone_out": x[0]} + res.update(x[1]) + return res + else: return x - return x[0] class LayoutXLMForSer(NLPBaseModel): - def __init__(self, num_classes, pretrained=True, checkpoints=None, + def __init__(self, + num_classes, + pretrained=True, + checkpoints=None, + mode="base", **kwargs): super(LayoutXLMForSer, self).__init__( LayoutXLMModel, LayoutXLMForTokenClassification, - 'ser', + mode, + "ser", pretrained, checkpoints, num_classes=num_classes) + self.use_visual_backbone = True def forward(self, x): + if self.use_visual_backbone is True: + image = x[4] + else: + image = None x = self.model( input_ids=x[0], bbox=x[1], attention_mask=x[2], token_type_ids=x[3], - image=x[4], + image=image, position_ids=None, head_mask=None, labels=None) - if not self.training: + if self.training: + res = {"backbone_out": x[0]} + res.update(x[1]) + return res + else: return x - return x[0] class LayoutLMv2ForRe(NLPBaseModel): - def __init__(self, pretrained=True, checkpoints=None, **kwargs): - super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model, - LayoutLMv2ForRelationExtraction, - 're', pretrained, checkpoints) + def __init__(self, pretrained=True, checkpoints=None, mode="base", + **kwargs): + super(LayoutLMv2ForRe, self).__init__( + LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re", + pretrained, checkpoints) def forward(self, x): x = self.model( @@ -158,18 +202,27 @@ class LayoutLMv2ForRe(NLPBaseModel): class LayoutXLMForRe(NLPBaseModel): - def __init__(self, pretrained=True, checkpoints=None, **kwargs): - super(LayoutXLMForRe, self).__init__(LayoutXLMModel, - LayoutXLMForRelationExtraction, - 're', pretrained, checkpoints) + def __init__(self, pretrained=True, checkpoints=None, mode="base", + **kwargs): + super(LayoutXLMForRe, self).__init__( + LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re", + pretrained, checkpoints) + self.use_visual_backbone = True + if hasattr(self.model.layoutxlm, "use_visual_backbone" + ) and self.model.layoutxlm.use_visual_backbone is False: + self.use_visual_backbone = False def forward(self, x): + if self.use_visual_backbone is True: + image = x[4] + else: + image = None x = self.model( input_ids=x[0], bbox=x[1], attention_mask=x[2], token_type_ids=x[3], - image=x[4], + image=image, position_ids=None, head_mask=None, labels=None, diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index eeebc580..6fa871a4 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -31,21 +31,38 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ SPINLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess -from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess -from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess +from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess +from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess from .table_postprocess import TableMasterLabelDecode, TableLabelDecode def build_post_process(config, global_config=None): support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess', - 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', - 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', - 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', - 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', - 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', - 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', - 'TableMasterLabelDecode', 'SPINLabelDecode' + 'DBPostProcess', + 'EASTPostProcess', + 'SASTPostProcess', + 'FCEPostProcess', + 'CTCLabelDecode', + 'AttnLabelDecode', + 'ClsPostProcess', + 'SRNLabelDecode', + 'PGPostProcess', + 'DistillationCTCLabelDecode', + 'TableLabelDecode', + 'DistillationDBPostProcess', + 'NRTRLabelDecode', + 'SARLabelDecode', + 'SEEDLabelDecode', + 'VQASerTokenLayoutLMPostProcess', + 'VQAReTokenLayoutLMPostProcess', + 'PRENLabelDecode', + 'DistillationSARLabelDecode', + 'ViTSTRLabelDecode', + 'ABINetLabelDecode', + 'TableMasterLabelDecode', + 'SPINLabelDecode', + 'DistillationSerPostProcess', + 'DistillationRePostProcess', ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py index 1d55d13d..96c25d9a 100644 --- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -49,3 +49,25 @@ class VQAReTokenLayoutLMPostProcess(object): result.append((ocr_info_head, ocr_info_tail)) results.append(result) return results + + +class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess): + """ + DistillationRePostProcess + """ + + def __init__(self, model_name=["Student"], key=None, **kwargs): + super().__init__(**kwargs) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + self.key = key + + def __call__(self, preds, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + output[name] = super().__call__(pred, *args, **kwargs) + return output diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py index 8a6669f7..5541da90 100644 --- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py @@ -93,3 +93,25 @@ class VQASerTokenLayoutLMPostProcess(object): ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] results.append(ocr_info) return results + + +class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess): + """ + DistillationSerPostProcess + """ + + def __init__(self, class_path, model_name=["Student"], key=None, **kwargs): + super().__init__(class_path, **kwargs) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + self.key = key + + def __call__(self, preds, batch=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + output[name] = super().__call__(pred, batch=batch, *args, **kwargs) + return output diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 3647111f..8fded687 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -55,6 +55,9 @@ def load_model(config, model, optimizer=None, model_type='det'): best_model_dict = {} if model_type == 'vqa': + # NOTE: for vqa model, resume training is not supported now + if config["Architecture"]["algorithm"] in ["Distillation"]: + return best_model_dict checkpoints = config['Architecture']['Backbone']['checkpoints'] # load vqa method metric if checkpoints: @@ -78,6 +81,7 @@ def load_model(config, model, optimizer=None, model_type='det'): logger.warning( "{}.pdopt is not exists, params of optimizer is not loaded". format(checkpoints)) + return best_model_dict if checkpoints: @@ -166,15 +170,19 @@ def save_model(model, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') + if config['Architecture']["model_type"] != 'vqa': + paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') if config['Architecture']["model_type"] != 'vqa': paddle.save(model.state_dict(), model_prefix + '.pdparams') metric_prefix = model_prefix - else: + else: # for vqa system, we follow the save/load rules in NLP if config['Global']['distributed']: - model._layers.backbone.model.save_pretrained(model_prefix) + arch = model._layers else: - model.backbone.model.save_pretrained(model_prefix) + arch = model + if config["Architecture"]["algorithm"] in ["Distillation"]: + arch = arch.Student + arch.backbone.model.save_pretrained(model_prefix) metric_prefix = os.path.join(model_prefix, 'metric') # save metric and config with open(metric_prefix + '.states', 'wb') as f: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 7eb77dec..9345106e 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -38,6 +38,7 @@ def init_args(): parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--min_subgraph_size", type=int, default=15) + parser.add_argument("--shape_info_filename", type=str, default=None) parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) @@ -204,9 +205,18 @@ def create_predictor(args, mode, logger): workspace_size=1 << 30, precision_mode=precision, max_batch_size=args.max_batch_size, - min_subgraph_size=args.min_subgraph_size, + min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph use_calib_mode=False) - # skip the minmum trt subgraph + + # collect shape + if args.shape_info_filename is not None: + if not os.path.exists(args.shape_info_filename): + config.collect_shape_range_info(args.shape_info_filename) + logger.info(f"collect dynamic shape info into : {args.shape_info_filename}") + else: + logger.info(f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again.") + config.enable_tuned_tensorrt_dynamic_shape(args.shape_info_filename, True) + use_dynamic_shape = True if mode == "det": min_input_shape = { diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py index 20ab1fe1..51378bda 100755 --- a/tools/infer_vqa_token_ser_re.py +++ b/tools/infer_vqa_token_ser_re.py @@ -113,10 +113,13 @@ def make_input(ser_inputs, ser_results): class SerRePredictor(object): def __init__(self, config, ser_config): + global_config = config['Global'] + if "infer_mode" in global_config: + ser_config["Global"]["infer_mode"] = global_config["infer_mode"] + self.ser_engine = SerPredictor(ser_config) # init re model - global_config = config['Global'] # build post process self.post_process_class = build_post_process(config['PostProcess'], @@ -130,8 +133,8 @@ class SerRePredictor(object): self.model.eval() - def __call__(self, img_path): - ser_results, ser_inputs = self.ser_engine({'img_path': img_path}) + def __call__(self, data): + ser_results, ser_inputs = self.ser_engine(data) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) preds = self.model(re_input) post_result = self.post_process_class( @@ -173,18 +176,33 @@ if __name__ == '__main__': ser_re_engine = SerRePredictor(config, ser_config) - infer_imgs = get_image_file_list(config['Global']['infer_img']) + if config["Global"].get("infer_mode", None) is False: + data_dir = config['Eval']['dataset']['data_dir'] + with open(config['Global']['infer_img'], "rb") as f: + infer_imgs = f.readlines() + else: + infer_imgs = get_image_file_list(config['Global']['infer_img']) + with open( os.path.join(config['Global']['save_res_path'], "infer_results.txt"), "w", encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): + for idx, info in enumerate(infer_imgs): + if config["Global"].get("infer_mode", None) is False: + data_line = info.decode('utf-8') + substr = data_line.strip("\n").split("\t") + img_path = os.path.join(data_dir, substr[0]) + data = {'img_path': img_path, 'label': substr[1]} + else: + img_path = info + data = {'img_path': img_path} + save_img_path = os.path.join( config['Global']['save_res_path'], os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg") - result = ser_re_engine(img_path) + result = ser_re_engine(data) result = result[0] fout.write(img_path + "\t" + json.dumps( { -- GitLab