提交 cd7b2ea9 编写于 作者: 文幕地方's avatar 文幕地方

add pretrained params to backbone

上级 9ecfc348
...@@ -8,7 +8,6 @@ Global: ...@@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: doc/vqa/input/zh_val_21.jpg
...@@ -20,7 +19,7 @@ Architecture: ...@@ -20,7 +19,7 @@ Architecture:
Transform: Transform:
Backbone: Backbone:
name: LayoutXLMForRe name: LayoutXLMForRe
pretrained_model: *pretrained_model pretrained: True
checkpoints: checkpoints:
Loss: Loss:
......
...@@ -8,7 +8,6 @@ Global: ...@@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: &pretrained_model layoutlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: doc/vqa/input/zh_val_0.jpg
...@@ -20,7 +19,7 @@ Architecture: ...@@ -20,7 +19,7 @@ Architecture:
Transform: Transform:
Backbone: Backbone:
name: LayoutLMForSer name: LayoutLMForSer
pretrained_model: *pretrained_model pretrained: True
checkpoints: checkpoints:
num_classes: &num_classes 7 num_classes: &num_classes 7
......
...@@ -8,7 +8,6 @@ Global: ...@@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/vqa/input/zh_val_42.jpg infer_img: doc/vqa/input/zh_val_42.jpg
...@@ -20,7 +19,7 @@ Architecture: ...@@ -20,7 +19,7 @@ Architecture:
Transform: Transform:
Backbone: Backbone:
name: LayoutXLMForSer name: LayoutXLMForSer
pretrained_model: *pretrained_model pretrained: True
checkpoints: checkpoints:
num_classes: &num_classes 7 num_classes: &num_classes 7
......
...@@ -24,21 +24,32 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification ...@@ -24,21 +24,32 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] __all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased'
}
class NLPBaseModel(nn.Layer): class NLPBaseModel(nn.Layer):
def __init__(self, def __init__(self,
base_model_class, base_model_class,
model_class, model_class,
type='ser', type='ser',
pretrained_model=None, pretrained=True,
checkpoints=None, checkpoints=None,
**kwargs): **kwargs):
super(NLPBaseModel, self).__init__() super(NLPBaseModel, self).__init__()
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
if checkpoints is not None: if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints) self.model = model_class.from_pretrained(checkpoints)
else: else:
base_model = base_model_class.from_pretrained(pretrained_model) pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class(
**base_model_class.pretrained_init_configuration[
pretrained_model_name])
if type == 'ser': if type == 'ser':
self.model = model_class( self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None) base_model, num_classes=kwargs['num_classes'], dropout=None)
...@@ -48,16 +59,13 @@ class NLPBaseModel(nn.Layer): ...@@ -48,16 +59,13 @@ class NLPBaseModel(nn.Layer):
class LayoutXLMForSer(NLPBaseModel): class LayoutXLMForSer(NLPBaseModel):
def __init__(self, def __init__(self, num_classes, pretrained=True, checkpoints=None,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs): **kwargs):
super(LayoutXLMForSer, self).__init__( super(LayoutXLMForSer, self).__init__(
LayoutXLMModel, LayoutXLMModel,
LayoutXLMForTokenClassification, LayoutXLMForTokenClassification,
'ser', 'ser',
pretrained_model, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes)
...@@ -75,16 +83,13 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -75,16 +83,13 @@ class LayoutXLMForSer(NLPBaseModel):
class LayoutLMForSer(NLPBaseModel): class LayoutLMForSer(NLPBaseModel):
def __init__(self, def __init__(self, num_classes, pretrained=True, checkpoints=None,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs): **kwargs):
super(LayoutLMForSer, self).__init__( super(LayoutLMForSer, self).__init__(
LayoutLMModel, LayoutLMModel,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
'ser', 'ser',
pretrained_model, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes)
...@@ -100,13 +105,10 @@ class LayoutLMForSer(NLPBaseModel): ...@@ -100,13 +105,10 @@ class LayoutLMForSer(NLPBaseModel):
class LayoutXLMForRe(NLPBaseModel): class LayoutXLMForRe(NLPBaseModel):
def __init__(self, def __init__(self, pretrained=True, checkpoints=None, **kwargs):
pretrained_model='layoutxlm-base-uncased', super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
checkpoints=None, LayoutXLMForRelationExtraction,
**kwargs): 're', pretrained, checkpoints)
super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
pretrained_model, checkpoints)
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册