提交 1bcbd318 编写于 作者: 文幕地方's avatar 文幕地方

add layoutlmv2

上级 6fe387ce
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
name: Piecewise
values: [0.000005, 0.00005]
decay_epochs: [10]
warmup_epoch: 0
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
...@@ -34,7 +34,10 @@ Optimizer: ...@@ -34,7 +34,10 @@ Optimizer:
beta2: 0.999 beta2: 0.999
clip_norm: 10 clip_norm: 10
lr: lr:
learning_rate: 0.00005 name: Piecewise
values: [0.000005, 0.00005]
decay_epochs: [10]
warmup_epoch: 0
regularizer: regularizer:
name: L2 name: L2
factor: 0.00000 factor: 0.00000
...@@ -81,7 +84,7 @@ Train: ...@@ -81,7 +84,7 @@ Train:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 4 num_workers: 8
collate_fn: ListCollator collate_fn: ListCollator
Eval: Eval:
...@@ -118,5 +121,5 @@ Eval: ...@@ -118,5 +121,5 @@ Eval:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 4 num_workers: 8
collate_fn: ListCollator collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -799,7 +799,7 @@ class VQATokenLabelEncode(object): ...@@ -799,7 +799,7 @@ class VQATokenLabelEncode(object):
ocr_engine=None, ocr_engine=None,
**kwargs): **kwargs):
super(VQATokenLabelEncode, self).__init__() super(VQATokenLabelEncode, self).__init__()
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
from ppocr.utils.utility import load_vqa_bio_label_maps from ppocr.utils.utility import load_vqa_bio_label_maps
tokenizer_dict = { tokenizer_dict = {
'LayoutXLM': { 'LayoutXLM': {
...@@ -809,6 +809,10 @@ class VQATokenLabelEncode(object): ...@@ -809,6 +809,10 @@ class VQATokenLabelEncode(object):
'LayoutLM': { 'LayoutLM': {
'class': LayoutLMTokenizer, 'class': LayoutLMTokenizer,
'pretrained_model': 'layoutlm-base-uncased' 'pretrained_model': 'layoutlm-base-uncased'
},
'LayoutLMv2': {
'class': LayoutLMv2Tokenizer,
'pretrained_model': 'layoutlmv2-base-uncased'
} }
} }
self.contains_re = contains_re self.contains_re = contains_re
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
class VQASerTokenChunk(object): class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
...@@ -39,6 +41,8 @@ class VQASerTokenChunk(object): ...@@ -39,6 +41,8 @@ class VQASerTokenChunk(object):
encoded_inputs_example[key] = data[key] encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example) encoded_inputs_all.append(encoded_inputs_example)
if len(encoded_inputs_all) == 0:
return None
return encoded_inputs_all[0] return encoded_inputs_all[0]
...@@ -101,17 +105,18 @@ class VQAReTokenChunk(object): ...@@ -101,17 +105,18 @@ class VQAReTokenChunk(object):
"entities": self.reformat(entities_in_this_span), "entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span), "relations": self.reformat(relations_in_this_span),
}) })
if len(item['entities']) > 0:
item['entities']['label'] = [ item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label'] self.entities_labels[x] for x in item['entities']['label']
] ]
encoded_inputs_all.append(item) encoded_inputs_all.append(item)
if len(encoded_inputs_all) == 0:
return None
return encoded_inputs_all[0] return encoded_inputs_all[0]
def reformat(self, data): def reformat(self, data):
new_data = {} new_data = defaultdict(list)
for item in data: for item in data:
for k, v in item.items(): for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v) new_data[k].append(v)
return new_data return new_data
...@@ -45,8 +45,11 @@ def build_backbone(config, model_type): ...@@ -45,8 +45,11 @@ def build_backbone(config, model_type):
from .table_mobilenet_v3 import MobileNetV3 from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"] support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa': elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe'] support_dict = [
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
"LayoutXLMForSer", 'LayoutXLMForRe'
]
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -21,12 +21,14 @@ from paddle import nn ...@@ -21,12 +21,14 @@ from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] __all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
pretrained_model_dict = { pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased', LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased' LayoutLMModel: 'layoutlm-base-uncased',
LayoutLMv2Model: 'layoutlmv2-base-uncased'
} }
...@@ -58,12 +60,34 @@ class NLPBaseModel(nn.Layer): ...@@ -58,12 +60,34 @@ class NLPBaseModel(nn.Layer):
self.out_channels = 1 self.out_channels = 1
class LayoutXLMForSer(NLPBaseModel): class LayoutLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None, def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs): **kwargs):
super(LayoutXLMForSer, self).__init__( super(LayoutLMForSer, self).__init__(
LayoutXLMModel, LayoutLMModel,
LayoutXLMForTokenClassification, LayoutLMForTokenClassification,
'ser',
pretrained,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
return x
class LayoutLMv2ForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model,
LayoutLMv2ForTokenClassification,
'ser', 'ser',
pretrained, pretrained,
checkpoints, checkpoints,
...@@ -82,12 +106,12 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -82,12 +106,12 @@ class LayoutXLMForSer(NLPBaseModel):
return x[0] return x[0]
class LayoutLMForSer(NLPBaseModel): class LayoutXLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None, def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs): **kwargs):
super(LayoutLMForSer, self).__init__( super(LayoutXLMForSer, self).__init__(
LayoutLMModel, LayoutXLMModel,
LayoutLMForTokenClassification, LayoutXLMForTokenClassification,
'ser', 'ser',
pretrained, pretrained,
checkpoints, checkpoints,
...@@ -97,10 +121,33 @@ class LayoutLMForSer(NLPBaseModel): ...@@ -97,10 +121,33 @@ class LayoutLMForSer(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[2],
image=x[3],
attention_mask=x[4], attention_mask=x[4],
token_type_ids=x[5], token_type_ids=x[5],
position_ids=None, position_ids=None,
output_hidden_states=False) head_mask=None,
labels=None)
return x[0]
class LayoutLMv2ForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
LayoutLMv2ForRelationExtraction,
're', pretrained, checkpoints)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
position_ids=None,
head_mask=None,
entities=x[5],
relations=x[6])
return x return x
......
...@@ -24,6 +24,8 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -24,6 +24,8 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
|:---:|:---:|:---:| :---:| |:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | | LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | | LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutLMv2 | RE | 0.6777 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
| LayoutLMv2 | SER | 0.8544 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | | LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册