diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d117fdeb16e1c0e90bf6ec89924e414fc764249
--- /dev/null
+++ b/ppstructure/vqa/README.md
@@ -0,0 +1,182 @@
+# 视觉问答(VQA)
+
+VQA主要特性如下:
+
+- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
+- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对)
+- 支持SER任务与OCR引擎联合的端到端系统预测与评估。
+- 支持SER任务和RE任务的自定义训练
+
+
+本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
+包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
+
+## 1. 效果演示
+
+**注意:** 测试图片来源于XFUN数据集。
+
+### 1.1 SER
+
+
+
+
+
+
+
+
+
+其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
+
+
+### 1.2 RE
+
+* Coming soon!
+
+
+
+## 2. 安装
+
+### 2.1 安装依赖
+
+- **(1) 安装PaddlePaddle**
+
+```bash
+pip3 install --upgrade pip
+
+# GPU安装
+python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
+
+# CPU安装
+python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
+
+```
+更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
+
+
+### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
+
+- **(1)pip快速安装PaddleOCR whl包(仅预测)**
+
+```bash
+pip install "paddleocr>=2.2" # 推荐使用2.2+版本
+```
+
+- **(2)下载VQA源码(预测+训练)**
+
+```bash
+【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
+
+# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
+git clone https://gitee.com/paddlepaddle/PaddleOCR
+
+# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
+```
+
+- **(3)安装PaddleNLP**
+
+```bash
+# 需要使用PaddleNLP最新的代码版本进行安装
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip install -e .
+```
+
+
+- **(4)安装VQA的`requirements`**
+
+```bash
+pip install -r requirements.txt
+```
+
+## 3. 使用
+
+
+### 3.1 数据和预训练模型准备
+
+处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
+
+
+下载并解压该数据集,解压后将数据集放置在当前目录下。
+
+```shell
+wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+```
+
+如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
+
+如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。
+
+* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
+* RE任务预训练模型下载链接:coming soon!
+
+
+### 3.2 SER任务
+
+* 启动训练
+
+```shell
+python train_ser.py \
+ --model_name_or_path "layoutxlm-base-uncased" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --num_train_epochs 200 \
+ --eval_steps 10 \
+ --save_steps 500 \
+ --output_dir "./output/ser/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --evaluate_during_training \
+ --seed 2048
+```
+
+最终会打印出`precision`, `recall`, `f1`等指标,如下所示。
+
+```
+best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
+```
+
+模型和训练日志会保存在`./output/ser/`文件夹中。
+
+* 使用评估集合中提供的OCR识别结果进行预测
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 infer_ser.py \
+ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
+ --output_dir "output_res/" \
+ --infer_imgs "XFUND/zh_val/image/" \
+ --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
+```
+
+最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
+
+* 使用`OCR引擎 + SER`串联结果
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 infer_ser_e2e.py \
+ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
+ --max_seq_length 512 \
+ --output_dir "output_res_e2e/"
+```
+
+* 对`OCR引擎 + SER`预测系统进行端到端评估
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
+```
+
+
+3.3 RE任务
+
+coming soon!
+
+
+## 参考链接
+
+- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
+- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
+- XFUND dataset, https://github.com/doc-analysis/XFUND
diff --git a/ppstructure/vqa/helper/eval_with_label_end2end.py b/ppstructure/vqa/helper/eval_with_label_end2end.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8dd3e0ad437e51e21ebc53daeec9fdf9aa76b63
--- /dev/null
+++ b/ppstructure/vqa/helper/eval_with_label_end2end.py
@@ -0,0 +1,262 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import sys
+# import Polygon
+import shapely
+from shapely.geometry import Polygon
+import numpy as np
+from collections import defaultdict
+import operator
+import editdistance
+import argparse
+import json
+import copy
+
+
+def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
+ # img/zh_val_0.jpg {
+ # "height": 3508,
+ # "width": 2480,
+ # "ocr_info": [
+ # {"text": "Maribyrnong", "label": "other", "bbox": [1958, 144, 2184, 198]},
+ # {"text": "CITYCOUNCIL", "label": "other", "bbox": [2052, 183, 2171, 214]},
+ # ]
+ assert fp_type in ["gt", "pred"]
+ key = "label" if fp_type == "gt" else "pred"
+ res_dict = dict()
+ with open(fp, "r") as fin:
+ lines = fin.readlines()
+
+ for _, line in enumerate(lines):
+ img_path, info = line.strip().split("\t")
+ # get key
+ image_name = os.path.basename(img_path)
+ res_dict[image_name] = []
+ # get infos
+ json_info = json.loads(info)
+ for single_ocr_info in json_info["ocr_info"]:
+ label = single_ocr_info[key].upper()
+ if label in ["O", "OTHERS", "OTHER"]:
+ label = "O"
+ if ignore_background and label == "O":
+ continue
+ single_ocr_info["label"] = label
+ res_dict[image_name].append(copy.deepcopy(single_ocr_info))
+ return res_dict
+
+
+def polygon_from_str(polygon_points):
+ """
+ Create a shapely polygon object from gt or dt line.
+ """
+ polygon_points = np.array(polygon_points).reshape(4, 2)
+ polygon = Polygon(polygon_points).convex_hull
+ return polygon
+
+
+def polygon_iou(poly1, poly2):
+ """
+ Intersection over union between two shapely polygons.
+ """
+ if not poly1.intersects(
+ poly2): # this test is fast and can accelerate calculation
+ iou = 0
+ else:
+ try:
+ inter_area = poly1.intersection(poly2).area
+ union_area = poly1.area + poly2.area - inter_area
+ iou = float(inter_area) / union_area
+ except shapely.geos.TopologicalError:
+ # except Exception as e:
+ # print(e)
+ print('shapely.geos.TopologicalError occured, iou set to 0')
+ iou = 0
+ return iou
+
+
+def ed(args, str1, str2):
+ if args.ignore_space:
+ str1 = str1.replace(" ", "")
+ str2 = str2.replace(" ", "")
+ if args.ignore_case:
+ str1 = str1.lower()
+ str2 = str2.lower()
+ return editdistance.eval(str1, str2)
+
+
+def convert_bbox_to_polygon(bbox):
+ """
+ bbox : [x1, y1, x2, y2]
+ output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
+ """
+ xmin, ymin, xmax, ymax = bbox
+ poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
+ return poly
+
+
+def eval_e2e(args):
+ # gt
+ gt_results = parse_ser_results_fp(args.gt_json_path, "gt",
+ args.ignore_background)
+ # pred
+ dt_results = parse_ser_results_fp(args.pred_json_path, "pred",
+ args.ignore_background)
+ assert set(gt_results.keys()) == set(dt_results.keys())
+
+ iou_thresh = args.iou_thres
+ num_gt_chars = 0
+ gt_count = 0
+ dt_count = 0
+ hit = 0
+ ed_sum = 0
+
+ for img_name in gt_results:
+ gt_info = gt_results[img_name]
+ gt_count += len(gt_info)
+
+ dt_info = dt_results[img_name]
+ dt_count += len(dt_info)
+
+ dt_match = [False] * len(dt_info)
+ gt_match = [False] * len(gt_info)
+
+ all_ious = defaultdict(tuple)
+ # gt: {text, label, bbox or poly}
+ for index_gt, gt in enumerate(gt_info):
+ if "poly" not in gt:
+ gt["poly"] = convert_bbox_to_polygon(gt["bbox"])
+ gt_poly = polygon_from_str(gt["poly"])
+ for index_dt, dt in enumerate(dt_info):
+ if "poly" not in dt:
+ dt["poly"] = convert_bbox_to_polygon(dt["bbox"])
+ dt_poly = polygon_from_str(dt["poly"])
+ iou = polygon_iou(dt_poly, gt_poly)
+ if iou >= iou_thresh:
+ all_ious[(index_gt, index_dt)] = iou
+ sorted_ious = sorted(
+ all_ious.items(), key=operator.itemgetter(1), reverse=True)
+ sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
+
+ # matched gt and dt
+ for gt_dt_pair in sorted_gt_dt_pairs:
+ index_gt, index_dt = gt_dt_pair
+ if gt_match[index_gt] == False and dt_match[index_dt] == False:
+ gt_match[index_gt] = True
+ dt_match[index_dt] = True
+ # ocr rec results
+ gt_text = gt_info[index_gt]["text"]
+ dt_text = dt_info[index_dt]["text"]
+
+ # ser results
+ gt_label = gt_info[index_gt]["label"]
+ dt_label = dt_info[index_dt]["pred"]
+
+ if True: # ignore_masks[index_gt] == '0':
+ ed_sum += ed(args, gt_text, dt_text)
+ num_gt_chars += len(gt_text)
+ if gt_text == dt_text:
+ if args.ignore_ser_prediction or gt_label == dt_label:
+ hit += 1
+
+# unmatched dt
+ for tindex, dt_match_flag in enumerate(dt_match):
+ if dt_match_flag == False:
+ dt_text = dt_info[tindex]["text"]
+ gt_text = ""
+ ed_sum += ed(args, dt_text, gt_text)
+
+# unmatched gt
+ for tindex, gt_match_flag in enumerate(gt_match):
+ if gt_match_flag == False:
+ dt_text = ""
+ gt_text = gt_info[tindex]["text"]
+ ed_sum += ed(args, gt_text, dt_text)
+ num_gt_chars += len(gt_text)
+
+ eps = 1e-9
+ print("config: ", args)
+ print('hit, dt_count, gt_count', hit, dt_count, gt_count)
+ precision = hit / (dt_count + eps)
+ recall = hit / (gt_count + eps)
+ fmeasure = 2.0 * precision * recall / (precision + recall + eps)
+ avg_edit_dist_img = ed_sum / len(gt_results)
+ avg_edit_dist_field = ed_sum / (gt_count + eps)
+ character_acc = 1 - ed_sum / (num_gt_chars + eps)
+
+ print('character_acc: %.2f' % (character_acc * 100) + "%")
+ print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
+ print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
+ print('precision: %.2f' % (precision * 100) + "%")
+ print('recall: %.2f' % (recall * 100) + "%")
+ print('fmeasure: %.2f' % (fmeasure * 100) + "%")
+
+ return
+
+
+def parse_args():
+ """
+ """
+
+ def str2bool(v):
+ return v.lower() in ("true", "t", "1")
+
+ parser = argparse.ArgumentParser()
+ ## Required parameters
+ parser.add_argument(
+ "--gt_json_path",
+ default=None,
+ type=str,
+ required=True, )
+ parser.add_argument(
+ "--pred_json_path",
+ default=None,
+ type=str,
+ required=True, )
+
+ parser.add_argument("--iou_thres", default=0.5, type=float)
+
+ parser.add_argument(
+ "--ignore_case",
+ default=False,
+ type=str2bool,
+ help="whether to do lower case for the strs")
+
+ parser.add_argument(
+ "--ignore_space",
+ default=True,
+ type=str2bool,
+ help="whether to ignore space")
+
+ parser.add_argument(
+ "--ignore_background",
+ default=True,
+ type=str2bool,
+ help="whether to ignore other label")
+
+ parser.add_argument(
+ "--ignore_ser_prediction",
+ default=False,
+ type=str2bool,
+ help="whether to ignore ocr pred results")
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ eval_e2e(args)
diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5ebd5dfbd8addda0701a7cfd2387133f7a8776b
--- /dev/null
+++ b/ppstructure/vqa/helper/trans_xfun_data.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+
+def transfer_xfun_data(json_path=None, output_file=None):
+ with open(json_path, "r") as fin:
+ lines = fin.readlines()
+
+ json_info = json.loads(lines[0])
+ documents = json_info["documents"]
+ label_info = {}
+ with open(output_file, "w") as fout:
+ for idx, document in enumerate(documents):
+ img_info = document["img"]
+ document = document["document"]
+ image_path = img_info["fname"]
+
+ label_info["height"] = img_info["height"]
+ label_info["width"] = img_info["width"]
+
+ label_info["ocr_info"] = []
+
+ for doc in document:
+ label_info["ocr_info"].append({
+ "text": doc["text"],
+ "label": doc["label"],
+ "bbox": doc["box"],
+ "id": doc["id"],
+ "linking": doc["linking"],
+ "words": doc["words"]
+ })
+
+ fout.write(image_path + "\t" + json.dumps(
+ label_info, ensure_ascii=False) + "\n")
+
+ print("===ok====")
+
+
+transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
diff --git a/ppstructure/vqa/images/input/zh_val_0.jpg b/ppstructure/vqa/images/input/zh_val_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..479b60bcd3a859b187ce5325dfc381c1b87ee27f
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_0.jpg differ
diff --git a/ppstructure/vqa/images/input/zh_val_42.jpg b/ppstructure/vqa/images/input/zh_val_42.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..42151bdd94929ede9da1a63ce8d9339971094a46
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_42.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..22ba9a6f1b7652ca9ce6848093c7a39affb4886b
Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..951864e5f35a987ff241f276c8da523d8c8eeaf3
Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg differ
diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad220094a26b330555fbe9122a46fb56e64fe1e
--- /dev/null
+++ b/ppstructure/vqa/infer_ser.py
@@ -0,0 +1,279 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+
+import paddle
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+
+
+def pad_sentences(tokenizer,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding with larger size, reshape is carried out
+ max_seq_len = (
+ len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
+
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [tokenizer.pad_token_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
+ ] * difference
+ else:
+ assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}"
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+
+def split_page(encoded_inputs, max_seq_len=512):
+ """
+ truncate is often used in training process
+ """
+ for key in encoded_inputs:
+ encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
+ if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
+ encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
+ else: # for bbox
+ encoded_inputs[key] = encoded_inputs[key].reshape(
+ [-1, max_seq_len, 4])
+ return encoded_inputs
+
+
+def preprocess(
+ tokenizer,
+ ori_img,
+ ocr_info,
+ img_size=(224, 224),
+ pad_token_label_id=-100,
+ max_seq_len=512,
+ add_special_ids=False,
+ return_attention_mask=True, ):
+ ocr_info = deepcopy(ocr_info)
+ height = ori_img.shape[0]
+ width = ori_img.shape[1]
+
+ img = cv2.resize(ori_img,
+ (224, 224)).transpose([2, 0, 1]).astype(np.float32)
+
+ segment_offset_id = []
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+
+ for info in ocr_info:
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ }
+
+ encoded_inputs = pad_sentences(
+ tokenizer,
+ encoded_inputs,
+ max_seq_len=max_seq_len,
+ return_attention_mask=return_attention_mask)
+
+ encoded_inputs = split_page(encoded_inputs)
+
+ fake_bs = encoded_inputs["input_ids"].shape[0]
+
+ encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
+ [fake_bs] + list(img.shape))
+
+ encoded_inputs["segment_offset_id"] = segment_offset_id
+
+ return encoded_inputs
+
+
+def postprocess(attention_mask, preds, label_map_path):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds = np.argmax(preds, axis=2)
+
+ _, label_map = get_bio_label_maps(label_map_path)
+
+ preds_list = [[] for _ in range(preds.shape[0])]
+
+ # keep batch info
+ for i in range(preds.shape[0]):
+ for j in range(preds.shape[1]):
+ if attention_mask[i][j] == 1:
+ preds_list[i].append(label_map[preds[i][j]])
+
+ return preds_list
+
+
+def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
+ preds_list):
+ # must ensure the preds_list is generated from the same image
+ preds = [p for pred in preds_list for p in pred]
+ label2id_map, _ = get_bio_label_maps(label_map_path)
+ for key in label2id_map:
+ if key.startswith("I-"):
+ label2id_map[key] = label2id_map["B" + key[1:]]
+
+ id2label_map = dict()
+ for key in label2id_map:
+ val = label2id_map[key]
+ if key == "O":
+ id2label_map[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ id2label_map[val] = key[2:]
+ else:
+ id2label_map[val] = key
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = preds[start_id:end_id]
+ curr_pred = [label2id_map[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = id2label_map[pred_id]
+ return ocr_info
+
+
+@paddle.no_grad()
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # init token and model
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ # model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ model.eval()
+
+ # load ocr results json
+ ocr_results = dict()
+ with open(args.ocr_json_path, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ img_name, json_info = line.split("\t")
+ ocr_results[os.path.basename(img_name)] = json.loads(json_info)
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ # loop for infer
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
+ inputs = preprocess(
+ tokenizer=tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=args.max_seq_length)
+
+ outputs = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds,
+ args.label_map_path)
+ ocr_info = merge_preds_list_with_ocr_info(
+ args.label_map_path, ocr_info, inputs["segment_offset_id"],
+ preds)
+
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": ocr_info,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_ser_results(img, ocr_info)
+ cv2.imwrite(
+ os.path.join(args.output_dir, os.path.basename(img_path)),
+ img_res)
+
+ return
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..da027a140bdb4fa12a40d423998d94e438a7cd11
--- /dev/null
+++ b/ppstructure/vqa/infer_ser_e2e.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+from PIL import Image
+
+import paddle
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine
+
+from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
+
+
+def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
+def parse_ocr_info_for_ser(ocr_result):
+ ocr_info = []
+ for res in ocr_result:
+ ocr_info.append({
+ "text": res[1][0],
+ "bbox": trans_poly_to_bbox(res[0]),
+ "poly": res[0],
+ })
+ return ocr_info
+
+
+@paddle.no_grad()
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # init token and model
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ model.eval()
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ label2id_map_for_draw[key] = label2id_map[key]
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
+ args.ocr_det_model_dir)
+
+ # loop for infer
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ ocr_result = ocr_engine.ocr(img_path, cls=False)
+
+ ocr_info = parse_ocr_info_for_ser(ocr_result)
+
+ inputs = preprocess(
+ tokenizer=tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=args.max_seq_length)
+
+ outputs = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds, id2label_map)
+ ocr_info = merge_preds_list_with_ocr_info(
+ ocr_info, inputs["segment_offset_id"], preds,
+ label2id_map_for_draw)
+
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": ocr_info,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_ser_results(img, ocr_info)
+ cv2.imwrite(
+ os.path.join(args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] +
+ "_ser.jpg"), img_res)
+
+ return
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt
new file mode 100644
index 0000000000000000000000000000000000000000..508e48112412f62538baf0c78bcf99ec8945196e
--- /dev/null
+++ b/ppstructure/vqa/labels/labels_ser.txt
@@ -0,0 +1,3 @@
+QUESTION
+ANSWER
+HEADER
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c259fadc395335b336cb0ecdb5aa6bca48631987
--- /dev/null
+++ b/ppstructure/vqa/requirements.txt
@@ -0,0 +1,2 @@
+sentencepiece
+yacs
diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ca69d93fd22983533fcacd639bbd64dc3e11ec
--- /dev/null
+++ b/ppstructure/vqa/train_ser.py
@@ -0,0 +1,313 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import random
+import copy
+import logging
+
+import argparse
+import paddle
+import numpy as np
+from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+from xfun import XFUNDataset
+from utils import parse_args
+from utils import get_bio_label_maps
+
+logger = logging.getLogger(__name__)
+
+
+def set_seed(args):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ paddle.seed(args.seed)
+
+
+def train(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+ logging.basicConfig(
+ filename=os.path.join(args.output_dir, "train.log")
+ if paddle.distributed.get_rank() == 0 else None,
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO
+ if paddle.distributed.get_rank() == 0 else logging.WARN, )
+
+ ch = logging.StreamHandler()
+ ch.setLevel(logging.DEBUG)
+ logger.addHandler(ch)
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ base_model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification(
+ base_model, num_classes=len(label2id_map), dropout=None)
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.DataParallel(model)
+
+ train_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.train_data_dir,
+ label_path=args.train_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ train_sampler = paddle.io.DistributedBatchSampler(
+ train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
+
+ args.train_batch_size = args.per_gpu_train_batch_size * max(
+ 1, paddle.distributed.get_world_size())
+
+ train_dataloader = paddle.io.DataLoader(
+ train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=0,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ t_total = len(train_dataloader) * args.num_train_epochs
+
+ # build linear decay with warmup lr sch
+ lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+ learning_rate=args.learning_rate,
+ decay_steps=t_total,
+ end_lr=0.0,
+ power=1.0)
+ if args.warmup_steps > 0:
+ lr_scheduler = paddle.optimizer.lr.LinearWarmup(
+ lr_scheduler,
+ args.warmup_steps,
+ start_lr=0,
+ end_lr=args.learning_rate, )
+
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ parameters=model.parameters(),
+ epsilon=args.adam_epsilon,
+ weight_decay=args.weight_decay)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = %d", len(train_dataset))
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
+ logger.info(" Instantaneous batch size per GPU = %d",
+ args.per_gpu_train_batch_size)
+ logger.info(
+ " Total train batch size (w. parallel, distributed) = %d",
+ args.train_batch_size * paddle.distributed.get_world_size(), )
+ logger.info(" Total optimization steps = %d", t_total)
+
+ global_step = 0
+ tr_loss = 0.0
+ set_seed(args)
+ best_metrics = None
+
+ for epoch_id in range(args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ model.train()
+ outputs = model(**batch)
+ # model outputs are always tuple in ppnlp (see doc)
+ loss = outputs[0]
+ loss = loss.mean()
+ logger.info(
+ "[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ".
+ format(epoch_id, args.num_train_epochs, step,
+ len(train_dataloader),
+ lr_scheduler.get_lr(), loss.numpy()[0]))
+
+ loss.backward()
+ tr_loss += loss.item()
+ optimizer.step()
+ lr_scheduler.step() # Update learning rate schedule
+ optimizer.clear_grad()
+ global_step += 1
+
+ if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
+ global_step % args.eval_steps == 0):
+ # Log metrics
+ # Only evaluate when single GPU otherwise metrics may not average well
+ if paddle.distributed.get_rank(
+ ) == 0 and args.evaluate_during_training:
+ results, _ = evaluate(
+ args,
+ model,
+ tokenizer,
+ label2id_map,
+ id2label_map,
+ pad_token_label_id, )
+
+ if best_metrics is None or results["f1"] >= best_metrics[
+ "f1"]:
+ best_metrics = copy.deepcopy(results)
+ output_dir = os.path.join(args.output_dir, "best_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(
+ args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s",
+ output_dir)
+
+ logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
+ epoch_id, args.num_train_epochs, step,
+ len(train_dataloader), results))
+ if best_metrics is not None:
+ logger.info("best metrics: {}".format(best_metrics))
+
+ if paddle.distributed.get_rank(
+ ) == 0 and args.save_steps > 0 and global_step % args.save_steps == 0:
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir,
+ "checkpoint-{}".format(global_step))
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s", output_dir)
+
+ return global_step, tr_loss / global_step
+
+
+def evaluate(args,
+ model,
+ tokenizer,
+ label2id_map,
+ id2label_map,
+ pad_token_label_id,
+ prefix=""):
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(
+ 1, paddle.distributed.get_world_size())
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.eval_batch_size,
+ num_workers=0,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ # Eval!
+ logger.info("***** Running evaluation %s *****", prefix)
+ logger.info(" Num examples = %d", len(eval_dataset))
+ logger.info(" Batch size = %d", args.eval_batch_size)
+ eval_loss = 0.0
+ nb_eval_steps = 0
+ preds = None
+ out_label_ids = None
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ outputs = model(**batch)
+ tmp_eval_loss, logits = outputs[:2]
+
+ tmp_eval_loss = tmp_eval_loss.mean()
+
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), tmp_eval_loss.numpy()[0]))
+
+ eval_loss += tmp_eval_loss.item()
+ nb_eval_steps += 1
+ if preds is None:
+ preds = logits.numpy()
+ out_label_ids = batch["labels"].numpy()
+ else:
+ preds = np.append(preds, logits.numpy(), axis=0)
+ out_label_ids = np.append(
+ out_label_ids, batch["labels"].numpy(), axis=0)
+
+ eval_loss = eval_loss / nb_eval_steps
+ preds = np.argmax(preds, axis=2)
+
+ # label_map = {i: label.upper() for i, label in enumerate(labels)}
+
+ out_label_list = [[] for _ in range(out_label_ids.shape[0])]
+ preds_list = [[] for _ in range(out_label_ids.shape[0])]
+
+ for i in range(out_label_ids.shape[0]):
+ for j in range(out_label_ids.shape[1]):
+ if out_label_ids[i, j] != pad_token_label_id:
+ out_label_list[i].append(id2label_map[out_label_ids[i][j]])
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ results = {
+ "loss": eval_loss,
+ "precision": precision_score(out_label_list, preds_list),
+ "recall": recall_score(out_label_list, preds_list),
+ "f1": f1_score(out_label_list, preds_list),
+ }
+
+ with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout:
+ for lbl in out_label_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+ with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout:
+ for lbl in preds_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+
+ report = classification_report(out_label_list, preds_list)
+ logger.info("\n" + report)
+
+ logger.info("***** Eval results %s *****", prefix)
+ for key in sorted(results.keys()):
+ logger.info(" %s = %s", key, str(results[key]))
+
+ return results, preds_list
+
+
+def print_arguments(args):
+ """print arguments"""
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ print_arguments(args)
+ train(args)
diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ac1e77d37d0a662294480a393c2f67e7f4cc64
--- /dev/null
+++ b/ppstructure/vqa/utils.py
@@ -0,0 +1,328 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import cv2
+import random
+import numpy as np
+import imghdr
+from copy import deepcopy
+
+import paddle
+
+from PIL import Image, ImageDraw, ImageFont
+
+from paddleocr import PaddleOCR
+
+
+def get_bio_label_maps(label_map_path):
+ with open(label_map_path, "r") as fin:
+ lines = fin.readlines()
+ lines = [line.strip() for line in lines]
+ if "O" not in lines:
+ lines.insert(0, "O")
+ labels = []
+ for line in lines:
+ if line == "O":
+ labels.append("O")
+ else:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label: idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def get_image_file_list(img_file):
+ imgs_lists = []
+ if img_file is None or not os.path.exists(img_file):
+ raise Exception("not found any img file in {}".format(img_file))
+
+ img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
+ if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
+ imgs_lists.append(img_file)
+ elif os.path.isdir(img_file):
+ for single_file in os.listdir(img_file):
+ file_path = os.path.join(img_file, single_file)
+ if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
+ imgs_lists.append(file_path)
+ if len(imgs_lists) == 0:
+ raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
+ return imgs_lists
+
+
+def draw_ser_results(image,
+ ocr_results,
+ font_path="../doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(0)
+ color = (np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)))
+ color_map = {
+ idx: (color[0][idx], color[1][idx], color[2][idx])
+ for idx in range(1, 255)
+ }
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+
+ for ocr_info in ocr_results:
+ if ocr_info["pred_id"] not in color_map:
+ continue
+ color = color_map[ocr_info["pred_id"]]
+
+ # draw ocr results outline
+ bbox = ocr_info["bbox"]
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+ start_y = max(0, bbox[0][1] - font_size)
+ tw = font.getsize(text)[0]
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1,
+ start_y + font_size)],
+ fill=(0, 0, 255))
+ draw.text(
+ (bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
+
+
+def build_ocr_engine(rec_model_dir, det_model_dir):
+ ocr_engine = PaddleOCR(
+ rec_model_dir=rec_model_dir,
+ det_model_dir=det_model_dir,
+ use_angle_cls=False)
+ return ocr_engine
+
+
+# pad sentences
+def pad_sentences(tokenizer,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding with larger size, reshape is carried out
+ max_seq_len = (
+ len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
+
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [tokenizer.pad_token_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
+ ] * difference
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+
+def split_page(encoded_inputs, max_seq_len=512):
+ """
+ truncate is often used in training process
+ """
+ for key in encoded_inputs:
+ encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
+ if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
+ encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
+ else: # for bbox
+ encoded_inputs[key] = encoded_inputs[key].reshape(
+ [-1, max_seq_len, 4])
+ return encoded_inputs
+
+
+def preprocess(
+ tokenizer,
+ ori_img,
+ ocr_info,
+ img_size=(224, 224),
+ pad_token_label_id=-100,
+ max_seq_len=512,
+ add_special_ids=False,
+ return_attention_mask=True, ):
+ ocr_info = deepcopy(ocr_info)
+ height = ori_img.shape[0]
+ width = ori_img.shape[1]
+
+ img = cv2.resize(ori_img,
+ (224, 224)).transpose([2, 0, 1]).astype(np.float32)
+
+ segment_offset_id = []
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+
+ for info in ocr_info:
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ }
+
+ encoded_inputs = pad_sentences(
+ tokenizer,
+ encoded_inputs,
+ max_seq_len=max_seq_len,
+ return_attention_mask=return_attention_mask)
+
+ encoded_inputs = split_page(encoded_inputs)
+
+ fake_bs = encoded_inputs["input_ids"].shape[0]
+
+ encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
+ [fake_bs] + list(img.shape))
+
+ encoded_inputs["segment_offset_id"] = segment_offset_id
+
+ return encoded_inputs
+
+
+def postprocess(attention_mask, preds, id2label_map):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds = np.argmax(preds, axis=2)
+
+ preds_list = [[] for _ in range(preds.shape[0])]
+
+ # keep batch info
+ for i in range(preds.shape[0]):
+ for j in range(preds.shape[1]):
+ if attention_mask[i][j] == 1:
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ return preds_list
+
+
+def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
+ label2id_map_for_draw):
+ # must ensure the preds_list is generated from the same image
+ preds = [p for pred in preds_list for p in pred]
+
+ id2label_map = dict()
+ for key in label2id_map_for_draw:
+ val = label2id_map_for_draw[key]
+ if key == "O":
+ id2label_map[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ id2label_map[val] = key[2:]
+ else:
+ id2label_map[val] = key
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = preds[start_id:end_id]
+ curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
+ return ocr_info
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ # yapf: disable
+ parser.add_argument("--model_name_or_path", default=None, type=str, required=True,)
+ parser.add_argument("--train_data_dir", default=None, type=str, required=False,)
+ parser.add_argument("--train_label_path", default=None, type=str, required=False,)
+ parser.add_argument("--eval_data_dir", default=None, type=str, required=False,)
+ parser.add_argument("--eval_label_path", default=None, type=str, required=False,)
+ parser.add_argument("--output_dir", default=None, type=str, required=True,)
+ parser.add_argument("--max_seq_length", default=512, type=int,)
+ parser.add_argument("--evaluate_during_training", action="store_true",)
+ parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",)
+ parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",)
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",)
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",)
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",)
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",)
+ parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",)
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",)
+ parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",)
+ parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",)
+ parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",)
+
+ parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
+ parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
+ parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
+ parser.add_argument("--infer_imgs", default=None, type=str, required=False)
+ parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results")
+ # yapf: enable
+ args = parser.parse_args()
+ return args
diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py
new file mode 100644
index 0000000000000000000000000000000000000000..d62cdb5da5514280b62687d80d345ede9484ee90
--- /dev/null
+++ b/ppstructure/vqa/xfun.py
@@ -0,0 +1,442 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import cv2
+import numpy as np
+import paddle
+import copy
+from paddle.io import Dataset
+
+__all__ = ["XFUNDataset"]
+
+
+class XFUNDataset(Dataset):
+ """
+ Example:
+ print("=====begin to build dataset=====")
+ from paddlenlp.transformers import LayoutXLMTokenizer
+ tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
+ tok_res = tokenizer.tokenize("Maribyrnong")
+ # res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
+ dataset = XfunDatasetForSer(
+ tokenizer,
+ data_dir="./zh.val/",
+ label_path="zh.val/xfun_normalize_val.json",
+ img_size=(224,224))
+ print(len(dataset))
+
+ data = dataset[0]
+ print(data.keys())
+ print("input_ids: ", data["input_ids"])
+ print("labels: ", data["labels"])
+ print("token_type_ids: ", data["token_type_ids"])
+ print("words_list: ", data["words_list"])
+ print("image shape: ", data["image"].shape)
+ """
+
+ def __init__(self,
+ tokenizer,
+ data_dir,
+ label_path,
+ contains_re=False,
+ label2id_map=None,
+ img_size=(224, 224),
+ pad_token_label_id=None,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all',
+ max_seq_len=512):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_dir = data_dir
+ self.label_path = label_path
+ self.contains_re = contains_re
+ self.label2id_map = label2id_map
+ self.img_size = img_size
+ self.pad_token_label_id = pad_token_label_id
+ self.add_special_ids = add_special_ids
+ self.return_attention_mask = return_attention_mask
+ self.load_mode = load_mode
+ self.max_seq_len = max_seq_len
+
+ if self.pad_token_label_id is None:
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ self.all_lines = self.read_all_lines()
+
+ self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+ self.return_keys = {
+ 'bbox': 'np',
+ 'input_ids': 'np',
+ 'labels': 'np',
+ 'attention_mask': 'np',
+ 'image': 'np',
+ 'token_type_ids': 'np',
+ 'entities': 'dict',
+ 'relations': 'dict',
+ }
+
+ if load_mode == "all":
+ self.encoded_inputs_all = self._parse_label_file_all()
+
+ def pad_sentences(self,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if self.tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [self.tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [self.tokenizer.pad_token_id] * difference
+ encoded_inputs["labels"] = encoded_inputs[
+ "labels"] + [self.pad_token_label_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs[
+ "bbox"] + [[0, 0, 0, 0]] * difference
+ elif self.tokenizer.padding_side == 'left':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + [
+ 1
+ ] * len(encoded_inputs["input_ids"])
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ [self.tokenizer.pad_token_type_id] * difference +
+ encoded_inputs["token_type_ids"])
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = [
+ 1
+ ] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs["input_ids"] = [
+ self.tokenizer.pad_token_id
+ ] * difference + encoded_inputs["input_ids"]
+ encoded_inputs["labels"] = [
+ self.pad_token_label_id
+ ] * difference + encoded_inputs["labels"]
+ encoded_inputs["bbox"] = [
+ [0, 0, 0, 0]
+ ] * difference + encoded_inputs["bbox"]
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+ def truncate_inputs(self, encoded_inputs, max_seq_len=512):
+ for key in encoded_inputs:
+ if key == "sample_id":
+ continue
+ length = min(len(encoded_inputs[key]), max_seq_len)
+ encoded_inputs[key] = encoded_inputs[key][:length]
+ return encoded_inputs
+
+ def read_all_lines(self, ):
+ with open(self.label_path, "r") as fin:
+ lines = fin.readlines()
+ return lines
+
+ def _parse_label_file_all(self):
+ """
+ parse all samples
+ """
+ encoded_inputs_all = []
+ for line in self.all_lines:
+ encoded_inputs_all.extend(self._parse_label_file(line))
+ return encoded_inputs_all
+
+ def _parse_label_file(self, line):
+ """
+ parse single sample
+ """
+
+ image_name, info_str = line.split("\t")
+ image_path = os.path.join(self.data_dir, image_name)
+
+ def add_imgge_path(x):
+ x['image_path'] = image_path
+ return x
+
+ encoded_inputs = self._read_encoded_inputs_sample(info_str)
+ if self.contains_re:
+ encoded_inputs = self._chunk_re(encoded_inputs)
+ else:
+ encoded_inputs = self._chunk_ser(encoded_inputs)
+ encoded_inputs = list(map(add_imgge_path, encoded_inputs))
+ return encoded_inputs
+
+ def _read_encoded_inputs_sample(self, info_str):
+ """
+ parse label info
+ """
+ # read text info
+ info_dict = json.loads(info_str)
+ height = info_dict["height"]
+ width = info_dict["width"]
+
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ gt_label_list = []
+
+ if self.contains_re:
+ # for re
+ entities = []
+ relations = []
+ id2label = {}
+ entity_id_to_index_map = {}
+ empty_entity = set()
+ for info in info_dict["ocr_info"]:
+ if self.contains_re:
+ # for re
+ if len(info["text"]) == 0:
+ empty_entity.add(info["id"])
+ continue
+ id2label[info["id"]] = info["label"]
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
+
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ label = info["label"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = self.tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ gt_label = []
+ if not self.add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
+ -1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
+ -1]
+ if label.lower() == "other":
+ gt_label.extend([0] * len(encode_res["input_ids"]))
+ else:
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
+ (len(encode_res["input_ids"]) - 1))
+ if self.contains_re:
+ if gt_label[0] != self.label2id_map["O"]:
+ entity_id_to_index_map[info["id"]] = len(entities)
+ entities.append({
+ "start": len(input_ids_list),
+ "end":
+ len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ })
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ gt_label_list.extend(gt_label)
+ words_list.append(text)
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "labels": gt_label_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ # "words_list": words_list,
+ }
+ encoded_inputs = self.pad_sentences(
+ encoded_inputs,
+ max_seq_len=self.max_seq_len,
+ return_attention_mask=self.return_attention_mask)
+ encoded_inputs = self.truncate_inputs(encoded_inputs)
+
+ if self.contains_re:
+ relations = self._relations(entities, relations, id2label,
+ empty_entity, entity_id_to_index_map)
+ encoded_inputs['relations'] = relations
+ encoded_inputs['entities'] = entities
+ return encoded_inputs
+
+ def _chunk_ser(self, encoded_inputs):
+ encoded_inputs_all = []
+ seq_len = len(encoded_inputs['input_ids'])
+ chunk_size = 512
+ for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
+ chunk_beg = index
+ chunk_end = min(index + chunk_size, seq_len)
+ encoded_inputs_example = {}
+ for key in encoded_inputs:
+ encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
+ chunk_end]
+
+ encoded_inputs_all.append(encoded_inputs_example)
+ return encoded_inputs_all
+
+ def _chunk_re(self, encoded_inputs):
+ # prepare data
+ entities = encoded_inputs.pop('entities')
+ relations = encoded_inputs.pop('relations')
+ encoded_inputs_all = []
+ chunk_size = 512
+ for chunk_id, index in enumerate(
+ range(0, len(encoded_inputs["input_ids"]), chunk_size)):
+ item = {}
+ for k in encoded_inputs:
+ item[k] = encoded_inputs[k][index:index + chunk_size]
+
+ # select entity in current chunk
+ entities_in_this_span = []
+ global_to_local_map = {} #
+ for entity_id, entity in enumerate(entities):
+ if (index <= entity["start"] < index + chunk_size and
+ index <= entity["end"] < index + chunk_size):
+ entity["start"] = entity["start"] - index
+ entity["end"] = entity["end"] - index
+ global_to_local_map[entity_id] = len(entities_in_this_span)
+ entities_in_this_span.append(entity)
+
+ # select relations in current chunk
+ relations_in_this_span = []
+ for relation in relations:
+ if (index <= relation["start_index"] < index + chunk_size and
+ index <= relation["end_index"] < index + chunk_size):
+ relations_in_this_span.append({
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ })
+ item.update({
+ "entities": reformat(entities_in_this_span),
+ "relations": reformat(relations_in_this_span),
+ })
+ item['entities']['label'] = [
+ self.entities_labels[x] for x in item['entities']['label']
+ ]
+ encoded_inputs_all.append(item)
+ return encoded_inputs_all
+
+ def _relations(self, entities, relations, id2label, empty_entity,
+ entity_id_to_index_map):
+ """
+ build relations
+ """
+ relations = list(set(relations))
+ relations = [
+ rel for rel in relations
+ if rel[0] not in empty_entity and rel[1] not in empty_entity
+ ]
+ kv_relations = []
+ for rel in relations:
+ pair = [id2label[rel[0]], id2label[rel[1]]]
+ if pair == ["question", "answer"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[0]],
+ "tail": entity_id_to_index_map[rel[1]]
+ })
+ elif pair == ["answer", "question"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[1]],
+ "tail": entity_id_to_index_map[rel[0]]
+ })
+ else:
+ continue
+ relations = sorted(
+ [{
+ "head": rel["head"],
+ "tail": rel["tail"],
+ "start_index": get_relation_span(rel, entities)[0],
+ "end_index": get_relation_span(rel, entities)[1],
+ } for rel in kv_relations],
+ key=lambda x: x["head"], )
+ return relations
+
+ def load_img(self, image_path):
+ # read img
+ img = cv2.imread(image_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ resize_h, resize_w = self.img_size
+ im_shape = img.shape[0:2]
+ im_scale_y = resize_h / im_shape[0]
+ im_scale_x = resize_w / im_shape[1]
+ img_new = cv2.resize(
+ img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
+ mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
+ std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
+ img_new = img_new / 255.0
+ img_new -= mean
+ img_new /= std
+ img = img_new.transpose((2, 0, 1))
+ return img
+
+ def __getitem__(self, idx):
+ if self.load_mode == "all":
+ data = copy.deepcopy(self.encoded_inputs_all[idx])
+ else:
+ data = self._parse_label_file(self.all_lines[idx])[0]
+
+ image_path = data.pop('image_path')
+ data["image"] = self.load_img(image_path)
+
+ return_data = {}
+ for k, v in data.items():
+ if k in self.return_keys:
+ if self.return_keys[k] == 'np':
+ v = np.array(v)
+ return_data[k] = v
+ return return_data
+
+ def __len__(self, ):
+ if self.load_mode == "all":
+ return len(self.encoded_inputs_all)
+ else:
+ return len(self.all_lines)
+
+
+def get_relation_span(rel, entities):
+ bound = []
+ for entity_index in [rel["head"], rel["tail"]]:
+ bound.append(entities[entity_index]["start"])
+ bound.append(entities[entity_index]["end"])
+ return min(bound), max(bound)
+
+
+def reformat(data):
+ new_data = {}
+ for item in data:
+ for k, v in item.items():
+ if k not in new_data:
+ new_data[k] = []
+ new_data[k].append(v)
+ return new_data