提交 1bcbd318 编写于 作者: 文幕地方's avatar 文幕地方

add layoutlmv2

上级 6fe387ce
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
name: Piecewise
values: [0.000005, 0.00005]
decay_epochs: [10]
warmup_epoch: 0
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
......@@ -21,7 +21,7 @@ Architecture:
Backbone:
name: LayoutXLMForRe
pretrained: True
checkpoints:
checkpoints:
Loss:
name: LossFromOutput
......@@ -34,7 +34,10 @@ Optimizer:
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
name: Piecewise
values: [0.000005, 0.00005]
decay_epochs: [10]
warmup_epoch: 0
regularizer:
name: L2
factor: 0.00000
......@@ -81,7 +84,7 @@ Train:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
num_workers: 8
collate_fn: ListCollator
Eval:
......@@ -118,5 +121,5 @@ Eval:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
num_workers: 8
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
......@@ -799,7 +799,7 @@ class VQATokenLabelEncode(object):
ocr_engine=None,
**kwargs):
super(VQATokenLabelEncode, self).__init__()
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
from ppocr.utils.utility import load_vqa_bio_label_maps
tokenizer_dict = {
'LayoutXLM': {
......@@ -809,6 +809,10 @@ class VQATokenLabelEncode(object):
'LayoutLM': {
'class': LayoutLMTokenizer,
'pretrained_model': 'layoutlm-base-uncased'
},
'LayoutLMv2': {
'class': LayoutLMv2Tokenizer,
'pretrained_model': 'layoutlmv2-base-uncased'
}
}
self.contains_re = contains_re
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
......@@ -39,6 +41,8 @@ class VQASerTokenChunk(object):
encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example)
if len(encoded_inputs_all) == 0:
return None
return encoded_inputs_all[0]
......@@ -101,17 +105,18 @@ class VQAReTokenChunk(object):
"entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
if len(item['entities']) > 0:
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
if len(encoded_inputs_all) == 0:
return None
return encoded_inputs_all[0]
def reformat(self, data):
new_data = {}
new_data = defaultdict(list)
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
......@@ -45,8 +45,11 @@ def build_backbone(config, model_type):
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
support_dict = [
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
"LayoutXLMForSer", 'LayoutXLMForRe'
]
else:
raise NotImplementedError
......
......@@ -21,12 +21,14 @@ from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased'
LayoutLMModel: 'layoutlm-base-uncased',
LayoutLMv2Model: 'layoutlmv2-base-uncased'
}
......@@ -58,12 +60,34 @@ class NLPBaseModel(nn.Layer):
self.out_channels = 1
class LayoutXLMForSer(NLPBaseModel):
class LayoutLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
return x
class LayoutLMv2ForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model,
LayoutLMv2ForTokenClassification,
'ser',
pretrained,
checkpoints,
......@@ -82,12 +106,12 @@ class LayoutXLMForSer(NLPBaseModel):
return x[0]
class LayoutLMForSer(NLPBaseModel):
class LayoutXLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained,
checkpoints,
......@@ -97,10 +121,33 @@ class LayoutLMForSer(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
head_mask=None,
labels=None)
return x[0]
class LayoutLMv2ForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
LayoutLMv2ForRelationExtraction,
're', pretrained, checkpoints)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
position_ids=None,
head_mask=None,
entities=x[5],
relations=x[6])
return x
......
......@@ -24,6 +24,8 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
|:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutLMv2 | RE | 0.6777 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
| LayoutLMv2 | SER | 0.8544 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册