提交 a323fce6 编写于 作者: 文幕地方's avatar 文幕地方

vqa code integrated into ppocr training system

上级 1ded2ac4
Global:
use_gpu: True
epoch_num: 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 38 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased
save_inference_dir:
use_visualdl: False
infer_img: ppstructure/vqa/images/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained_model: *pretrained_model
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
regularizer:
name: Const
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutlm-base-uncased
save_inference_dir:
use_visualdl: False
infer_img: ppstructure/vqa/images/input/zh_val_0.jpg
save_res_path: ./output/ser/predicts_layoutlm.txt
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained_model: *pretrained_model
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: Const
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased
save_inference_dir:
use_visualdl: False
infer_img: ppstructure/vqa/images/input/zh_val_42.jpg
save_res_path: ./output/ser
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained_model: *pretrained_model
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: Const
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -86,13 +86,19 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -86,13 +86,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
if 'collate_fn' in loader_config:
from . import collate_fn
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
else:
collate_fn = None
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory) use_shared_memory=use_shared_memory,
collate_fn=collate_fn)
# support exit using ctrl+c # support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGINT, term_mp)
......
...@@ -15,20 +15,19 @@ ...@@ -15,20 +15,19 @@
import paddle import paddle
import numbers import numbers
import numpy as np import numpy as np
from collections import defaultdict
class DataCollator: class DictCollator(object):
""" """
data batch data batch
""" """
def __call__(self, batch): def __call__(self, batch):
data_dict = {} data_dict = defaultdict(list)
to_tensor_keys = [] to_tensor_keys = []
for sample in batch: for sample in batch:
for k, v in sample.items(): for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys: if k not in to_tensor_keys:
to_tensor_keys.append(k) to_tensor_keys.append(k)
...@@ -36,3 +35,22 @@ class DataCollator: ...@@ -36,3 +35,22 @@ class DataCollator:
for k in to_tensor_keys: for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k]) data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict return data_dict
class ListCollator(object):
"""
data batch
"""
def __call__(self, batch):
data_dict = defaultdict(list)
to_tensor_idxs = []
for sample in batch:
for idx, v in enumerate(sample):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
to_tensor_idxs.append(idx)
data_dict[idx].append(v)
for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values())
...@@ -34,6 +34,8 @@ from .sast_process import * ...@@ -34,6 +34,8 @@ from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import * from .gen_table_mask import *
from .vqa import *
def transform(data, ops=None): def transform(data, ops=None):
""" transform """ """ transform """
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import copy
import numpy as np import numpy as np
import string import string
from shapely.geometry import LineString, Point, Polygon from shapely.geometry import LineString, Point, Polygon
...@@ -736,7 +737,7 @@ class TableLabelEncode(object): ...@@ -736,7 +737,7 @@ class TableLabelEncode(object):
% beg_or_end % beg_or_end
else: else:
assert False, "Unsupport type %s in char_or_elem" \ assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem % char_or_elem
return idx return idx
...@@ -782,3 +783,208 @@ class SARLabelEncode(BaseRecLabelEncode): ...@@ -782,3 +783,208 @@ class SARLabelEncode(BaseRecLabelEncode):
def get_ignored_tokens(self): def get_ignored_tokens(self):
return [self.padding_idx] return [self.padding_idx]
class VQATokenLabelEncode(object):
"""
基于NLP的标签编码
"""
def __init__(self,
class_path,
contains_re=False,
add_special_ids=False,
algorithm='LayoutXLM',
infer_mode=False,
ocr_engine=None,
**kwargs):
super(VQATokenLabelEncode, self).__init__()
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
from ppocr.utils.utility import load_vqa_bio_label_maps
tokenizer_dict = {
'LayoutXLM': {
'class': LayoutXLMTokenizer,
'pretrained_model': 'layoutxlm-base-uncased'
},
'LayoutLM': {
'class': LayoutLMTokenizer,
'pretrained_model': 'layoutlm-base-uncased'
}
}
self.contains_re = contains_re
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config['class'].from_pretrained(
tokenizer_config['pretrained_model'])
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
def __call__(self, data):
if self.infer_mode == False:
return self._train(data)
else:
return self._infer(data)
def _train(self, data):
info = data['label']
# read text info
info_dict = json.loads(info)
height = info_dict["height"]
width = info_dict["width"]
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
gt_label_list = []
if self.contains_re:
# for re
entities = []
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
for info in info_dict["ocr_info"]:
if self.contains_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# x1, y1, x2, y2
bbox = info["bbox"]
label = info["label"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
gt_label = []
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
if label.lower() == "other":
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
if self.contains_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
gt_label_list.extend(gt_label)
words_list.append(text)
encoded_inputs = {
"input_ids": input_ids_list,
"labels": gt_label_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
}
data.update(encoded_inputs)
data['tokenizer_params'] = dict(
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id)
if self.contains_re:
data['entities'] = entities
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
def _infer(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
height, width, _ = data['image'].shape
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
entities = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = copy.deepcopy(info["bbox"])
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
# for re
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": "O",
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
"entities": entities,
'labels': None,
'segment_offset_id': segment_offset_id,
'ocr_info': ocr_info
}
data.update(encoded_inputs)
return data
...@@ -170,17 +170,19 @@ class Resize(object): ...@@ -170,17 +170,19 @@ class Resize(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
text_polys = data['polys'] if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img) img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = [] if 'polys' in data:
for box in text_polys: new_boxes = []
new_box = [] for box in text_polys:
for cord in box: new_box = []
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) for cord in box:
new_boxes.append(new_box) new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data return data
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
__all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
]
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode
def __call__(self, data):
encoded_inputs_all = []
seq_len = len(data['input_ids'])
for index in range(0, seq_len, self.max_seq_len):
chunk_beg = index
chunk_end = min(index + self.max_seq_len, seq_len)
encoded_inputs_example = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
encoded_inputs_example[key] = data[key]
else:
encoded_inputs_example[key] = data[key][chunk_beg:
chunk_end]
else:
encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example)
return encoded_inputs_all[0]
class VQAReTokenChunk(object):
def __init__(self,
max_seq_len=512,
entities_labels=None,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.entities_labels = {
'HEADER': 0,
'QUESTION': 1,
'ANSWER': 2
} if entities_labels is None else entities_labels
self.infer_mode = infer_mode
def __call__(self, data):
# prepare data
entities = data.pop('entities')
relations = data.pop('relations')
encoded_inputs_all = []
for index in range(0, len(data["input_ids"]), self.max_seq_len):
item = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
item[key] = data[key]
else:
item[key] = data[key][index:index + self.max_seq_len]
else:
item[key] = data[key]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + self.max_seq_len and
index <= entity["end"] < index + self.max_seq_len):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + self.max_seq_len
and index <= relation["end_index"] <
index + self.max_seq_len):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
return encoded_inputs_all[0]
def reformat(self, data):
new_data = {}
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
class VQATokenPad(object):
def __init__(self,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.pad_to_max_seq_len = max_seq_len
self.return_attention_mask = return_attention_mask
self.return_token_type_ids = return_token_type_ids
self.truncation_strategy = truncation_strategy
self.return_overflowing_tokens = return_overflowing_tokens
self.return_special_tokens_mask = return_special_tokens_mask
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.infer_mode = infer_mode
def __call__(self, data):
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
"input_ids"]) < self.max_seq_len
if needs_to_be_padded:
if 'tokenizer_params' in data:
tokenizer_params = data.pop('tokenizer_params')
else:
tokenizer_params = dict(
padding_side='right', pad_token_type_id=0, pad_token_id=1)
difference = self.max_seq_len - len(data["input_ids"])
if tokenizer_params['padding_side'] == 'right':
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data[
"input_ids"]) + [0] * difference
if self.return_token_type_ids:
data["token_type_ids"] = (
data["token_type_ids"] +
[tokenizer_params['pad_token_type_id']] * difference)
if self.return_special_tokens_mask:
data["special_tokens_mask"] = data[
"special_tokens_mask"] + [1] * difference
data["input_ids"] = data["input_ids"] + [
tokenizer_params['pad_token_id']
] * difference
if not self.infer_mode:
data["labels"] = data[
"labels"] + [self.pad_token_label_id] * difference
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
elif tokenizer_params['padding_side'] == 'left':
if self.return_attention_mask:
data["attention_mask"] = [0] * difference + [
1
] * len(data["input_ids"])
if self.return_token_type_ids:
data["token_type_ids"] = (
[tokenizer_params['pad_token_type_id']] * difference +
data["token_type_ids"])
if self.return_special_tokens_mask:
data["special_tokens_mask"] = [
1
] * difference + data["special_tokens_mask"]
data["input_ids"] = [tokenizer_params['pad_token_id']
] * difference + data["input_ids"]
if not self.infer_mode:
data["labels"] = [self.pad_token_label_id
] * difference + data["labels"]
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
else:
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data["input_ids"])
for key in data:
if key in [
'input_ids', 'labels', 'token_type_ids', 'bbox',
'attention_mask'
]:
if self.infer_mode and key == 'labels':
continue
length = min(len(data[key]), self.max_seq_len)
data[key] = np.array(data[key][:length], dtype='int64')
return data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQAReTokenRelation(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
"""
build relations
"""
entities = data['entities']
relations = data['relations']
id2label = data.pop('id2label')
empty_entity = data.pop('empty_entity')
entity_id_to_index_map = data.pop('entity_id_to_index_map')
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": self.get_relation_span(rel, entities)[0],
"end_index": self.get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
data['relations'] = relations
return data
def get_relation_span(self, rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)
...@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset): ...@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
...@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset): ...@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list): def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
...@@ -95,7 +96,7 @@ class SimpleDataSet(Dataset): ...@@ -95,7 +96,7 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None or data['polys'].shape[1]!=4: if data is None or data['polys'].shape[1] != 4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -121,7 +122,7 @@ class SimpleDataSet(Dataset): ...@@ -121,7 +122,7 @@ class SimpleDataSet(Dataset):
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data_line, traceback.format_exc())) data_line, traceback.format_exc()))
outs = None # outs = None
if outs is None: if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation. # during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__( rnd_idx = np.random.randint(self.__len__(
......
...@@ -16,6 +16,9 @@ import copy ...@@ -16,6 +16,9 @@ import copy
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
# basic_loss
from .basic_loss import LossFromOutput
# det loss # det loss
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss from .det_east_loss import EASTLoss
...@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss ...@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss' 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer): ...@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def forward(self, x, y): def forward(self, x, y):
return self.loss_func(x, y) return self.loss_func(x, y)
class LossFromOutput(nn.Layer):
def __init__(self, key='loss', reduction='none'):
super().__init__()
self.key = key
self.reduction = reduction
def forward(self, predicts, batch):
loss = predicts[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,24 +12,31 @@ ...@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn from paddle import nn
class SERLoss(nn.Layer): class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes): def __init__(self, num_classes):
super().__init__() super().__init__()
self.loss_class = nn.CrossEntropyLoss() self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
def forward(self, labels, outputs, attention_mask): def forward(self, predicts, batch):
labels = batch[1]
attention_mask = batch[4]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = outputs.reshape( active_outputs = predicts.reshape(
[-1, self.num_classes])[active_loss] [-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss] active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels) loss = self.loss_class(active_outputs, active_labels)
else: else:
loss = self.loss_class( loss = self.loss_class(
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ])) predicts.reshape([-1, self.num_classes]),
return loss labels.reshape([-1, ]))
return {'loss': loss}
...@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric ...@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric from .distillation_metric import DistillationMetric
from .table_metric import TableMetric from .table_metric import TableMetric
from .kie_metric import KIEMetric from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric' "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from seqeval.metrics import f1_score, precision_score, recall_score
__all__ = ['KIEMetric']
class VQAReTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
pred_relations, relations, entities = preds
self.pred_relations_list.extend(pred_relations)
self.relations_list.extend(relations)
self.entities_list.extend(entities)
def get_metric(self):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
for head, tail in zip(self.relations_list[b]["head"],
self.relations_list[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries")
metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"hmean": re_metrics["ALL"]["f1"],
}
self.reset()
return metrics
def reset(self):
self.pred_relations_list = []
self.relations_list = []
self.entities_list = []
def re_score(self, pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
return scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from seqeval.metrics import f1_score, precision_score, recall_score
__all__ = ['KIEMetric']
class VQASerTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
preds, labels = preds
self.pred_list.extend(preds)
self.gt_list.extend(labels)
def get_metric(self):
metircs = {
"precision": precision_score(self.gt_list, self.pred_list),
"recall": recall_score(self.gt_list, self.pred_list),
"hmean": f1_score(self.gt_list, self.pred_list),
}
self.reset()
return metircs
def reset(self):
self.pred_list = []
self.gt_list = []
...@@ -63,8 +63,12 @@ class BaseModel(nn.Layer): ...@@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
in_channels = self.neck.out_channels in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls # # build head, head is need for det, rec and cls
config["Head"]['in_channels'] = in_channels if 'Head' not in config or config['Head'] is None:
self.head = build_head(config["Head"]) self.use_head = False
else:
self.use_head = True
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False) self.return_all_feats = config.get("return_all_feats", False)
...@@ -77,7 +81,8 @@ class BaseModel(nn.Layer): ...@@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x y["neck_out"] = x
x = self.head(x, targets=data) if self.use_head:
x = self.head(x, targets=data)
if isinstance(x, dict): if isinstance(x, dict):
y.update(x) y.update(x)
else: else:
......
...@@ -43,6 +43,9 @@ def build_backbone(config, model_type): ...@@ -43,6 +43,9 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3 from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"] support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
else: else:
raise NotImplementedError raise NotImplementedError
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
pretrained_model=None,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
else:
base_model = base_model_class.from_pretrained(pretrained_model)
if type == 'ser':
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
class LayoutXLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained_model,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
head_mask=None,
labels=None)
return x[0]
class LayoutLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained_model,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
return x
class LayoutXLMForRe(NLPBaseModel):
def __init__(self,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
pretrained_model, checkpoints)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
position_ids=None,
head_mask=None,
entities=x[5],
relations=x[6])
return x
...@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization # step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None: if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer') reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay' reg_name = reg_config.pop('name')
if not hasattr(regularizer, reg_name):
reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)() reg = getattr(regularizer, reg_name)(**reg_config)()
else: else:
reg = None reg = None
......
...@@ -158,3 +158,38 @@ class Adadelta(object): ...@@ -158,3 +158,38 @@ class Adadelta(object):
name=self.name, name=self.name,
parameters=parameters) parameters=parameters)
return opt return opt
class AdamW(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
weight_decay=0.01,
grad_clip=None,
name=None,
lazy_mode=False,
**kwargs):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.learning_rate = learning_rate
self.weight_decay = 0.01 if weight_decay is None else weight_decay
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
def __call__(self, parameters):
opt = optim.AdamW(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
parameters=parameters)
return opt
...@@ -50,3 +50,18 @@ class L2Decay(object): ...@@ -50,3 +50,18 @@ class L2Decay(object):
def __call__(self): def __call__(self):
reg = paddle.regularizer.L2Decay(self.regularization_coeff) reg = paddle.regularizer.L2Decay(self.regularization_coeff)
return reg return reg
class ConstDecay(object):
"""
Const L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(ConstDecay, self).__init__()
self.regularization_coeff = factor
def __call__(self):
return self.regularization_coeff
...@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di ...@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
...@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None): ...@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode' 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class VQAReTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, **kwargs):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
if label is not None:
return self._metric(preds, label)
else:
return self._infer(preds, *args, **kwargs)
def _metric(self, preds, label):
return preds['pred_relations'], label[6], label[5]
def _infer(self, preds, *args, **kwargs):
ser_results = kwargs['ser_results']
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
pred_relations = preds['pred_relations']
# 进行 relations 到 ocr信息的转换
results = []
for pred_relation, ser_result, entity_idx_dict in zip(
pred_relations, ser_results, entity_idx_dict_batch):
result = []
used_tail_id = []
for relation in pred_relation:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from ppocr.utils.utility import load_vqa_bio_label_maps
class VQASerTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, class_path, **kwargs):
super(VQASerTokenLayoutLMPostProcess, self).__init__()
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
self.id2label_map_for_show = dict()
for key in self.label2id_map_for_draw:
val = self.label2id_map_for_draw[key]
if key == "O":
self.id2label_map_for_show[val] = key
if key.startswith("B-") or key.startswith("I-"):
self.id2label_map_for_show[val] = key[2:]
else:
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
return self._metric(preds, batch[1])
else:
return self._infer(preds, **kwargs)
def _metric(self, preds, label):
pred_idxs = preds.argmax(axis=2)
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
for i in range(pred_idxs.shape[0]):
for j in range(pred_idxs.shape[1]):
if label[i, j] != -100:
label_decode_out_list[i].append(self.id2label_map[label[i,
j]])
decode_out_list[i].append(self.id2label_map[pred_idxs[i,
j]])
return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip(
preds, attention_masks, segment_offset_ids, ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = pred[start_id:end_id]
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results
...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger): ...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_model(config, model, optimizer=None): def load_model(config, model, optimizer=None, model_type='det'):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
...@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None): ...@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
checkpoints = global_config.get('checkpoints') checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if model_type == 'vqa':
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
if os.path.exists(os.path.join(checkpoints, 'metric.states')):
with open(os.path.join(checkpoints, 'metric.states'),
'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
if optimizer is not None:
if checkpoints[-1] in ['/', '\\']:
checkpoints = checkpoints[:-1]
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
return best_model_dict
if checkpoints: if checkpoints:
if checkpoints.endswith('.pdparams'): if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '') checkpoints = checkpoints.replace('.pdparams', '')
...@@ -127,6 +154,7 @@ def save_model(model, ...@@ -127,6 +154,7 @@ def save_model(model,
optimizer, optimizer,
model_path, model_path,
logger, logger,
config,
is_best=False, is_best=False,
prefix='ppocr', prefix='ppocr',
**kwargs): **kwargs):
...@@ -135,13 +163,20 @@ def save_model(model, ...@@ -135,13 +163,20 @@ def save_model(model,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(model.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
else:
if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix)
else:
model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config # save metric and config
with open(model_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
if is_best: if is_best:
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix)) logger.info('save best model is to {}'.format(model_prefix))
else: else:
logger.info("save model in {}".format(model_prefix)) logger.info("save model in {}".format(model_prefix))
...@@ -77,4 +77,22 @@ def check_and_read_gif(img_path): ...@@ -77,4 +77,22 @@ def check_and_read_gif(img_path):
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1] imgvalue = frame[:, :, ::-1]
return imgvalue, True return imgvalue, True
return None, False return None, False
\ No newline at end of file
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image,
ocr_results,
font_path="doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
np.random.permutation(range(255)))
color_map = {
idx: (color[0][idx], color[1][idx], color[2][idx])
for idx in range(1, 255)
}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert('RGB')
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0]
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def draw_re_results(image,
result,
font_path="doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert('RGB')
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
draw.line([center_head, center_tail], fill=color_line, width=5)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) | |PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | |PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
## 3. KIE模型 ## 3. KIE模型
......
...@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下 我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
| 模型 | 任务 | f1 | 模型下载地址 | | 模型 | 任务 | hmean | 模型下载地址 |
|:---:|:---:|:---:| :---:| |:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | | LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) | | LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) | | LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
...@@ -65,10 +65,10 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -65,10 +65,10 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
pip3 install --upgrade pip pip3 install --upgrade pip
# GPU安装 # GPU安装
python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装 # CPU安装
python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
``` ```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
...@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple ...@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
- **(1)pip快速安装PaddleOCR whl包(仅预测)** - **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash ```bash
pip install paddleocr python3 -m pip install paddleocr
``` ```
- **(2)下载VQA源码(预测+训练)** - **(2)下载VQA源码(预测+训练)**
...@@ -93,18 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR ...@@ -93,18 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。 # 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
``` ```
- **(3)安装PaddleNLP** - **(3)安装VQA的`requirements`**
```bash ```bash
pip3 install "paddlenlp>=2.2.1" python3 -m pip install -r ppstructure/vqa/requirements.txt
```
- **(4)安装VQA的`requirements`**
```bash
cd ppstructure/vqa
pip install -r requirements.txt
``` ```
## 4. 使用 ## 4. 使用
...@@ -112,6 +104,10 @@ pip install -r requirements.txt ...@@ -112,6 +104,10 @@ pip install -r requirements.txt
### 4.1 数据和预训练模型准备 ### 4.1 数据和预训练模型准备
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
* 下载处理好的数据集
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) 处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
...@@ -121,98 +117,62 @@ pip install -r requirements.txt ...@@ -121,98 +117,62 @@ pip install -r requirements.txt
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
``` ```
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py) * 转换数据集
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。 若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
```bash
python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
```
### 4.2 SER任务 ### 4.2 SER任务
* 启动训练 启动训练之前,需要修改下面的四个字段
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
* 启动训练
```shell ```shell
python3.7 train_ser.py \ CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
--model_name_or_path "layoutxlm-base-uncased" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。 最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练 * 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell ```shell
python3.7 train_ser.py \ CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
--model_name_or_path "model_path" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--num_workers 8 \
--seed 2048 \
--resume
``` ```
* 评估 * 评估
```shell
export CUDA_VISIBLE_DEVICES=0
python3 eval_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--output_dir "output/ser/" \
--seed 2048
```
最终会打印出`precision`, `recall`, `f1`等指标
* 使用评估集合中提供的OCR识别结果进行预测 评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
python3.7 infer_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--output_dir "output/ser/" \
--infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt` * 使用`OCR引擎 + SER`串联预测
* 使用`OCR引擎 + SER`串联结果 使用如下命令即可完成`OCR引擎 + SER`的串联预测
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ Global.infer_img=ppstructure/vqa/images/input/zh_val_42.jpg
python3.7 infer_ser_e2e.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output/ser_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR引擎 + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
...@@ -223,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor ...@@ -223,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
* 启动训练 * 启动训练
```shell 启动训练之前,需要修改下面的四个字段
export CUDA_VISIBLE_DEVICES=0
python3 train_re.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \
--seed 2048
```
* 恢复训练 1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
python3 train_re.py \
--model_name_or_path "model_path" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--num_train_epochs 2 \
--eval_steps 10 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \
--seed 2048 \
--resume
``` ```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。 最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
* 评估
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
python3 eval_re.py \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标
* 评估
* 使用评估集合中提供的OCR识别结果进行预测 评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
python3 infer_re.py \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 1 \
--seed 2048
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt` * 使用`OCR引擎 + SER + RE`串联预测
* 使用`OCR引擎 + SER + RE`串联结果
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_re_e2e.py \ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_re_pretrained/ Global.infer_img=ppstructure/vqa/images/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output/ser_re_e2e/" \
--infer_imgs "images/input/zh_val_21.jpg"
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
## 参考链接 ## 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator
from metric import re_score
from ppocr.utils.logging import get_logger
def cal_metric(re_preds, re_labels, entities):
gt_relations = []
for b in range(len(re_labels)):
rel_sent = []
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (entities[b]["start"][rel["head_id"]],
entities[b]["end"][rel["head_id"]])
rel["head_type"] = entities[b]["label"][rel["head_id"]]
rel["tail_id"] = tail
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
entities[b]["end"][rel["tail_id"]])
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
return re_metrics
def evaluate(model, eval_dataloader, logger, prefix=""):
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
re_preds = []
re_labels = []
entities = []
eval_loss = 0.0
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
outputs = model(**batch)
loss = outputs['loss'].mean().item()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss))
eval_loss += loss
re_preds.extend(outputs['pred_relations'])
re_labels.extend(batch['relations'])
entities.extend(batch['entities'])
re_metrics = cal_metric(re_preds, re_labels, entities)
re_metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"f1": re_metrics["ALL"]["f1"],
}
model.train()
return re_metrics
def eval(args):
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
shuffle=False,
collate_fn=DataCollator())
results = evaluate(model, eval_dataloader, logger)
logger.info("eval results: {}".format(results))
if __name__ == "__main__":
args = parse_args()
eval(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import copy
import logging
import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from losses import SERLoss
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def eval(args):
logger = get_logger()
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
loss_class = SERLoss(len(label2id_map))
results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
label2id_map, id2label_map, pad_token_label_id,
logger)
logger.info(results)
def evaluate(args,
model,
tokenizer,
loss_class,
eval_dataloader,
label2id_map,
id2label_map,
pad_token_label_id,
logger,
prefix=""):
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
outputs = model(**batch)
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
loss = loss.mean()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss.numpy()[0]))
eval_loss += loss.item()
nb_eval_steps += 1
if preds is None:
preds = outputs.numpy()
out_label_ids = labels.numpy()
else:
preds = np.append(preds, outputs.numpy(), axis=0)
out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2)
# label_map = {i: label.upper() for i, label in enumerate(labels)}
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]
for i in range(out_label_ids.shape[0]):
for j in range(out_label_ids.shape[1]):
if out_label_ids[i, j] != pad_token_label_id:
out_label_list[i].append(id2label_map[out_label_ids[i][j]])
preds_list[i].append(id2label_map[preds[i][j]])
results = {
"loss": eval_loss,
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
}
with open(
os.path.join(args.output_dir, "test_gt.txt"), "w",
encoding='utf-8') as fout:
for lbl in out_label_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
with open(
os.path.join(args.output_dir, "test_pred.txt"), "w",
encoding='utf-8') as fout:
for lbl in preds_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
report = classification_report(out_label_list, preds_list)
logger.info("\n" + report)
logger.info("***** Eval results %s *****", prefix)
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
model.train()
return results, preds_list
if __name__ == "__main__":
args = parse_args()
eval(args)
...@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None): ...@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
print("===ok====") print("===ok====")
transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json") def parser_args():
import argparse
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument(
"--ori_gt_path", type=str, required=True, help='origin xfun gt path')
parser.add_argument(
"--output_path", type=str, required=True, help='path to save')
args = parser.parse_args()
return args
args = parser_args()
transfer_xfun_data(args.ori_gt_path, args.output_path)
export CUDA_VISIBLE_DEVICES=6
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --max_seq_length 512 \
# --output_dir "output_res_e2e/" \
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --re_model_name_or_path "output/re_test/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e_train/" \
# --infer_imgs "images/input/zh_val_21.jpg"
# python3.7 infer_ser.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --output_dir "ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_21.jpg" \
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
python3.7 infer_ser.py \
--model_name_or_path "output/ser_new/best_model" \
--ser_model_type "LayoutXLM" \
--output_dir "ser_new/" \
--infer_imgs "images/input/zh_val_21.jpg" \
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_new/best_model" \
# --ser_model_type "LayoutXLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_new/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3 infer_re.py \
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
# --max_seq_length 512 \
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
# --label_map_path 'labels/labels_ser.txt' \
# --output_dir "output_res" \
# --per_gpu_eval_batch_size 1 \
# --seed 2048
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --re_model_name_or_path "output/re_new/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e/" \
# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator
from ppocr.utils.logging import get_logger
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
# 读取gt的oct数据
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
logger.info("[Infer] process: {}/{}, save result to {}".format(
idx, len(eval_dataloader), save_img_path))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relations in pred_relations:
for relation in relations:
if relation['tail_id'] in used_tail_id:
continue
if relation['head_id'] not in ocr_info or relation[
'tail_id'] not in ocr_info:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
cv2.imwrite(save_img_path, img_show)
def load_ocr(img_folder, json_path):
import json
d = []
with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
info_dict = json.loads(info_str)
info_dict['image_path'] = os.path.join(img_folder, image_name)
d.append(info_dict)
return d
def filter_bg_by_txt(ocr_info, batch, tokenizer):
entities = batch['entities'][0]
input_ids = batch['input_ids'][0]
new_info_dict = {}
for i in range(len(entities['start'])):
entitie_head = entities['start'][i]
entitie_tail = entities['end'][i]
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
txt = tokenizer.convert_tokens_to_string(txt)
for i, info in enumerate(ocr_info):
if info['text'] == txt:
new_info_dict[i] = info
return new_info_dict
def post_process(pred_relations, ocr_info, img):
result = []
for relations in pred_relations:
for relation in relations:
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
return result
def draw_re(result, image_path, output_folder):
img = cv2.imread(image_path)
from matplotlib import pyplot as plt
for ocr_info_head, ocr_info_tail in result:
cv2.rectangle(
img,
tuple(ocr_info_head['bbox'][:2]),
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
thickness=2)
cv2.rectangle(
img,
tuple(ocr_info_tail['bbox'][:2]),
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
thickness=2)
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
cv2.line(
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
plt.imshow(img)
plt.savefig(
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
# plt.show()
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
from copy import deepcopy
import paddle
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def pad_sentences(tokenizer,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding with larger size, reshape is carried out
max_seq_len = (
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [tokenizer.pad_token_id] * difference
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
tokenizer.padding_side)
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def split_page(encoded_inputs, max_seq_len=512):
"""
truncate is often used in training process
"""
for key in encoded_inputs:
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
else: # for bbox
encoded_inputs[key] = encoded_inputs[key].reshape(
[-1, max_seq_len, 4])
return encoded_inputs
def preprocess(
tokenizer,
ori_img,
ocr_info,
img_size=(224, 224),
pad_token_label_id=-100,
max_seq_len=512,
add_special_ids=False,
return_attention_mask=True, ):
ocr_info = deepcopy(ocr_info)
height = ori_img.shape[0]
width = ori_img.shape[1]
img = cv2.resize(ori_img,
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = info["bbox"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
}
encoded_inputs = pad_sentences(
tokenizer,
encoded_inputs,
max_seq_len=max_seq_len,
return_attention_mask=return_attention_mask)
encoded_inputs = split_page(encoded_inputs)
fake_bs = encoded_inputs["input_ids"].shape[0]
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
[fake_bs] + list(img.shape))
encoded_inputs["segment_offset_id"] = segment_offset_id
return encoded_inputs
def postprocess(attention_mask, preds, label_map_path):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds = np.argmax(preds, axis=2)
_, label_map = get_bio_label_maps(label_map_path)
preds_list = [[] for _ in range(preds.shape[0])]
# keep batch info
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
if attention_mask[i][j] == 1:
preds_list[i].append(label_map[preds[i][j]])
return preds_list
def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
preds_list):
# must ensure the preds_list is generated from the same image
preds = [p for pred in preds_list for p in pred]
label2id_map, _ = get_bio_label_maps(label_map_path)
for key in label2id_map:
if key.startswith("I-"):
label2id_map[key] = label2id_map["B" + key[1:]]
id2label_map = dict()
for key in label2id_map:
val = label2id_map[key]
if key == "O":
id2label_map[val] = key
if key.startswith("B-") or key.startswith("I-"):
id2label_map[val] = key[2:]
else:
id2label_map[val] = key
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = preds[start_id:end_id]
curr_pred = [label2id_map[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = id2label_map[pred_id]
return ocr_info
@paddle.no_grad()
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.eval()
# load ocr results json
ocr_results = dict()
with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
img_name, json_info = line.split("\t")
ocr_results[os.path.basename(img_name)] = json.loads(json_info)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(args.output_dir,
os.path.basename(img_path))
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
inputs = preprocess(
tokenizer=tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
if args.ser_model_type == 'LayoutLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info(
args.label_map_path, ocr_info, inputs["segment_offset_id"],
preds)
fout.write(img_path + "\t" + json.dumps(
{
"ocr_info": ocr_info,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
cv2.imwrite(save_img_path, img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def parse_ocr_info_for_ser(ocr_result):
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
class SerPredictor(object):
def __init__(self, args):
self.args = args
self.max_seq_length = args.max_seq_length
# init ser token and model
tokenizer_class, base_model_class, model_class = MODELS[
args.ser_model_type]
self.tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path)
self.model = model_class.from_pretrained(args.model_name_or_path)
self.model.eval()
# init ocr_engine
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(
rec_model_dir=args.rec_model_dir,
det_model_dir=args.det_model_dir,
use_angle_cls=False,
show_log=False)
# init dict
label2id_map, self.id2label_map = get_bio_label_maps(
args.label_map_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
def __call__(self, img):
ocr_result = self.ocr_engine.ocr(img, cls=False)
ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess(
tokenizer=self.tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=self.max_seq_length)
if self.args.ser_model_type == 'LayoutLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif self.args.ser_model_type == 'LayoutXLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
self.label2id_map_for_draw)
return ocr_info, inputs
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_engine = SerPredictor(args)
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
result, _ = ser_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"ser_resule": result,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, result)
cv2.imwrite(save_img_path, img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor
def make_input(ser_input, ser_result, max_seq_len=512):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
entities = ser_input['entities'][0]
assert len(entities) == len(ser_result)
# entities
start = []
end = []
label = []
entity_idx_dict = {}
for i, (res, entity) in enumerate(zip(ser_result, entities)):
if res['pred'] == 'O':
continue
entity_idx_dict[len(start)] = i
start.append(entity['start'])
end.append(entity['end'])
label.append(entities_labels[res['pred']])
entities = dict(start=start, end=end, label=label)
# relations
head = []
tail = []
for i in range(len(entities["label"])):
for j in range(len(entities["label"])):
if entities["label"][i] == 1 and entities["label"][j] == 2:
head.append(i)
tail.append(j)
relations = dict(head=head, tail=tail)
batch_size = ser_input["input_ids"].shape[0]
entities_batch = []
relations_batch = []
for b in range(batch_size):
entities_batch.append(entities)
relations_batch.append(relations)
ser_input['entities'] = entities_batch
ser_input['relations'] = relations_batch
ser_input.pop('segment_offset_id')
return ser_input, entity_idx_dict
class SerReSystem(object):
def __init__(self, args):
self.ser_engine = SerPredictor(args)
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
args.re_model_name_or_path)
self.model = LayoutXLMForRelationExtraction.from_pretrained(
args.re_model_name_or_path)
self.model.eval()
def __call__(self, img):
ser_result, ser_inputs = self.ser_engine(img)
re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
re_result = self.model(**re_input)
pred_relations = re_result['pred_relations'][0]
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relation in pred_relations:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
return result
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_re_engine = SerReSystem(args)
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
result = ser_re_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"result": result,
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result)
cv2.imwrite(save_img_path, img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import numpy as np
import logging
logger = logging.getLogger(__name__)
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(
os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(
folder,
max(checkpoints,
key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
def re_score(pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return scores
sentencepiece sentencepiece
yacs yacs
seqeval seqeval
\ No newline at end of file paddlenlp>=2.2.1
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from data_collator import DataCollator
from eval_re import evaluate
from ppocr.utils.logging import get_logger
def train(args):
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
rank = paddle.distributed.get_rank()
distributed = paddle.distributed.get_world_size() > 1
print_arguments(args, logger)
# Added here for reproducibility (even between python 2 and 3)
set_seed(args.seed)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode
if distributed:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
if not args.resume:
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction(model, dropout=None)
logger.info('train from scratch')
else:
logger.info('resume from {}'.format(args.model_name_or_path))
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
# dist mode
if distributed:
model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=DataCollator())
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
shuffle=False,
collate_fn=DataCollator())
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
optimizer = paddle.optimizer.Adam(
learning_rate=args.learning_rate,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
grad_clip=grad_clip,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = {}".format(len(train_dataset)))
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
logger.info(" Instantaneous batch size per GPU = {}".format(
args.per_gpu_train_batch_size))
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = {}".
format(args.per_gpu_train_batch_size *
paddle.distributed.get_world_size()))
logger.info(" Total optimization steps = {}".format(t_total))
global_step = 0
model.clear_gradients()
train_dataloader_len = len(train_dataloader)
best_metirc = {'f1': 0}
model.train()
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
print_step = 1
for epoch in range(int(args.num_train_epochs)):
for step, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - reader_start
train_start = time.time()
outputs = model(**batch)
train_run_cost += time.time() - train_start
# model outputs are always tuple in ppnlp (see doc)
loss = outputs['loss']
loss = loss.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
# lr_scheduler.step() # Update learning rate schedule
global_step += 1
total_samples += batch['image'].shape[0]
if rank == 0 and step % print_step == 0:
logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch, args.num_train_epochs, step,
train_dataloader_len, global_step,
np.mean(loss.numpy()),
optimizer.get_lr(), train_reader_cost / print_step, (
train_reader_cost + train_run_cost) / print_step,
total_samples / print_step, total_samples / (
train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(model, eval_dataloader, logger)
if results['f1'] >= best_metirc['f1']:
best_metirc = results
output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("eval results: {}".format(results))
logger.info("best_metirc: {}".format(best_metirc))
reader_start = time.time()
if rank == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "latest_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(output_dir))
logger.info("best_metirc: {}".format(best_metirc))
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
train(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import copy
import logging
import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate
from losses import SERLoss
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def train(args):
os.makedirs(args.output_dir, exist_ok=True)
rank = paddle.distributed.get_rank()
distributed = paddle.distributed.get_world_size() > 1
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
loss_class = SERLoss(len(label2id_map))
pad_token_label_id = loss_class.ignore_index
# dist mode
if distributed:
paddle.distributed.init_parallel_env()
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
if not args.resume:
base_model = base_model_class.from_pretrained(args.model_name_or_path)
model = model_class(
base_model, num_classes=len(label2id_map), dropout=None)
logger.info('train from scratch')
else:
logger.info('resume from {}'.format(args.model_name_or_path))
model = model_class.from_pretrained(args.model_name_or_path)
# dist mode
if distributed:
model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d",
args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed) = %d",
args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), )
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss = 0.0
set_seed(args.seed)
best_metrics = None
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
print_step = 1
model.train()
for epoch_id in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - reader_start
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
train_start = time.time()
outputs = model(**batch)
train_run_cost += time.time() - train_start
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
# model outputs are always tuple in ppnlp (see doc)
loss = loss.mean()
loss.backward()
tr_loss += loss.item()
optimizer.step()
lr_scheduler.step() # Update learning rate schedule
optimizer.clear_grad()
global_step += 1
total_samples += batch['input_ids'].shape[0]
if rank == 0 and step % print_step == 0:
logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch_id, args.num_train_epochs, step,
len(train_dataloader), global_step,
loss.numpy()[0],
lr_scheduler.get_lr(), train_reader_cost /
print_step, (train_reader_cost + train_run_cost) /
print_step, total_samples / print_step, total_samples
/ (train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results, _ = evaluate(args, model, tokenizer, loss_class,
eval_dataloader, label2id_map,
id2label_map, pad_token_label_id, logger)
if best_metrics is None or results["f1"] >= best_metrics["f1"]:
best_metrics = copy.deepcopy(results)
output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
epoch_id, args.num_train_epochs, step,
len(train_dataloader), results))
if best_metrics is not None:
logger.info("best metrics: {}".format(best_metrics))
reader_start = time.time()
if rank == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "latest_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(output_dir))
return global_step, tr_loss / global_step
if __name__ == "__main__":
args = parse_args()
train(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import cv2
import random
import numpy as np
import imghdr
from copy import deepcopy
import paddle
from PIL import Image, ImageDraw, ImageFont
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def get_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
def draw_ser_results(image,
ocr_results,
font_path="../../doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
np.random.permutation(range(255)))
color_map = {
idx: (color[0][idx], color[1][idx], color[2][idx])
for idx in range(1, 255)
}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0]
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def draw_re_results(image,
result,
font_path="../../doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
draw.line([center_head, center_tail], fill=color_line, width=5)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
# pad sentences
def pad_sentences(tokenizer,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding with larger size, reshape is carried out
max_seq_len = (
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [tokenizer.pad_token_id] * difference
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def split_page(encoded_inputs, max_seq_len=512):
"""
truncate is often used in training process
"""
for key in encoded_inputs:
if key == 'entities':
encoded_inputs[key] = [encoded_inputs[key]]
continue
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
else: # for bbox
encoded_inputs[key] = encoded_inputs[key].reshape(
[-1, max_seq_len, 4])
return encoded_inputs
def preprocess(
tokenizer,
ori_img,
ocr_info,
img_size=(224, 224),
pad_token_label_id=-100,
max_seq_len=512,
add_special_ids=False,
return_attention_mask=True, ):
ocr_info = deepcopy(ocr_info)
height = ori_img.shape[0]
width = ori_img.shape[1]
img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
entities = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = info["bbox"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
# for re
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": "O",
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
"entities": entities
}
encoded_inputs = pad_sentences(
tokenizer,
encoded_inputs,
max_seq_len=max_seq_len,
return_attention_mask=return_attention_mask)
encoded_inputs = split_page(encoded_inputs)
fake_bs = encoded_inputs["input_ids"].shape[0]
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
[fake_bs] + list(img.shape))
encoded_inputs["segment_offset_id"] = segment_offset_id
return encoded_inputs
def postprocess(attention_mask, preds, id2label_map):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds = np.argmax(preds, axis=2)
preds_list = [[] for _ in range(preds.shape[0])]
# keep batch info
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
if attention_mask[i][j] == 1:
preds_list[i].append(id2label_map[preds[i][j]])
return preds_list
def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
label2id_map_for_draw):
# must ensure the preds_list is generated from the same image
preds = [p for pred in preds_list for p in pred]
id2label_map = dict()
for key in label2id_map_for_draw:
val = label2id_map_for_draw[key]
if key == "O":
id2label_map[val] = key
if key.startswith("B-") or key.startswith("I-"):
id2label_map[val] = key[2:]
else:
id2label_map[val] = key
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = preds[start_id:end_id]
curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
return ocr_info
def print_arguments(args, logger=None):
print_func = logger.info if logger is not None else print
"""print arguments"""
print_func('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print_func('%s: %s' % (arg, value))
print_func('------------------------------------------------')
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
# yapf: disable
parser.add_argument("--model_name_or_path",
default=None, type=str, required=True,)
parser.add_argument("--ser_model_type",
default='LayoutXLM', type=str)
parser.add_argument("--re_model_name_or_path",
default=None, type=str, required=False,)
parser.add_argument("--train_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--train_label_path", default=None,
type=str, required=False,)
parser.add_argument("--eval_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--eval_label_path", default=None,
type=str, required=False,)
parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",)
parser.add_argument("--num_workers", default=8, type=int,)
parser.add_argument("--per_gpu_train_batch_size", default=8,
type=int, help="Batch size per GPU/CPU for training.",)
parser.add_argument("--per_gpu_eval_batch_size", default=8,
type=int, help="Batch size per GPU/CPU for eval.",)
parser.add_argument("--learning_rate", default=5e-5,
type=float, help="The initial learning rate for Adam.",)
parser.add_argument("--weight_decay", default=0.0,
type=float, help="Weight decay if we apply some.",)
parser.add_argument("--adam_epsilon", default=1e-8,
type=float, help="Epsilon for Adam optimizer.",)
parser.add_argument("--max_grad_norm", default=1.0,
type=float, help="Max gradient norm.",)
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.",)
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.",)
parser.add_argument("--eval_steps", type=int, default=10,
help="eval every X updates steps.",)
parser.add_argument("--seed", type=int, default=2048,
help="random seed for initialization",)
parser.add_argument("--rec_model_dir", default=None, type=str, )
parser.add_argument("--det_model_dir", default=None, type=str, )
parser.add_argument(
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
parser.add_argument("--resume", action='store_true')
parser.add_argument("--ocr_json_path", default=None,
type=str, required=False, help="ocr prediction results")
# yapf: enable
args = parser.parse_args()
return args
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import cv2
import numpy as np
import paddle
import copy
from paddle.io import Dataset
__all__ = ["XFUNDataset"]
class XFUNDataset(Dataset):
"""
Example:
print("=====begin to build dataset=====")
from paddlenlp.transformers import LayoutXLMTokenizer
tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
tok_res = tokenizer.tokenize("Maribyrnong")
# res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
dataset = XfunDatasetForSer(
tokenizer,
data_dir="./zh.val/",
label_path="zh.val/xfun_normalize_val.json",
img_size=(224,224))
print(len(dataset))
data = dataset[0]
print(data.keys())
print("input_ids: ", data["input_ids"])
print("labels: ", data["labels"])
print("token_type_ids: ", data["token_type_ids"])
print("words_list: ", data["words_list"])
print("image shape: ", data["image"].shape)
"""
def __init__(self,
tokenizer,
data_dir,
label_path,
contains_re=False,
label2id_map=None,
img_size=(224, 224),
pad_token_label_id=None,
add_special_ids=False,
return_attention_mask=True,
load_mode='all',
max_seq_len=512):
super().__init__()
self.tokenizer = tokenizer
self.data_dir = data_dir
self.label_path = label_path
self.contains_re = contains_re
self.label2id_map = label2id_map
self.img_size = img_size
self.pad_token_label_id = pad_token_label_id
self.add_special_ids = add_special_ids
self.return_attention_mask = return_attention_mask
self.load_mode = load_mode
self.max_seq_len = max_seq_len
if self.pad_token_label_id is None:
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.all_lines = self.read_all_lines()
self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
self.return_keys = {
'bbox': {
'type': 'np',
'dtype': 'int64'
},
'input_ids': {
'type': 'np',
'dtype': 'int64'
},
'labels': {
'type': 'np',
'dtype': 'int64'
},
'attention_mask': {
'type': 'np',
'dtype': 'int64'
},
'image': {
'type': 'np',
'dtype': 'float32'
},
'token_type_ids': {
'type': 'np',
'dtype': 'int64'
},
'entities': {
'type': 'dict'
},
'relations': {
'type': 'dict'
}
}
if load_mode == "all":
self.encoded_inputs_all = self._parse_label_file_all()
def pad_sentences(self,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if self.tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[self.tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [self.tokenizer.pad_token_id] * difference
encoded_inputs["labels"] = encoded_inputs[
"labels"] + [self.pad_token_label_id] * difference
encoded_inputs["bbox"] = encoded_inputs[
"bbox"] + [[0, 0, 0, 0]] * difference
elif self.tokenizer.padding_side == 'left':
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [
1
] * len(encoded_inputs["input_ids"])
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
[self.tokenizer.pad_token_type_id] * difference +
encoded_inputs["token_type_ids"])
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = [
1
] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [
self.tokenizer.pad_token_id
] * difference + encoded_inputs["input_ids"]
encoded_inputs["labels"] = [
self.pad_token_label_id
] * difference + encoded_inputs["labels"]
encoded_inputs["bbox"] = [
[0, 0, 0, 0]
] * difference + encoded_inputs["bbox"]
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def truncate_inputs(self, encoded_inputs, max_seq_len=512):
for key in encoded_inputs:
if key == "sample_id":
continue
length = min(len(encoded_inputs[key]), max_seq_len)
encoded_inputs[key] = encoded_inputs[key][:length]
return encoded_inputs
def read_all_lines(self, ):
with open(self.label_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
return lines
def _parse_label_file_all(self):
"""
parse all samples
"""
encoded_inputs_all = []
for line in self.all_lines:
encoded_inputs_all.extend(self._parse_label_file(line))
return encoded_inputs_all
def _parse_label_file(self, line):
"""
parse single sample
"""
image_name, info_str = line.split("\t")
image_path = os.path.join(self.data_dir, image_name)
def add_imgge_path(x):
x['image_path'] = image_path
return x
encoded_inputs = self._read_encoded_inputs_sample(info_str)
if self.contains_re:
encoded_inputs = self._chunk_re(encoded_inputs)
else:
encoded_inputs = self._chunk_ser(encoded_inputs)
encoded_inputs = list(map(add_imgge_path, encoded_inputs))
return encoded_inputs
def _read_encoded_inputs_sample(self, info_str):
"""
parse label info
"""
# read text info
info_dict = json.loads(info_str)
height = info_dict["height"]
width = info_dict["width"]
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
gt_label_list = []
if self.contains_re:
# for re
entities = []
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
for info in info_dict["ocr_info"]:
if self.contains_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# x1, y1, x2, y2
bbox = info["bbox"]
label = info["label"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
gt_label = []
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
if label.lower() == "other":
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
if self.contains_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
gt_label_list.extend(gt_label)
words_list.append(text)
encoded_inputs = {
"input_ids": input_ids_list,
"labels": gt_label_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
# "words_list": words_list,
}
encoded_inputs = self.pad_sentences(
encoded_inputs,
max_seq_len=self.max_seq_len,
return_attention_mask=self.return_attention_mask)
encoded_inputs = self.truncate_inputs(encoded_inputs)
if self.contains_re:
relations = self._relations(entities, relations, id2label,
empty_entity, entity_id_to_index_map)
encoded_inputs['relations'] = relations
encoded_inputs['entities'] = entities
return encoded_inputs
def _chunk_ser(self, encoded_inputs):
encoded_inputs_all = []
seq_len = len(encoded_inputs['input_ids'])
chunk_size = 512
for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
chunk_beg = index
chunk_end = min(index + chunk_size, seq_len)
encoded_inputs_example = {}
for key in encoded_inputs:
encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
chunk_end]
encoded_inputs_all.append(encoded_inputs_example)
return encoded_inputs_all
def _chunk_re(self, encoded_inputs):
# prepare data
entities = encoded_inputs.pop('entities')
relations = encoded_inputs.pop('relations')
encoded_inputs_all = []
chunk_size = 512
for chunk_id, index in enumerate(
range(0, len(encoded_inputs["input_ids"]), chunk_size)):
item = {}
for k in encoded_inputs:
item[k] = encoded_inputs[k][index:index + chunk_size]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + chunk_size and
index <= entity["end"] < index + chunk_size):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + chunk_size and
index <= relation["end_index"] < index + chunk_size):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": reformat(entities_in_this_span),
"relations": reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
return encoded_inputs_all
def _relations(self, entities, relations, id2label, empty_entity,
entity_id_to_index_map):
"""
build relations
"""
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": get_relation_span(rel, entities)[0],
"end_index": get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
return relations
def load_img(self, image_path):
# read img
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
resize_h, resize_w = self.img_size
im_shape = img.shape[0:2]
im_scale_y = resize_h / im_shape[0]
im_scale_x = resize_w / im_shape[1]
img_new = cv2.resize(
img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
img_new = img_new / 255.0
img_new -= mean
img_new /= std
img = img_new.transpose((2, 0, 1))
return img
def __getitem__(self, idx):
if self.load_mode == "all":
data = copy.deepcopy(self.encoded_inputs_all[idx])
else:
data = self._parse_label_file(self.all_lines[idx])[0]
image_path = data.pop('image_path')
data["image"] = self.load_img(image_path)
return_data = {}
for k, v in data.items():
if k in self.return_keys:
if self.return_keys[k]['type'] == 'np':
v = np.array(v, dtype=self.return_keys[k]['dtype'])
return_data[k] = v
return return_data
def __len__(self, ):
if self.load_mode == "all":
return len(self.encoded_inputs_all)
else:
return len(self.all_lines)
def get_relation_span(rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)
def reformat(data):
new_data = {}
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
...@@ -13,4 +13,3 @@ lxml ...@@ -13,4 +13,3 @@ lxml
premailer premailer
openpyxl openpyxl
fasttext==0.9.1 fasttext==0.9.1
paddlenlp>=2.2.1
...@@ -61,7 +61,8 @@ def main(): ...@@ -61,7 +61,8 @@ def main():
else: else:
model_type = None model_type = None
best_model_dict = load_model(config, model) best_model_dict = load_model(
config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer): ...@@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_model(config, model, optimizer) pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册