diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml index 8d0ffadf0607687c96b2e3ecdb9fae059af6bf51..29005bfbd4c085cc2514cb9d5eca1825c6dc48a4 100644 --- a/configs/vqa/re/layoutxlm.yml +++ b/configs/vqa/re/layoutxlm.yml @@ -8,7 +8,6 @@ Global: # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] cal_metric_during_train: False - pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file save_inference_dir: use_visualdl: False infer_img: doc/vqa/input/zh_val_21.jpg @@ -20,7 +19,7 @@ Architecture: Transform: Backbone: name: LayoutXLMForRe - pretrained_model: *pretrained_model + pretrained: True checkpoints: Loss: diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml index f29153a2e9bea44f758a38558aa4c45a25f53213..805a3993ffd9f76716755cf7fa2cfc5d440462e5 100644 --- a/configs/vqa/ser/layoutlm.yml +++ b/configs/vqa/ser/layoutlm.yml @@ -8,7 +8,6 @@ Global: # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] cal_metric_during_train: False - pretrained_model: &pretrained_model layoutlm-base-uncased # This field can only be changed by modifying the configuration file save_inference_dir: use_visualdl: False infer_img: doc/vqa/input/zh_val_0.jpg @@ -20,7 +19,7 @@ Architecture: Transform: Backbone: name: LayoutLMForSer - pretrained_model: *pretrained_model + pretrained: True checkpoints: num_classes: &num_classes 7 diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml index 14041eb26246f396b15f3b11754bda36609a47b9..54b1899c68a7e7b07fd13d69d49ece302662d00c 100644 --- a/configs/vqa/ser/layoutxlm.yml +++ b/configs/vqa/ser/layoutxlm.yml @@ -8,7 +8,6 @@ Global: # evaluation is run every 10 iterations after the 0th iteration eval_batch_step: [ 0, 19 ] cal_metric_during_train: False - pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file save_inference_dir: use_visualdl: False infer_img: doc/vqa/input/zh_val_42.jpg @@ -20,7 +19,7 @@ Architecture: Transform: Backbone: name: LayoutXLMForSer - pretrained_model: *pretrained_model + pretrained: True checkpoints: num_classes: &num_classes 7 diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index a2f46fc6b55304f7abca411d25f7864ced8c3887..0e98155514cdd055680f32b529fdce631384a37f 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -24,21 +24,32 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification __all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] +pretrained_model_dict = { + LayoutXLMModel: 'layoutxlm-base-uncased', + LayoutLMModel: 'layoutlm-base-uncased' +} + class NLPBaseModel(nn.Layer): def __init__(self, base_model_class, model_class, type='ser', - pretrained_model=None, + pretrained=True, checkpoints=None, **kwargs): super(NLPBaseModel, self).__init__() - assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None" if checkpoints is not None: self.model = model_class.from_pretrained(checkpoints) else: - base_model = base_model_class.from_pretrained(pretrained_model) + pretrained_model_name = pretrained_model_dict[base_model_class] + if pretrained: + base_model = base_model_class.from_pretrained( + pretrained_model_name) + else: + base_model = base_model_class( + **base_model_class.pretrained_init_configuration[ + pretrained_model_name]) if type == 'ser': self.model = model_class( base_model, num_classes=kwargs['num_classes'], dropout=None) @@ -48,16 +59,13 @@ class NLPBaseModel(nn.Layer): class LayoutXLMForSer(NLPBaseModel): - def __init__(self, - num_classes, - pretrained_model='layoutxlm-base-uncased', - checkpoints=None, + def __init__(self, num_classes, pretrained=True, checkpoints=None, **kwargs): super(LayoutXLMForSer, self).__init__( LayoutXLMModel, LayoutXLMForTokenClassification, 'ser', - pretrained_model, + pretrained, checkpoints, num_classes=num_classes) @@ -75,16 +83,13 @@ class LayoutXLMForSer(NLPBaseModel): class LayoutLMForSer(NLPBaseModel): - def __init__(self, - num_classes, - pretrained_model='layoutxlm-base-uncased', - checkpoints=None, + def __init__(self, num_classes, pretrained=True, checkpoints=None, **kwargs): super(LayoutLMForSer, self).__init__( LayoutLMModel, LayoutLMForTokenClassification, 'ser', - pretrained_model, + pretrained, checkpoints, num_classes=num_classes) @@ -100,13 +105,10 @@ class LayoutLMForSer(NLPBaseModel): class LayoutXLMForRe(NLPBaseModel): - def __init__(self, - pretrained_model='layoutxlm-base-uncased', - checkpoints=None, - **kwargs): - super(LayoutXLMForRe, self).__init__( - LayoutXLMModel, LayoutXLMForRelationExtraction, 're', - pretrained_model, checkpoints) + def __init__(self, pretrained=True, checkpoints=None, **kwargs): + super(LayoutXLMForRe, self).__init__(LayoutXLMModel, + LayoutXLMForRelationExtraction, + 're', pretrained, checkpoints) def forward(self, x): x = self.model(