提交 cd7b2ea9 编写于 作者: 文幕地方's avatar 文幕地方

add pretrained params to backbone

上级 9ecfc348
......@@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir:
use_visualdl: False
infer_img: doc/vqa/input/zh_val_21.jpg
......@@ -20,7 +19,7 @@ Architecture:
Transform:
Backbone:
name: LayoutXLMForRe
pretrained_model: *pretrained_model
pretrained: True
checkpoints:
Loss:
......
......@@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir:
use_visualdl: False
infer_img: doc/vqa/input/zh_val_0.jpg
......@@ -20,7 +19,7 @@ Architecture:
Transform:
Backbone:
name: LayoutLMForSer
pretrained_model: *pretrained_model
pretrained: True
checkpoints:
num_classes: &num_classes 7
......
......@@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir:
use_visualdl: False
infer_img: doc/vqa/input/zh_val_42.jpg
......@@ -20,7 +19,7 @@ Architecture:
Transform:
Backbone:
name: LayoutXLMForSer
pretrained_model: *pretrained_model
pretrained: True
checkpoints:
num_classes: &num_classes 7
......
......@@ -24,21 +24,32 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased'
}
class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
pretrained_model=None,
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
else:
base_model = base_model_class.from_pretrained(pretrained_model)
pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class(
**base_model_class.pretrained_init_configuration[
pretrained_model_name])
if type == 'ser':
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
......@@ -48,16 +59,13 @@ class NLPBaseModel(nn.Layer):
class LayoutXLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained_model,
pretrained,
checkpoints,
num_classes=num_classes)
......@@ -75,16 +83,13 @@ class LayoutXLMForSer(NLPBaseModel):
class LayoutLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained_model,
pretrained,
checkpoints,
num_classes=num_classes)
......@@ -100,13 +105,10 @@ class LayoutLMForSer(NLPBaseModel):
class LayoutXLMForRe(NLPBaseModel):
def __init__(self,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
pretrained_model, checkpoints)
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
LayoutXLMForRelationExtraction,
're', pretrained, checkpoints)
def forward(self, x):
x = self.model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册