From 1bcbd3181568d07b12000deb6bbe0a148ebffef3 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Sat, 12 Feb 2022 07:17:38 +0000 Subject: [PATCH] add layoutlmv2 --- configs/vqa/re/layoutlmv2.yml | 125 ++++++++++++++++++ configs/vqa/re/layoutxlm.yml | 11 +- configs/vqa/ser/layoutlmv2.yml | 121 +++++++++++++++++ ppocr/data/imaug/label_ops.py | 6 +- ppocr/data/imaug/vqa/token/vqa_token_chunk.py | 19 ++- ppocr/modeling/backbones/__init__.py | 7 +- ppocr/modeling/backbones/vqa_layoutlm.py | 67 ++++++++-- ppstructure/vqa/README.md | 2 + 8 files changed, 334 insertions(+), 24 deletions(-) create mode 100644 configs/vqa/re/layoutlmv2.yml create mode 100644 configs/vqa/ser/layoutlmv2.yml diff --git a/configs/vqa/re/layoutlmv2.yml b/configs/vqa/re/layoutlmv2.yml new file mode 100644 index 00000000..9daa2a96 --- /dev/null +++ b/configs/vqa/re/layoutlmv2.yml @@ -0,0 +1,125 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/re_layoutlmv2/ + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2048 + infer_img: doc/vqa/input/zh_val_21.jpg + save_res_path: ./output/re/ + +Architecture: + model_type: vqa + algorithm: &algorithm "LayoutLMv2" + Transform: + Backbone: + name: LayoutLMv2ForRe + pretrained: True + checkpoints: + +Loss: + name: LossFromOutput + key: loss + reduction: mean + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + clip_norm: 10 + lr: + name: Piecewise + values: [0.000005, 0.00005] + decay_epochs: [10] + warmup_epoch: 0 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: VQAReTokenLayoutLMPostProcess + +Metric: + name: VQAReTokenMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/xfun_normalize_train.json + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQAReTokenRelation: + - VQAReTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 8 + collate_fn: ListCollator + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/xfun_normalize_val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQAReTokenRelation: + - VQAReTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 8 + collate_fn: ListCollator diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml index ca6b0d29..d413b174 100644 --- a/configs/vqa/re/layoutxlm.yml +++ b/configs/vqa/re/layoutxlm.yml @@ -21,7 +21,7 @@ Architecture: Backbone: name: LayoutXLMForRe pretrained: True - checkpoints: + checkpoints: Loss: name: LossFromOutput @@ -34,7 +34,10 @@ Optimizer: beta2: 0.999 clip_norm: 10 lr: - learning_rate: 0.00005 + name: Piecewise + values: [0.000005, 0.00005] + decay_epochs: [10] + warmup_epoch: 0 regularizer: name: L2 factor: 0.00000 @@ -81,7 +84,7 @@ Train: shuffle: True drop_last: False batch_size_per_card: 8 - num_workers: 4 + num_workers: 8 collate_fn: ListCollator Eval: @@ -118,5 +121,5 @@ Eval: shuffle: False drop_last: False batch_size_per_card: 8 - num_workers: 4 + num_workers: 8 collate_fn: ListCollator diff --git a/configs/vqa/ser/layoutlmv2.yml b/configs/vqa/ser/layoutlmv2.yml new file mode 100644 index 00000000..33406252 --- /dev/null +++ b/configs/vqa/ser/layoutlmv2.yml @@ -0,0 +1,121 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/ser_layoutlmv2/ + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: doc/vqa/input/zh_val_0.jpg + save_res_path: ./output/ser/ + +Architecture: + model_type: vqa + algorithm: &algorithm "LayoutLMv2" + Transform: + Backbone: + name: LayoutLMv2ForSer + pretrained: True + checkpoints: + num_classes: &num_classes 7 + +Loss: + name: VQASerTokenLayoutLMLoss + num_classes: *num_classes + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Linear + learning_rate: 0.00005 + epochs: *epoch_num + warmup_epoch: 2 + regularizer: + + name: L2 + factor: 0.00000 + +PostProcess: + name: VQASerTokenLayoutLMPostProcess + class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + +Metric: + name: VQASerTokenMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/xfun_normalize_train.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/xfun_normalize_val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 4 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 786647f1..ef962b17 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -799,7 +799,7 @@ class VQATokenLabelEncode(object): ocr_engine=None, **kwargs): super(VQATokenLabelEncode, self).__init__() - from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer + from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer from ppocr.utils.utility import load_vqa_bio_label_maps tokenizer_dict = { 'LayoutXLM': { @@ -809,6 +809,10 @@ class VQATokenLabelEncode(object): 'LayoutLM': { 'class': LayoutLMTokenizer, 'pretrained_model': 'layoutlm-base-uncased' + }, + 'LayoutLMv2': { + 'class': LayoutLMv2Tokenizer, + 'pretrained_model': 'layoutlmv2-base-uncased' } } self.contains_re = contains_re diff --git a/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py index deb55b4d..1fa949e6 100644 --- a/ppocr/data/imaug/vqa/token/vqa_token_chunk.py +++ b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict + class VQASerTokenChunk(object): def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): @@ -39,6 +41,8 @@ class VQASerTokenChunk(object): encoded_inputs_example[key] = data[key] encoded_inputs_all.append(encoded_inputs_example) + if len(encoded_inputs_all) == 0: + return None return encoded_inputs_all[0] @@ -101,17 +105,18 @@ class VQAReTokenChunk(object): "entities": self.reformat(entities_in_this_span), "relations": self.reformat(relations_in_this_span), }) - item['entities']['label'] = [ - self.entities_labels[x] for x in item['entities']['label'] - ] - encoded_inputs_all.append(item) + if len(item['entities']) > 0: + item['entities']['label'] = [ + self.entities_labels[x] for x in item['entities']['label'] + ] + encoded_inputs_all.append(item) + if len(encoded_inputs_all) == 0: + return None return encoded_inputs_all[0] def reformat(self, data): - new_data = {} + new_data = defaultdict(list) for item in data: for k, v in item.items(): - if k not in new_data: - new_data[k] = [] new_data[k].append(v) return new_data diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index a7db52d2..b34b7550 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -45,8 +45,11 @@ def build_backbone(config, model_type): from .table_mobilenet_v3 import MobileNetV3 support_dict = ["ResNet", "MobileNetV3"] elif model_type == 'vqa': - from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe - support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe'] + from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe + support_dict = [ + "LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe', + "LayoutXLMForSer", 'LayoutXLMForRe' + ] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index 0e981555..ede5b7a3 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -21,12 +21,14 @@ from paddle import nn from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification +from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction __all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] pretrained_model_dict = { LayoutXLMModel: 'layoutxlm-base-uncased', - LayoutLMModel: 'layoutlm-base-uncased' + LayoutLMModel: 'layoutlm-base-uncased', + LayoutLMv2Model: 'layoutlmv2-base-uncased' } @@ -58,12 +60,34 @@ class NLPBaseModel(nn.Layer): self.out_channels = 1 -class LayoutXLMForSer(NLPBaseModel): +class LayoutLMForSer(NLPBaseModel): def __init__(self, num_classes, pretrained=True, checkpoints=None, **kwargs): - super(LayoutXLMForSer, self).__init__( - LayoutXLMModel, - LayoutXLMForTokenClassification, + super(LayoutLMForSer, self).__init__( + LayoutLMModel, + LayoutLMForTokenClassification, + 'ser', + pretrained, + checkpoints, + num_classes=num_classes) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[2], + attention_mask=x[4], + token_type_ids=x[5], + position_ids=None, + output_hidden_states=False) + return x + + +class LayoutLMv2ForSer(NLPBaseModel): + def __init__(self, num_classes, pretrained=True, checkpoints=None, + **kwargs): + super(LayoutLMv2ForSer, self).__init__( + LayoutLMv2Model, + LayoutLMv2ForTokenClassification, 'ser', pretrained, checkpoints, @@ -82,12 +106,12 @@ class LayoutXLMForSer(NLPBaseModel): return x[0] -class LayoutLMForSer(NLPBaseModel): +class LayoutXLMForSer(NLPBaseModel): def __init__(self, num_classes, pretrained=True, checkpoints=None, **kwargs): - super(LayoutLMForSer, self).__init__( - LayoutLMModel, - LayoutLMForTokenClassification, + super(LayoutXLMForSer, self).__init__( + LayoutXLMModel, + LayoutXLMForTokenClassification, 'ser', pretrained, checkpoints, @@ -97,10 +121,33 @@ class LayoutLMForSer(NLPBaseModel): x = self.model( input_ids=x[0], bbox=x[2], + image=x[3], attention_mask=x[4], token_type_ids=x[5], position_ids=None, - output_hidden_states=False) + head_mask=None, + labels=None) + return x[0] + + +class LayoutLMv2ForRe(NLPBaseModel): + def __init__(self, pretrained=True, checkpoints=None, **kwargs): + super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model, + LayoutLMv2ForRelationExtraction, + 're', pretrained, checkpoints) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[1], + labels=None, + image=x[2], + attention_mask=x[3], + token_type_ids=x[4], + position_ids=None, + head_mask=None, + entities=x[5], + relations=x[6]) return x diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index 7f4ca119..4de815af 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -24,6 +24,8 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 |:---:|:---:|:---:| :---:| | LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | | LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | +| LayoutLMv2 | RE | 0.6777 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | +| LayoutLMv2 | SER | 0.8544 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | | LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | -- GitLab