diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8d117fdeb16e1c0e90bf6ec89924e414fc764249 --- /dev/null +++ b/ppstructure/vqa/README.md @@ -0,0 +1,182 @@ +# 视觉问答(VQA) + +VQA主要特性如下: + +- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。 +- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对) +- 支持SER任务与OCR引擎联合的端到端系统预测与评估。 +- 支持SER任务和RE任务的自定义训练 + + +本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现, +包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。 + +## 1. 效果演示 + +**注意:** 测试图片来源于XFUN数据集。 + +### 1.1 SER + +
+ +
+ +
+ +
+ +其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 + + +### 1.2 RE + +* Coming soon! + + + +## 2. 安装 + +### 2.1 安装依赖 + +- **(1) 安装PaddlePaddle** + +```bash +pip3 install --upgrade pip + +# GPU安装 +python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple + +# CPU安装 +python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple + +``` +更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 + + +### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA ) + +- **(1)pip快速安装PaddleOCR whl包(仅预测)** + +```bash +pip install "paddleocr>=2.2" # 推荐使用2.2+版本 +``` + +- **(2)下载VQA源码(预测+训练)** + +```bash +【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR + +# 如果因为网络问题无法pull成功,也可选择使用码云上的托管: +git clone https://gitee.com/paddlepaddle/PaddleOCR + +# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。 +``` + +- **(3)安装PaddleNLP** + +```bash +# 需要使用PaddleNLP最新的代码版本进行安装 +git clone https://github.com/PaddlePaddle/PaddleNLP -b develop +cd PaddleNLP +pip install -e . +``` + + +- **(4)安装VQA的`requirements`** + +```bash +pip install -r requirements.txt +``` + +## 3. 使用 + + +### 3.1 数据和预训练模型准备 + +处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。 + + +下载并解压该数据集,解压后将数据集放置在当前目录下。 + +```shell +wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar +``` + +如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。 + +如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。 + +* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) +* RE任务预训练模型下载链接:coming soon! + + +### 3.2 SER任务 + +* 启动训练 + +```shell +python train_ser.py \ + --model_name_or_path "layoutxlm-base-uncased" \ + --train_data_dir "XFUND/zh_train/image" \ + --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ + --eval_data_dir "XFUND/zh_val/image" \ + --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ + --num_train_epochs 200 \ + --eval_steps 10 \ + --save_steps 500 \ + --output_dir "./output/ser/" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --evaluate_during_training \ + --seed 2048 +``` + +最终会打印出`precision`, `recall`, `f1`等指标,如下所示。 + +``` +best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063} +``` + +模型和训练日志会保存在`./output/ser/`文件夹中。 + +* 使用评估集合中提供的OCR识别结果进行预测 + +```shell +export CUDA_VISIBLE_DEVICES=0 +python3.7 infer_ser.py \ + --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ + --output_dir "output_res/" \ + --infer_imgs "XFUND/zh_val/image/" \ + --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json" +``` + +最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。 + +* 使用`OCR引擎 + SER`串联结果 + +```shell +export CUDA_VISIBLE_DEVICES=0 +python3.7 infer_ser_e2e.py \ + --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ + --max_seq_length 512 \ + --output_dir "output_res_e2e/" +``` + +* 对`OCR引擎 + SER`预测系统进行端到端评估 + +```shell +export CUDA_VISIBLE_DEVICES=0 +python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt +``` + + +3.3 RE任务 + +coming soon! + + +## 参考链接 + +- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf +- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm +- XFUND dataset, https://github.com/doc-analysis/XFUND diff --git a/ppstructure/vqa/helper/eval_with_label_end2end.py b/ppstructure/vqa/helper/eval_with_label_end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dd3e0ad437e51e21ebc53daeec9fdf9aa76b63 --- /dev/null +++ b/ppstructure/vqa/helper/eval_with_label_end2end.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import sys +# import Polygon +import shapely +from shapely.geometry import Polygon +import numpy as np +from collections import defaultdict +import operator +import editdistance +import argparse +import json +import copy + + +def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True): + # img/zh_val_0.jpg { + # "height": 3508, + # "width": 2480, + # "ocr_info": [ + # {"text": "Maribyrnong", "label": "other", "bbox": [1958, 144, 2184, 198]}, + # {"text": "CITYCOUNCIL", "label": "other", "bbox": [2052, 183, 2171, 214]}, + # ] + assert fp_type in ["gt", "pred"] + key = "label" if fp_type == "gt" else "pred" + res_dict = dict() + with open(fp, "r") as fin: + lines = fin.readlines() + + for _, line in enumerate(lines): + img_path, info = line.strip().split("\t") + # get key + image_name = os.path.basename(img_path) + res_dict[image_name] = [] + # get infos + json_info = json.loads(info) + for single_ocr_info in json_info["ocr_info"]: + label = single_ocr_info[key].upper() + if label in ["O", "OTHERS", "OTHER"]: + label = "O" + if ignore_background and label == "O": + continue + single_ocr_info["label"] = label + res_dict[image_name].append(copy.deepcopy(single_ocr_info)) + return res_dict + + +def polygon_from_str(polygon_points): + """ + Create a shapely polygon object from gt or dt line. + """ + polygon_points = np.array(polygon_points).reshape(4, 2) + polygon = Polygon(polygon_points).convex_hull + return polygon + + +def polygon_iou(poly1, poly2): + """ + Intersection over union between two shapely polygons. + """ + if not poly1.intersects( + poly2): # this test is fast and can accelerate calculation + iou = 0 + else: + try: + inter_area = poly1.intersection(poly2).area + union_area = poly1.area + poly2.area - inter_area + iou = float(inter_area) / union_area + except shapely.geos.TopologicalError: + # except Exception as e: + # print(e) + print('shapely.geos.TopologicalError occured, iou set to 0') + iou = 0 + return iou + + +def ed(args, str1, str2): + if args.ignore_space: + str1 = str1.replace(" ", "") + str2 = str2.replace(" ", "") + if args.ignore_case: + str1 = str1.lower() + str2 = str2.lower() + return editdistance.eval(str1, str2) + + +def convert_bbox_to_polygon(bbox): + """ + bbox : [x1, y1, x2, y2] + output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] + """ + xmin, ymin, xmax, ymax = bbox + poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + return poly + + +def eval_e2e(args): + # gt + gt_results = parse_ser_results_fp(args.gt_json_path, "gt", + args.ignore_background) + # pred + dt_results = parse_ser_results_fp(args.pred_json_path, "pred", + args.ignore_background) + assert set(gt_results.keys()) == set(dt_results.keys()) + + iou_thresh = args.iou_thres + num_gt_chars = 0 + gt_count = 0 + dt_count = 0 + hit = 0 + ed_sum = 0 + + for img_name in gt_results: + gt_info = gt_results[img_name] + gt_count += len(gt_info) + + dt_info = dt_results[img_name] + dt_count += len(dt_info) + + dt_match = [False] * len(dt_info) + gt_match = [False] * len(gt_info) + + all_ious = defaultdict(tuple) + # gt: {text, label, bbox or poly} + for index_gt, gt in enumerate(gt_info): + if "poly" not in gt: + gt["poly"] = convert_bbox_to_polygon(gt["bbox"]) + gt_poly = polygon_from_str(gt["poly"]) + for index_dt, dt in enumerate(dt_info): + if "poly" not in dt: + dt["poly"] = convert_bbox_to_polygon(dt["bbox"]) + dt_poly = polygon_from_str(dt["poly"]) + iou = polygon_iou(dt_poly, gt_poly) + if iou >= iou_thresh: + all_ious[(index_gt, index_dt)] = iou + sorted_ious = sorted( + all_ious.items(), key=operator.itemgetter(1), reverse=True) + sorted_gt_dt_pairs = [item[0] for item in sorted_ious] + + # matched gt and dt + for gt_dt_pair in sorted_gt_dt_pairs: + index_gt, index_dt = gt_dt_pair + if gt_match[index_gt] == False and dt_match[index_dt] == False: + gt_match[index_gt] = True + dt_match[index_dt] = True + # ocr rec results + gt_text = gt_info[index_gt]["text"] + dt_text = dt_info[index_dt]["text"] + + # ser results + gt_label = gt_info[index_gt]["label"] + dt_label = dt_info[index_dt]["pred"] + + if True: # ignore_masks[index_gt] == '0': + ed_sum += ed(args, gt_text, dt_text) + num_gt_chars += len(gt_text) + if gt_text == dt_text: + if args.ignore_ser_prediction or gt_label == dt_label: + hit += 1 + +# unmatched dt + for tindex, dt_match_flag in enumerate(dt_match): + if dt_match_flag == False: + dt_text = dt_info[tindex]["text"] + gt_text = "" + ed_sum += ed(args, dt_text, gt_text) + +# unmatched gt + for tindex, gt_match_flag in enumerate(gt_match): + if gt_match_flag == False: + dt_text = "" + gt_text = gt_info[tindex]["text"] + ed_sum += ed(args, gt_text, dt_text) + num_gt_chars += len(gt_text) + + eps = 1e-9 + print("config: ", args) + print('hit, dt_count, gt_count', hit, dt_count, gt_count) + precision = hit / (dt_count + eps) + recall = hit / (gt_count + eps) + fmeasure = 2.0 * precision * recall / (precision + recall + eps) + avg_edit_dist_img = ed_sum / len(gt_results) + avg_edit_dist_field = ed_sum / (gt_count + eps) + character_acc = 1 - ed_sum / (num_gt_chars + eps) + + print('character_acc: %.2f' % (character_acc * 100) + "%") + print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field)) + print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img)) + print('precision: %.2f' % (precision * 100) + "%") + print('recall: %.2f' % (recall * 100) + "%") + print('fmeasure: %.2f' % (fmeasure * 100) + "%") + + return + + +def parse_args(): + """ + """ + + def str2bool(v): + return v.lower() in ("true", "t", "1") + + parser = argparse.ArgumentParser() + ## Required parameters + parser.add_argument( + "--gt_json_path", + default=None, + type=str, + required=True, ) + parser.add_argument( + "--pred_json_path", + default=None, + type=str, + required=True, ) + + parser.add_argument("--iou_thres", default=0.5, type=float) + + parser.add_argument( + "--ignore_case", + default=False, + type=str2bool, + help="whether to do lower case for the strs") + + parser.add_argument( + "--ignore_space", + default=True, + type=str2bool, + help="whether to ignore space") + + parser.add_argument( + "--ignore_background", + default=True, + type=str2bool, + help="whether to ignore other label") + + parser.add_argument( + "--ignore_ser_prediction", + default=False, + type=str2bool, + help="whether to ignore ocr pred results") + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + eval_e2e(args) diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ebd5dfbd8addda0701a7cfd2387133f7a8776b --- /dev/null +++ b/ppstructure/vqa/helper/trans_xfun_data.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + + +def transfer_xfun_data(json_path=None, output_file=None): + with open(json_path, "r") as fin: + lines = fin.readlines() + + json_info = json.loads(lines[0]) + documents = json_info["documents"] + label_info = {} + with open(output_file, "w") as fout: + for idx, document in enumerate(documents): + img_info = document["img"] + document = document["document"] + image_path = img_info["fname"] + + label_info["height"] = img_info["height"] + label_info["width"] = img_info["width"] + + label_info["ocr_info"] = [] + + for doc in document: + label_info["ocr_info"].append({ + "text": doc["text"], + "label": doc["label"], + "bbox": doc["box"], + "id": doc["id"], + "linking": doc["linking"], + "words": doc["words"] + }) + + fout.write(image_path + "\t" + json.dumps( + label_info, ensure_ascii=False) + "\n") + + print("===ok====") + + +transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json") diff --git a/ppstructure/vqa/images/input/zh_val_0.jpg b/ppstructure/vqa/images/input/zh_val_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..479b60bcd3a859b187ce5325dfc381c1b87ee27f Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_0.jpg differ diff --git a/ppstructure/vqa/images/input/zh_val_42.jpg b/ppstructure/vqa/images/input/zh_val_42.jpg new file mode 100644 index 0000000000000000000000000000000000000000..42151bdd94929ede9da1a63ce8d9339971094a46 Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_42.jpg differ diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..22ba9a6f1b7652ca9ce6848093c7a39affb4886b Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg differ diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..951864e5f35a987ff241f276c8da523d8c8eeaf3 Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg differ diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad220094a26b330555fbe9122a46fb56e64fe1e --- /dev/null +++ b/ppstructure/vqa/infer_ser.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import json +import cv2 +import numpy as np +from copy import deepcopy + +import paddle + +# relative reference +from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps +from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification + + +def pad_sentences(tokenizer, + encoded_inputs, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False): + # Padding with larger size, reshape is carried out + max_seq_len = ( + len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len + + needs_to_be_padded = pad_to_max_seq_len and \ + max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len + + if needs_to_be_padded: + difference = max_seq_len - len(encoded_inputs["input_ids"]) + if tokenizer.padding_side == 'right': + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ + "input_ids"]) + [0] * difference + if return_token_type_ids: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + + [tokenizer.pad_token_type_id] * difference) + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = encoded_inputs[ + "special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs[ + "input_ids"] + [tokenizer.pad_token_id] * difference + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0] + ] * difference + else: + assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}" + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ + "input_ids"]) + + return encoded_inputs + + +def split_page(encoded_inputs, max_seq_len=512): + """ + truncate is often used in training process + """ + for key in encoded_inputs: + encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) + if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on + encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) + else: # for bbox + encoded_inputs[key] = encoded_inputs[key].reshape( + [-1, max_seq_len, 4]) + return encoded_inputs + + +def preprocess( + tokenizer, + ori_img, + ocr_info, + img_size=(224, 224), + pad_token_label_id=-100, + max_seq_len=512, + add_special_ids=False, + return_attention_mask=True, ): + ocr_info = deepcopy(ocr_info) + height = ori_img.shape[0] + width = ori_img.shape[1] + + img = cv2.resize(ori_img, + (224, 224)).transpose([2, 0, 1]).astype(np.float32) + + segment_offset_id = [] + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + + for info in ocr_info: + # x1, y1, x2, y2 + bbox = info["bbox"] + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + + text = info["text"] + encode_res = tokenizer.encode( + text, pad_to_max_seq_len=False, return_attention_mask=True) + + if not add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] + encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] + + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend([bbox] * len(encode_res["input_ids"])) + words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + + encoded_inputs = { + "input_ids": input_ids_list, + "token_type_ids": token_type_ids_list, + "bbox": bbox_list, + "attention_mask": [1] * len(input_ids_list), + } + + encoded_inputs = pad_sentences( + tokenizer, + encoded_inputs, + max_seq_len=max_seq_len, + return_attention_mask=return_attention_mask) + + encoded_inputs = split_page(encoded_inputs) + + fake_bs = encoded_inputs["input_ids"].shape[0] + + encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand( + [fake_bs] + list(img.shape)) + + encoded_inputs["segment_offset_id"] = segment_offset_id + + return encoded_inputs + + +def postprocess(attention_mask, preds, label_map_path): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds = np.argmax(preds, axis=2) + + _, label_map = get_bio_label_maps(label_map_path) + + preds_list = [[] for _ in range(preds.shape[0])] + + # keep batch info + for i in range(preds.shape[0]): + for j in range(preds.shape[1]): + if attention_mask[i][j] == 1: + preds_list[i].append(label_map[preds[i][j]]) + + return preds_list + + +def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id, + preds_list): + # must ensure the preds_list is generated from the same image + preds = [p for pred in preds_list for p in pred] + label2id_map, _ = get_bio_label_maps(label_map_path) + for key in label2id_map: + if key.startswith("I-"): + label2id_map[key] = label2id_map["B" + key[1:]] + + id2label_map = dict() + for key in label2id_map: + val = label2id_map[key] + if key == "O": + id2label_map[val] = key + if key.startswith("B-") or key.startswith("I-"): + id2label_map[val] = key[2:] + else: + id2label_map[val] = key + + for idx in range(len(segment_offset_id)): + if idx == 0: + start_id = 0 + else: + start_id = segment_offset_id[idx - 1] + + end_id = segment_offset_id[idx] + + curr_pred = preds[start_id:end_id] + curr_pred = [label2id_map[p] for p in curr_pred] + + if len(curr_pred) <= 0: + pred_id = 0 + else: + counts = np.bincount(curr_pred) + pred_id = np.argmax(counts) + ocr_info[idx]["pred_id"] = int(pred_id) + ocr_info[idx]["pred"] = id2label_map[pred_id] + return ocr_info + + +@paddle.no_grad() +def infer(args): + os.makedirs(args.output_dir, exist_ok=True) + + # init token and model + tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) + # model = LayoutXLMModel.from_pretrained(args.model_name_or_path) + model = LayoutXLMForTokenClassification.from_pretrained( + args.model_name_or_path) + model.eval() + + # load ocr results json + ocr_results = dict() + with open(args.ocr_json_path, "r") as fin: + lines = fin.readlines() + for line in lines: + img_name, json_info = line.split("\t") + ocr_results[os.path.basename(img_name)] = json.loads(json_info) + + # get infer img list + infer_imgs = get_image_file_list(args.infer_imgs) + + # loop for infer + with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: + for idx, img_path in enumerate(infer_imgs): + print("process: [{}/{}]".format(idx, len(infer_imgs), img_path)) + + img = cv2.imread(img_path) + + ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"] + inputs = preprocess( + tokenizer=tokenizer, + ori_img=img, + ocr_info=ocr_info, + max_seq_len=args.max_seq_length) + + outputs = model( + input_ids=inputs["input_ids"], + bbox=inputs["bbox"], + image=inputs["image"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"]) + + preds = outputs[0] + preds = postprocess(inputs["attention_mask"], preds, + args.label_map_path) + ocr_info = merge_preds_list_with_ocr_info( + args.label_map_path, ocr_info, inputs["segment_offset_id"], + preds) + + fout.write(img_path + "\t" + json.dumps( + { + "ocr_info": ocr_info, + }, ensure_ascii=False) + "\n") + + img_res = draw_ser_results(img, ocr_info) + cv2.imwrite( + os.path.join(args.output_dir, os.path.basename(img_path)), + img_res) + + return + + +if __name__ == "__main__": + args = parse_args() + infer(args) diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..da027a140bdb4fa12a40d423998d94e438a7cd11 --- /dev/null +++ b/ppstructure/vqa/infer_ser_e2e.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import json +import cv2 +import numpy as np +from copy import deepcopy +from PIL import Image + +import paddle +from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification + +# relative reference +from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine + +from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info + + +def trans_poly_to_bbox(poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + + +def parse_ocr_info_for_ser(ocr_result): + ocr_info = [] + for res in ocr_result: + ocr_info.append({ + "text": res[1][0], + "bbox": trans_poly_to_bbox(res[0]), + "poly": res[0], + }) + return ocr_info + + +@paddle.no_grad() +def infer(args): + os.makedirs(args.output_dir, exist_ok=True) + + # init token and model + tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) + model = LayoutXLMForTokenClassification.from_pretrained( + args.model_name_or_path) + model.eval() + + label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) + label2id_map_for_draw = dict() + for key in label2id_map: + if key.startswith("I-"): + label2id_map_for_draw[key] = label2id_map["B" + key[1:]] + else: + label2id_map_for_draw[key] = label2id_map[key] + + # get infer img list + infer_imgs = get_image_file_list(args.infer_imgs) + + ocr_engine = build_ocr_engine(args.ocr_rec_model_dir, + args.ocr_det_model_dir) + + # loop for infer + with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: + for idx, img_path in enumerate(infer_imgs): + print("process: [{}/{}]".format(idx, len(infer_imgs), img_path)) + + img = cv2.imread(img_path) + + ocr_result = ocr_engine.ocr(img_path, cls=False) + + ocr_info = parse_ocr_info_for_ser(ocr_result) + + inputs = preprocess( + tokenizer=tokenizer, + ori_img=img, + ocr_info=ocr_info, + max_seq_len=args.max_seq_length) + + outputs = model( + input_ids=inputs["input_ids"], + bbox=inputs["bbox"], + image=inputs["image"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"]) + + preds = outputs[0] + preds = postprocess(inputs["attention_mask"], preds, id2label_map) + ocr_info = merge_preds_list_with_ocr_info( + ocr_info, inputs["segment_offset_id"], preds, + label2id_map_for_draw) + + fout.write(img_path + "\t" + json.dumps( + { + "ocr_info": ocr_info, + }, ensure_ascii=False) + "\n") + + img_res = draw_ser_results(img, ocr_info) + cv2.imwrite( + os.path.join(args.output_dir, + os.path.splitext(os.path.basename(img_path))[0] + + "_ser.jpg"), img_res) + + return + + +if __name__ == "__main__": + args = parse_args() + infer(args) diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt new file mode 100644 index 0000000000000000000000000000000000000000..508e48112412f62538baf0c78bcf99ec8945196e --- /dev/null +++ b/ppstructure/vqa/labels/labels_ser.txt @@ -0,0 +1,3 @@ +QUESTION +ANSWER +HEADER diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c259fadc395335b336cb0ecdb5aa6bca48631987 --- /dev/null +++ b/ppstructure/vqa/requirements.txt @@ -0,0 +1,2 @@ +sentencepiece +yacs diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py new file mode 100644 index 0000000000000000000000000000000000000000..90ca69d93fd22983533fcacd639bbd64dc3e11ec --- /dev/null +++ b/ppstructure/vqa/train_ser.py @@ -0,0 +1,313 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import random +import copy +import logging + +import argparse +import paddle +import numpy as np +from seqeval.metrics import classification_report, f1_score, precision_score, recall_score +from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification +from xfun import XFUNDataset +from utils import parse_args +from utils import get_bio_label_maps + +logger = logging.getLogger(__name__) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + paddle.seed(args.seed) + + +def train(args): + os.makedirs(args.output_dir, exist_ok=True) + logging.basicConfig( + filename=os.path.join(args.output_dir, "train.log") + if paddle.distributed.get_rank() == 0 else None, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO + if paddle.distributed.get_rank() == 0 else logging.WARN, ) + + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + logger.addHandler(ch) + + label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) + pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + + # dist mode + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) + base_model = LayoutXLMModel.from_pretrained(args.model_name_or_path) + model = LayoutXLMForTokenClassification( + base_model, num_classes=len(label2id_map), dropout=None) + + # dist mode + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + train_dataset = XFUNDataset( + tokenizer, + data_dir=args.train_data_dir, + label_path=args.train_label_path, + label2id_map=label2id_map, + img_size=(224, 224), + pad_token_label_id=pad_token_label_id, + contains_re=False, + add_special_ids=False, + return_attention_mask=True, + load_mode='all') + + train_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True) + + args.train_batch_size = args.per_gpu_train_batch_size * max( + 1, paddle.distributed.get_world_size()) + + train_dataloader = paddle.io.DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=0, + use_shared_memory=True, + collate_fn=None, ) + + t_total = len(train_dataloader) * args.num_train_epochs + + # build linear decay with warmup lr sch + lr_scheduler = paddle.optimizer.lr.PolynomialDecay( + learning_rate=args.learning_rate, + decay_steps=t_total, + end_lr=0.0, + power=1.0) + if args.warmup_steps > 0: + lr_scheduler = paddle.optimizer.lr.LinearWarmup( + lr_scheduler, + args.warmup_steps, + start_lr=0, + end_lr=args.learning_rate, ) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + parameters=model.parameters(), + epsilon=args.adam_epsilon, + weight_decay=args.weight_decay) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", + args.per_gpu_train_batch_size) + logger.info( + " Total train batch size (w. parallel, distributed) = %d", + args.train_batch_size * paddle.distributed.get_world_size(), ) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 0 + tr_loss = 0.0 + set_seed(args) + best_metrics = None + + for epoch_id in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + model.train() + outputs = model(**batch) + # model outputs are always tuple in ppnlp (see doc) + loss = outputs[0] + loss = loss.mean() + logger.info( + "[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ". + format(epoch_id, args.num_train_epochs, step, + len(train_dataloader), + lr_scheduler.get_lr(), loss.numpy()[0])) + + loss.backward() + tr_loss += loss.item() + optimizer.step() + lr_scheduler.step() # Update learning rate schedule + optimizer.clear_grad() + global_step += 1 + + if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and + global_step % args.eval_steps == 0): + # Log metrics + # Only evaluate when single GPU otherwise metrics may not average well + if paddle.distributed.get_rank( + ) == 0 and args.evaluate_during_training: + results, _ = evaluate( + args, + model, + tokenizer, + label2id_map, + id2label_map, + pad_token_label_id, ) + + if best_metrics is None or results["f1"] >= best_metrics[ + "f1"]: + best_metrics = copy.deepcopy(results) + output_dir = os.path.join(args.output_dir, "best_model") + os.makedirs(output_dir, exist_ok=True) + if paddle.distributed.get_rank() == 0: + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + paddle.save( + args, + os.path.join(output_dir, "training_args.bin")) + logger.info("Saving model checkpoint to %s", + output_dir) + + logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format( + epoch_id, args.num_train_epochs, step, + len(train_dataloader), results)) + if best_metrics is not None: + logger.info("best metrics: {}".format(best_metrics)) + + if paddle.distributed.get_rank( + ) == 0 and args.save_steps > 0 and global_step % args.save_steps == 0: + # Save model checkpoint + output_dir = os.path.join(args.output_dir, + "checkpoint-{}".format(global_step)) + os.makedirs(output_dir, exist_ok=True) + if paddle.distributed.get_rank() == 0: + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + paddle.save(args, + os.path.join(output_dir, "training_args.bin")) + logger.info("Saving model checkpoint to %s", output_dir) + + return global_step, tr_loss / global_step + + +def evaluate(args, + model, + tokenizer, + label2id_map, + id2label_map, + pad_token_label_id, + prefix=""): + eval_dataset = XFUNDataset( + tokenizer, + data_dir=args.eval_data_dir, + label_path=args.eval_label_path, + label2id_map=label2id_map, + img_size=(224, 224), + pad_token_label_id=pad_token_label_id, + contains_re=False, + add_special_ids=False, + return_attention_mask=True, + load_mode='all') + + args.eval_batch_size = args.per_gpu_eval_batch_size * max( + 1, paddle.distributed.get_world_size()) + + eval_dataloader = paddle.io.DataLoader( + eval_dataset, + batch_size=args.eval_batch_size, + num_workers=0, + use_shared_memory=True, + collate_fn=None, ) + + # Eval! + logger.info("***** Running evaluation %s *****", prefix) + logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + eval_loss = 0.0 + nb_eval_steps = 0 + preds = None + out_label_ids = None + model.eval() + for idx, batch in enumerate(eval_dataloader): + with paddle.no_grad(): + outputs = model(**batch) + tmp_eval_loss, logits = outputs[:2] + + tmp_eval_loss = tmp_eval_loss.mean() + + if paddle.distributed.get_rank() == 0: + logger.info("[Eval]process: {}/{}, loss: {:.5f}".format( + idx, len(eval_dataloader), tmp_eval_loss.numpy()[0])) + + eval_loss += tmp_eval_loss.item() + nb_eval_steps += 1 + if preds is None: + preds = logits.numpy() + out_label_ids = batch["labels"].numpy() + else: + preds = np.append(preds, logits.numpy(), axis=0) + out_label_ids = np.append( + out_label_ids, batch["labels"].numpy(), axis=0) + + eval_loss = eval_loss / nb_eval_steps + preds = np.argmax(preds, axis=2) + + # label_map = {i: label.upper() for i, label in enumerate(labels)} + + out_label_list = [[] for _ in range(out_label_ids.shape[0])] + preds_list = [[] for _ in range(out_label_ids.shape[0])] + + for i in range(out_label_ids.shape[0]): + for j in range(out_label_ids.shape[1]): + if out_label_ids[i, j] != pad_token_label_id: + out_label_list[i].append(id2label_map[out_label_ids[i][j]]) + preds_list[i].append(id2label_map[preds[i][j]]) + + results = { + "loss": eval_loss, + "precision": precision_score(out_label_list, preds_list), + "recall": recall_score(out_label_list, preds_list), + "f1": f1_score(out_label_list, preds_list), + } + + with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout: + for lbl in out_label_list: + for l in lbl: + fout.write(l + "\t") + fout.write("\n") + with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout: + for lbl in preds_list: + for l in lbl: + fout.write(l + "\t") + fout.write("\n") + + report = classification_report(out_label_list, preds_list) + logger.info("\n" + report) + + logger.info("***** Eval results %s *****", prefix) + for key in sorted(results.keys()): + logger.info(" %s = %s", key, str(results[key])) + + return results, preds_list + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + train(args) diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ac1e77d37d0a662294480a393c2f67e7f4cc64 --- /dev/null +++ b/ppstructure/vqa/utils.py @@ -0,0 +1,328 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import cv2 +import random +import numpy as np +import imghdr +from copy import deepcopy + +import paddle + +from PIL import Image, ImageDraw, ImageFont + +from paddleocr import PaddleOCR + + +def get_bio_label_maps(label_map_path): + with open(label_map_path, "r") as fin: + lines = fin.readlines() + lines = [line.strip() for line in lines] + if "O" not in lines: + lines.insert(0, "O") + labels = [] + for line in lines: + if line == "O": + labels.append("O") + else: + labels.append("B-" + line) + labels.append("I-" + line) + label2id_map = {label: idx for idx, label in enumerate(labels)} + id2label_map = {idx: label for idx, label in enumerate(labels)} + return label2id_map, id2label_map + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'} + if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and imghdr.what(file_path) in img_end: + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +def draw_ser_results(image, + ocr_results, + font_path="../doc/fonts/simfang.ttf", + font_size=18): + np.random.seed(0) + color = (np.random.permutation(range(255)), + np.random.permutation(range(255)), + np.random.permutation(range(255))) + color_map = { + idx: (color[0][idx], color[1][idx], color[2][idx]) + for idx in range(1, 255) + } + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + img_new = image.copy() + draw = ImageDraw.Draw(img_new) + + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + + for ocr_info in ocr_results: + if ocr_info["pred_id"] not in color_map: + continue + color = color_map[ocr_info["pred_id"]] + + # draw ocr results outline + bbox = ocr_info["bbox"] + bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) + draw.rectangle(bbox, fill=color) + + # draw ocr results + text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) + start_y = max(0, bbox[0][1] - font_size) + tw = font.getsize(text)[0] + draw.rectangle( + [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, + start_y + font_size)], + fill=(0, 0, 255)) + draw.text( + (bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) + + img_new = Image.blend(image, img_new, 0.5) + return np.array(img_new) + + +def build_ocr_engine(rec_model_dir, det_model_dir): + ocr_engine = PaddleOCR( + rec_model_dir=rec_model_dir, + det_model_dir=det_model_dir, + use_angle_cls=False) + return ocr_engine + + +# pad sentences +def pad_sentences(tokenizer, + encoded_inputs, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False): + # Padding with larger size, reshape is carried out + max_seq_len = ( + len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len + + needs_to_be_padded = pad_to_max_seq_len and \ + max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len + + if needs_to_be_padded: + difference = max_seq_len - len(encoded_inputs["input_ids"]) + if tokenizer.padding_side == 'right': + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ + "input_ids"]) + [0] * difference + if return_token_type_ids: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + + [tokenizer.pad_token_type_id] * difference) + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = encoded_inputs[ + "special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs[ + "input_ids"] + [tokenizer.pad_token_id] * difference + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0] + ] * difference + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ + "input_ids"]) + + return encoded_inputs + + +def split_page(encoded_inputs, max_seq_len=512): + """ + truncate is often used in training process + """ + for key in encoded_inputs: + encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) + if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on + encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) + else: # for bbox + encoded_inputs[key] = encoded_inputs[key].reshape( + [-1, max_seq_len, 4]) + return encoded_inputs + + +def preprocess( + tokenizer, + ori_img, + ocr_info, + img_size=(224, 224), + pad_token_label_id=-100, + max_seq_len=512, + add_special_ids=False, + return_attention_mask=True, ): + ocr_info = deepcopy(ocr_info) + height = ori_img.shape[0] + width = ori_img.shape[1] + + img = cv2.resize(ori_img, + (224, 224)).transpose([2, 0, 1]).astype(np.float32) + + segment_offset_id = [] + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + + for info in ocr_info: + # x1, y1, x2, y2 + bbox = info["bbox"] + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + + text = info["text"] + encode_res = tokenizer.encode( + text, pad_to_max_seq_len=False, return_attention_mask=True) + + if not add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] + encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] + + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend([bbox] * len(encode_res["input_ids"])) + words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + + encoded_inputs = { + "input_ids": input_ids_list, + "token_type_ids": token_type_ids_list, + "bbox": bbox_list, + "attention_mask": [1] * len(input_ids_list), + } + + encoded_inputs = pad_sentences( + tokenizer, + encoded_inputs, + max_seq_len=max_seq_len, + return_attention_mask=return_attention_mask) + + encoded_inputs = split_page(encoded_inputs) + + fake_bs = encoded_inputs["input_ids"].shape[0] + + encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand( + [fake_bs] + list(img.shape)) + + encoded_inputs["segment_offset_id"] = segment_offset_id + + return encoded_inputs + + +def postprocess(attention_mask, preds, id2label_map): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds = np.argmax(preds, axis=2) + + preds_list = [[] for _ in range(preds.shape[0])] + + # keep batch info + for i in range(preds.shape[0]): + for j in range(preds.shape[1]): + if attention_mask[i][j] == 1: + preds_list[i].append(id2label_map[preds[i][j]]) + + return preds_list + + +def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list, + label2id_map_for_draw): + # must ensure the preds_list is generated from the same image + preds = [p for pred in preds_list for p in pred] + + id2label_map = dict() + for key in label2id_map_for_draw: + val = label2id_map_for_draw[key] + if key == "O": + id2label_map[val] = key + if key.startswith("B-") or key.startswith("I-"): + id2label_map[val] = key[2:] + else: + id2label_map[val] = key + + for idx in range(len(segment_offset_id)): + if idx == 0: + start_id = 0 + else: + start_id = segment_offset_id[idx - 1] + + end_id = segment_offset_id[idx] + + curr_pred = preds[start_id:end_id] + curr_pred = [label2id_map_for_draw[p] for p in curr_pred] + + if len(curr_pred) <= 0: + pred_id = 0 + else: + counts = np.bincount(curr_pred) + pred_id = np.argmax(counts) + ocr_info[idx]["pred_id"] = int(pred_id) + ocr_info[idx]["pred"] = id2label_map[int(pred_id)] + return ocr_info + + +def parse_args(): + parser = argparse.ArgumentParser() + # Required parameters + # yapf: disable + parser.add_argument("--model_name_or_path", default=None, type=str, required=True,) + parser.add_argument("--train_data_dir", default=None, type=str, required=False,) + parser.add_argument("--train_label_path", default=None, type=str, required=False,) + parser.add_argument("--eval_data_dir", default=None, type=str, required=False,) + parser.add_argument("--eval_label_path", default=None, type=str, required=False,) + parser.add_argument("--output_dir", default=None, type=str, required=True,) + parser.add_argument("--max_seq_length", default=512, type=int,) + parser.add_argument("--evaluate_during_training", action="store_true",) + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",) + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",) + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",) + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",) + parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",) + parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",) + parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",) + parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",) + + parser.add_argument("--ocr_rec_model_dir", default=None, type=str, ) + parser.add_argument("--ocr_det_model_dir", default=None, type=str, ) + parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) + parser.add_argument("--infer_imgs", default=None, type=str, required=False) + parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results") + # yapf: enable + args = parser.parse_args() + return args diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py new file mode 100644 index 0000000000000000000000000000000000000000..d62cdb5da5514280b62687d80d345ede9484ee90 --- /dev/null +++ b/ppstructure/vqa/xfun.py @@ -0,0 +1,442 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import cv2 +import numpy as np +import paddle +import copy +from paddle.io import Dataset + +__all__ = ["XFUNDataset"] + + +class XFUNDataset(Dataset): + """ + Example: + print("=====begin to build dataset=====") + from paddlenlp.transformers import LayoutXLMTokenizer + tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/") + tok_res = tokenizer.tokenize("Maribyrnong") + # res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0]) + dataset = XfunDatasetForSer( + tokenizer, + data_dir="./zh.val/", + label_path="zh.val/xfun_normalize_val.json", + img_size=(224,224)) + print(len(dataset)) + + data = dataset[0] + print(data.keys()) + print("input_ids: ", data["input_ids"]) + print("labels: ", data["labels"]) + print("token_type_ids: ", data["token_type_ids"]) + print("words_list: ", data["words_list"]) + print("image shape: ", data["image"].shape) + """ + + def __init__(self, + tokenizer, + data_dir, + label_path, + contains_re=False, + label2id_map=None, + img_size=(224, 224), + pad_token_label_id=None, + add_special_ids=False, + return_attention_mask=True, + load_mode='all', + max_seq_len=512): + super().__init__() + self.tokenizer = tokenizer + self.data_dir = data_dir + self.label_path = label_path + self.contains_re = contains_re + self.label2id_map = label2id_map + self.img_size = img_size + self.pad_token_label_id = pad_token_label_id + self.add_special_ids = add_special_ids + self.return_attention_mask = return_attention_mask + self.load_mode = load_mode + self.max_seq_len = max_seq_len + + if self.pad_token_label_id is None: + self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + + self.all_lines = self.read_all_lines() + + self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} + self.return_keys = { + 'bbox': 'np', + 'input_ids': 'np', + 'labels': 'np', + 'attention_mask': 'np', + 'image': 'np', + 'token_type_ids': 'np', + 'entities': 'dict', + 'relations': 'dict', + } + + if load_mode == "all": + self.encoded_inputs_all = self._parse_label_file_all() + + def pad_sentences(self, + encoded_inputs, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + truncation_strategy="longest_first", + return_overflowing_tokens=False, + return_special_tokens_mask=False): + # Padding + needs_to_be_padded = pad_to_max_seq_len and \ + max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len + + if needs_to_be_padded: + difference = max_seq_len - len(encoded_inputs["input_ids"]) + if self.tokenizer.padding_side == 'right': + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ + "input_ids"]) + [0] * difference + if return_token_type_ids: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + + [self.tokenizer.pad_token_type_id] * difference) + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = encoded_inputs[ + "special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs[ + "input_ids"] + [self.tokenizer.pad_token_id] * difference + encoded_inputs["labels"] = encoded_inputs[ + "labels"] + [self.pad_token_label_id] * difference + encoded_inputs["bbox"] = encoded_inputs[ + "bbox"] + [[0, 0, 0, 0]] * difference + elif self.tokenizer.padding_side == 'left': + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + [ + 1 + ] * len(encoded_inputs["input_ids"]) + if return_token_type_ids: + encoded_inputs["token_type_ids"] = ( + [self.tokenizer.pad_token_type_id] * difference + + encoded_inputs["token_type_ids"]) + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = [ + 1 + ] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [ + self.tokenizer.pad_token_id + ] * difference + encoded_inputs["input_ids"] + encoded_inputs["labels"] = [ + self.pad_token_label_id + ] * difference + encoded_inputs["labels"] + encoded_inputs["bbox"] = [ + [0, 0, 0, 0] + ] * difference + encoded_inputs["bbox"] + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ + "input_ids"]) + + return encoded_inputs + + def truncate_inputs(self, encoded_inputs, max_seq_len=512): + for key in encoded_inputs: + if key == "sample_id": + continue + length = min(len(encoded_inputs[key]), max_seq_len) + encoded_inputs[key] = encoded_inputs[key][:length] + return encoded_inputs + + def read_all_lines(self, ): + with open(self.label_path, "r") as fin: + lines = fin.readlines() + return lines + + def _parse_label_file_all(self): + """ + parse all samples + """ + encoded_inputs_all = [] + for line in self.all_lines: + encoded_inputs_all.extend(self._parse_label_file(line)) + return encoded_inputs_all + + def _parse_label_file(self, line): + """ + parse single sample + """ + + image_name, info_str = line.split("\t") + image_path = os.path.join(self.data_dir, image_name) + + def add_imgge_path(x): + x['image_path'] = image_path + return x + + encoded_inputs = self._read_encoded_inputs_sample(info_str) + if self.contains_re: + encoded_inputs = self._chunk_re(encoded_inputs) + else: + encoded_inputs = self._chunk_ser(encoded_inputs) + encoded_inputs = list(map(add_imgge_path, encoded_inputs)) + return encoded_inputs + + def _read_encoded_inputs_sample(self, info_str): + """ + parse label info + """ + # read text info + info_dict = json.loads(info_str) + height = info_dict["height"] + width = info_dict["width"] + + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + gt_label_list = [] + + if self.contains_re: + # for re + entities = [] + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() + for info in info_dict["ocr_info"]: + if self.contains_re: + # for re + if len(info["text"]) == 0: + empty_entity.add(info["id"]) + continue + id2label[info["id"]] = info["label"] + relations.extend([tuple(sorted(l)) for l in info["linking"]]) + + # x1, y1, x2, y2 + bbox = info["bbox"] + label = info["label"] + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + + text = info["text"] + encode_res = self.tokenizer.encode( + text, pad_to_max_seq_len=False, return_attention_mask=True) + + gt_label = [] + if not self.add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1: + -1] + encode_res["attention_mask"] = encode_res["attention_mask"][1: + -1] + if label.lower() == "other": + gt_label.extend([0] * len(encode_res["input_ids"])) + else: + gt_label.append(self.label2id_map[("b-" + label).upper()]) + gt_label.extend([self.label2id_map[("i-" + label).upper()]] * + (len(encode_res["input_ids"]) - 1)) + if self.contains_re: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + entities.append({ + "start": len(input_ids_list), + "end": + len(input_ids_list) + len(encode_res["input_ids"]), + "label": label.upper(), + }) + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend([bbox] * len(encode_res["input_ids"])) + gt_label_list.extend(gt_label) + words_list.append(text) + + encoded_inputs = { + "input_ids": input_ids_list, + "labels": gt_label_list, + "token_type_ids": token_type_ids_list, + "bbox": bbox_list, + "attention_mask": [1] * len(input_ids_list), + # "words_list": words_list, + } + encoded_inputs = self.pad_sentences( + encoded_inputs, + max_seq_len=self.max_seq_len, + return_attention_mask=self.return_attention_mask) + encoded_inputs = self.truncate_inputs(encoded_inputs) + + if self.contains_re: + relations = self._relations(entities, relations, id2label, + empty_entity, entity_id_to_index_map) + encoded_inputs['relations'] = relations + encoded_inputs['entities'] = entities + return encoded_inputs + + def _chunk_ser(self, encoded_inputs): + encoded_inputs_all = [] + seq_len = len(encoded_inputs['input_ids']) + chunk_size = 512 + for chunk_id, index in enumerate(range(0, seq_len, chunk_size)): + chunk_beg = index + chunk_end = min(index + chunk_size, seq_len) + encoded_inputs_example = {} + for key in encoded_inputs: + encoded_inputs_example[key] = encoded_inputs[key][chunk_beg: + chunk_end] + + encoded_inputs_all.append(encoded_inputs_example) + return encoded_inputs_all + + def _chunk_re(self, encoded_inputs): + # prepare data + entities = encoded_inputs.pop('entities') + relations = encoded_inputs.pop('relations') + encoded_inputs_all = [] + chunk_size = 512 + for chunk_id, index in enumerate( + range(0, len(encoded_inputs["input_ids"]), chunk_size)): + item = {} + for k in encoded_inputs: + item[k] = encoded_inputs[k][index:index + chunk_size] + + # select entity in current chunk + entities_in_this_span = [] + global_to_local_map = {} # + for entity_id, entity in enumerate(entities): + if (index <= entity["start"] < index + chunk_size and + index <= entity["end"] < index + chunk_size): + entity["start"] = entity["start"] - index + entity["end"] = entity["end"] - index + global_to_local_map[entity_id] = len(entities_in_this_span) + entities_in_this_span.append(entity) + + # select relations in current chunk + relations_in_this_span = [] + for relation in relations: + if (index <= relation["start_index"] < index + chunk_size and + index <= relation["end_index"] < index + chunk_size): + relations_in_this_span.append({ + "head": global_to_local_map[relation["head"]], + "tail": global_to_local_map[relation["tail"]], + "start_index": relation["start_index"] - index, + "end_index": relation["end_index"] - index, + }) + item.update({ + "entities": reformat(entities_in_this_span), + "relations": reformat(relations_in_this_span), + }) + item['entities']['label'] = [ + self.entities_labels[x] for x in item['entities']['label'] + ] + encoded_inputs_all.append(item) + return encoded_inputs_all + + def _relations(self, entities, relations, id2label, empty_entity, + entity_id_to_index_map): + """ + build relations + """ + relations = list(set(relations)) + relations = [ + rel for rel in relations + if rel[0] not in empty_entity and rel[1] not in empty_entity + ] + kv_relations = [] + for rel in relations: + pair = [id2label[rel[0]], id2label[rel[1]]] + if pair == ["question", "answer"]: + kv_relations.append({ + "head": entity_id_to_index_map[rel[0]], + "tail": entity_id_to_index_map[rel[1]] + }) + elif pair == ["answer", "question"]: + kv_relations.append({ + "head": entity_id_to_index_map[rel[1]], + "tail": entity_id_to_index_map[rel[0]] + }) + else: + continue + relations = sorted( + [{ + "head": rel["head"], + "tail": rel["tail"], + "start_index": get_relation_span(rel, entities)[0], + "end_index": get_relation_span(rel, entities)[1], + } for rel in kv_relations], + key=lambda x: x["head"], ) + return relations + + def load_img(self, image_path): + # read img + img = cv2.imread(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + resize_h, resize_w = self.img_size + im_shape = img.shape[0:2] + im_scale_y = resize_h / im_shape[0] + im_scale_x = resize_w / im_shape[1] + img_new = cv2.resize( + img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2) + mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :] + std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :] + img_new = img_new / 255.0 + img_new -= mean + img_new /= std + img = img_new.transpose((2, 0, 1)) + return img + + def __getitem__(self, idx): + if self.load_mode == "all": + data = copy.deepcopy(self.encoded_inputs_all[idx]) + else: + data = self._parse_label_file(self.all_lines[idx])[0] + + image_path = data.pop('image_path') + data["image"] = self.load_img(image_path) + + return_data = {} + for k, v in data.items(): + if k in self.return_keys: + if self.return_keys[k] == 'np': + v = np.array(v) + return_data[k] = v + return return_data + + def __len__(self, ): + if self.load_mode == "all": + return len(self.encoded_inputs_all) + else: + return len(self.all_lines) + + +def get_relation_span(rel, entities): + bound = [] + for entity_index in [rel["head"], rel["tail"]]: + bound.append(entities[entity_index]["start"]) + bound.append(entities[entity_index]["end"]) + return min(bound), max(bound) + + +def reformat(data): + new_data = {} + for item in data: + for k, v in item.items(): + if k not in new_data: + new_data[k] = [] + new_data[k].append(v) + return new_data