From cd7b2ea923e96dd9d5302fcaba8a619d98c3c7dc Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Thu, 6 Jan 2022 03:35:30 +0000 Subject: [PATCH] add pretrained params to backbone --- configs/vqa/re/layoutxlm.yml | 3 +- configs/vqa/ser/layoutlm.yml | 3 +- configs/vqa/ser/layoutxlm.yml | 3 +- ppocr/modeling/backbones/vqa_layoutlm.py | 42 +++++++++++++----------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml index 8d0ffadf..29005bfb 100644 --- a/configs/vqa/re/layoutxlm.yml +++ b/configs/vqa/re/layoutxlm.yml @@ -8,7 +8,6 @@ Global: # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] cal_metric_during_train: False - pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file save_inference_dir: use_visualdl: False infer_img: doc/vqa/input/zh_val_21.jpg @@ -20,7 +19,7 @@ Architecture: Transform: Backbone: name: LayoutXLMForRe - pretrained_model: *pretrained_model + pretrained: True checkpoints: Loss: diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml index f29153a2..805a3993 100644 --- a/configs/vqa/ser/layoutlm.yml +++ b/configs/vqa/ser/layoutlm.yml @@ -8,7 +8,6 @@ Global: # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] cal_metric_during_train: False - pretrained_model: &pretrained_model layoutlm-base-uncased # This field can only be changed by modifying the configuration file save_inference_dir: use_visualdl: False infer_img: doc/vqa/input/zh_val_0.jpg @@ -20,7 +19,7 @@ Architecture: Transform: Backbone: name: LayoutLMForSer - pretrained_model: *pretrained_model + pretrained: True checkpoints: num_classes: &num_classes 7 diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml index 14041eb2..54b1899c 100644 --- a/configs/vqa/ser/layoutxlm.yml +++ b/configs/vqa/ser/layoutxlm.yml @@ -8,7 +8,6 @@ Global: # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] cal_metric_during_train: False - pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file save_inference_dir: use_visualdl: False infer_img: doc/vqa/input/zh_val_42.jpg @@ -20,7 +19,7 @@ Architecture: Transform: Backbone: name: LayoutXLMForSer - pretrained_model: *pretrained_model + pretrained: True checkpoints: num_classes: &num_classes 7 diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index a2f46fc6..0e981555 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -24,21 +24,32 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification __all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] +pretrained_model_dict = { + LayoutXLMModel: 'layoutxlm-base-uncased', + LayoutLMModel: 'layoutlm-base-uncased' +} + class NLPBaseModel(nn.Layer): def __init__(self, base_model_class, model_class, type='ser', - pretrained_model=None, + pretrained=True, checkpoints=None, **kwargs): super(NLPBaseModel, self).__init__() - assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None" if checkpoints is not None: self.model = model_class.from_pretrained(checkpoints) else: - base_model = base_model_class.from_pretrained(pretrained_model) + pretrained_model_name = pretrained_model_dict[base_model_class] + if pretrained: + base_model = base_model_class.from_pretrained( + pretrained_model_name) + else: + base_model = base_model_class( + **base_model_class.pretrained_init_configuration[ + pretrained_model_name]) if type == 'ser': self.model = model_class( base_model, num_classes=kwargs['num_classes'], dropout=None) @@ -48,16 +59,13 @@ class NLPBaseModel(nn.Layer): class LayoutXLMForSer(NLPBaseModel): - def __init__(self, - num_classes, - pretrained_model='layoutxlm-base-uncased', - checkpoints=None, + def __init__(self, num_classes, pretrained=True, checkpoints=None, **kwargs): super(LayoutXLMForSer, self).__init__( LayoutXLMModel, LayoutXLMForTokenClassification, 'ser', - pretrained_model, + pretrained, checkpoints, num_classes=num_classes) @@ -75,16 +83,13 @@ class LayoutXLMForSer(NLPBaseModel): class LayoutLMForSer(NLPBaseModel): - def __init__(self, - num_classes, - pretrained_model='layoutxlm-base-uncased', - checkpoints=None, + def __init__(self, num_classes, pretrained=True, checkpoints=None, **kwargs): super(LayoutLMForSer, self).__init__( LayoutLMModel, LayoutLMForTokenClassification, 'ser', - pretrained_model, + pretrained, checkpoints, num_classes=num_classes) @@ -100,13 +105,10 @@ class LayoutLMForSer(NLPBaseModel): class LayoutXLMForRe(NLPBaseModel): - def __init__(self, - pretrained_model='layoutxlm-base-uncased', - checkpoints=None, - **kwargs): - super(LayoutXLMForRe, self).__init__( - LayoutXLMModel, LayoutXLMForRelationExtraction, 're', - pretrained_model, checkpoints) + def __init__(self, pretrained=True, checkpoints=None, **kwargs): + super(LayoutXLMForRe, self).__init__(LayoutXLMModel, + LayoutXLMForRelationExtraction, + 're', pretrained, checkpoints) def forward(self, x): x = self.model( -- GitLab