diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index 8d117fdeb16e1c0e90bf6ec89924e414fc764249..23fe28f8494ce84e774c3dd21811003f772c41f8 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -1,42 +1,62 @@ -# 视觉问答(VQA) +# 文档视觉问答(DOC-VQA) -VQA主要特性如下: +VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。 + +PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。 + +主要特性如下: - 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。 -- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对) -- 支持SER任务与OCR引擎联合的端到端系统预测与评估。 -- 支持SER任务和RE任务的自定义训练 +- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。 +- 支持SER任务和RE任务的自定义训练。 +- 支持OCR+SER的端到端系统预测与评估。 +- 支持OCR+SER+RE的端到端系统预测。 本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现, 包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。 -## 1. 效果演示 +## 1 性能 + +我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下 + +|任务| f1 | 模型下载地址| +|:---:|:---:| :---:| +|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)| +|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)| + + + +## 2. 效果演示 **注意:** 测试图片来源于XFUN数据集。 -### 1.1 SER +### 2.1 SER -
- -
+![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg) +---|--- -
- -
+图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别 -其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 +* 深紫色:HEADER +* 浅紫色:QUESTION +* 军绿色:ANSWER +在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 -### 1.2 RE -* Coming soon! +### 2.2 RE +![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg) +---|--- -## 2. 安装 +图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 -### 2.1 安装依赖 + +## 3. 安装 + +### 3.1 安装依赖 - **(1) 安装PaddlePaddle** @@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 -### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA ) +### 3.2 安装PaddleOCR(包含 PP-OCR 和 VQA ) - **(1)pip快速安装PaddleOCR whl包(仅预测)** ```bash -pip install "paddleocr>=2.2" # 推荐使用2.2+版本 +pip install paddleocr ``` - **(2)下载VQA源码(预测+训练)** @@ -85,13 +105,14 @@ pip install -e . - **(4)安装VQA的`requirements`** ```bash +cd ppstructure/vqa pip install -r requirements.txt ``` -## 3. 使用 +## 4. 使用 -### 3.1 数据和预训练模型准备 +### 4.1 数据和预训练模型准备 处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。 @@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar 如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。 -如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。 - -* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) -* RE任务预训练模型下载链接:coming soon! +如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。 -### 3.2 SER任务 +### 4.2 SER任务 * 启动训练 ```shell -python train_ser.py \ +python3.7 train_ser.py \ --model_name_or_path "layoutxlm-base-uncased" \ --train_data_dir "XFUND/zh_train/image" \ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ @@ -131,13 +149,7 @@ python train_ser.py \ --seed 2048 ``` -最终会打印出`precision`, `recall`, `f1`等指标,如下所示。 - -``` -best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063} -``` - -模型和训练日志会保存在`./output/ser/`文件夹中。 +最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。 * 使用评估集合中提供的OCR识别结果进行预测 @@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0 python3.7 infer_ser_e2e.py \ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ --max_seq_length 512 \ - --output_dir "output_res_e2e/" + --output_dir "output_res_e2e/" \ + --infer_imgs "images/input/zh_val_0.jpg" ``` * 对`OCR引擎 + SER`预测系统进行端到端评估 ```shell export CUDA_VISIBLE_DEVICES=0 -python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt +python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt ``` -3.3 RE任务 +### 3.3 RE任务 -coming soon! +* 启动训练 +```shell +python3 train_re.py \ + --model_name_or_path "layoutxlm-base-uncased" \ + --train_data_dir "XFUND/zh_train/image" \ + --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ + --eval_data_dir "XFUND/zh_val/image" \ + --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ + --label_map_path 'labels/labels_ser.txt' \ + --num_train_epochs 2 \ + --eval_steps 10 \ + --save_steps 500 \ + --output_dir "output/re/" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --per_gpu_train_batch_size 8 \ + --per_gpu_eval_batch_size 8 \ + --evaluate_during_training \ + --seed 2048 + +``` + +最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。 + +* 使用评估集合中提供的OCR识别结果进行预测 + +```shell +export CUDA_VISIBLE_DEVICES=0 +python3 infer_re.py \ + --model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \ + --max_seq_length 512 \ + --eval_data_dir "XFUND/zh_val/image" \ + --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ + --label_map_path 'labels/labels_ser.txt' \ + --output_dir "output_res" \ + --per_gpu_eval_batch_size 1 \ + --seed 2048 +``` + +最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。 + +* 使用`OCR引擎 + SER + RE`串联结果 + +```shell +export CUDA_VISIBLE_DEVICES=0 +# python3.7 infer_ser_re_e2e.py \ + --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ + --re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \ + --max_seq_length 512 \ + --output_dir "output_ser_re_e2e_train/" \ + --infer_imgs "images/input/zh_val_21.jpg" +``` ## 参考链接 diff --git a/ppstructure/vqa/data_collator.py b/ppstructure/vqa/data_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..a969935b487e3d22ea5c4a3527028aa2cfe1a797 --- /dev/null +++ b/ppstructure/vqa/data_collator.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numbers +import numpy as np + + +class DataCollator: + """ + data batch + """ + + def __call__(self, batch): + data_dict = {} + to_tensor_keys = [] + for sample in batch: + for k, v in sample.items(): + if k not in data_dict: + data_dict[k] = [] + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if k not in to_tensor_keys: + to_tensor_keys.append(k) + data_dict[k].append(v) + for k in to_tensor_keys: + data_dict[k] = paddle.to_tensor(data_dict[k]) + return data_dict diff --git a/ppstructure/vqa/images/input/zh_val_21.jpg b/ppstructure/vqa/images/input/zh_val_21.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35b572d7dd6a6b42cf43a8a4b33567c0af527d30 Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_21.jpg differ diff --git a/ppstructure/vqa/images/input/zh_val_40.jpg b/ppstructure/vqa/images/input/zh_val_40.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2a858cc33d54831335c209146853b6c302c734f8 Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_40.jpg differ diff --git a/ppstructure/vqa/images/result_re/zh_val_21_re.jpg b/ppstructure/vqa/images/result_re/zh_val_21_re.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7bf248dd0e69057c4775ff9c205317044e94ee65 Binary files /dev/null and b/ppstructure/vqa/images/result_re/zh_val_21_re.jpg differ diff --git a/ppstructure/vqa/images/result_re/zh_val_40_re.jpg b/ppstructure/vqa/images/result_re/zh_val_40_re.jpg new file mode 100644 index 0000000000000000000000000000000000000000..242f9d6e80be39c595d98b57d59d48673ce62f20 Binary files /dev/null and b/ppstructure/vqa/images/result_re/zh_val_40_re.jpg differ diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg index 22ba9a6f1b7652ca9ce6848093c7a39affb4886b..4605c3a7f395e9868ba55cd31a99367694c78f5c 100644 Binary files a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg and b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg differ diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg index 951864e5f35a987ff241f276c8da523d8c8eeaf3..13bc7272e49a03115085d4a7420a7acfb92d3260 100644 Binary files a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg and b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg differ diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2f52550294b072179c3bdba28c3572369e11a3 --- /dev/null +++ b/ppstructure/vqa/infer_re.py @@ -0,0 +1,162 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import random + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import paddle + +from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction + +from xfun import XFUNDataset +from utils import parse_args, get_bio_label_maps, draw_re_results +from data_collator import DataCollator + +from ppocr.utils.logging import get_logger + + +def infer(args): + os.makedirs(args.output_dir, exist_ok=True) + logger = get_logger() + label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) + pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + + tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) + + model = LayoutXLMForRelationExtraction.from_pretrained( + args.model_name_or_path) + + eval_dataset = XFUNDataset( + tokenizer, + data_dir=args.eval_data_dir, + label_path=args.eval_label_path, + label2id_map=label2id_map, + img_size=(224, 224), + max_seq_len=args.max_seq_length, + pad_token_label_id=pad_token_label_id, + contains_re=True, + add_special_ids=False, + return_attention_mask=True, + load_mode='all') + + eval_dataloader = paddle.io.DataLoader( + eval_dataset, + batch_size=args.per_gpu_eval_batch_size, + num_workers=8, + shuffle=False, + collate_fn=DataCollator()) + + # 读取gt的oct数据 + ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path) + + for idx, batch in enumerate(eval_dataloader): + logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader))) + with paddle.no_grad(): + outputs = model(**batch) + pred_relations = outputs['pred_relations'] + + ocr_info = ocr_info_list[idx] + image_path = ocr_info['image_path'] + ocr_info = ocr_info['ocr_info'] + + # 根据entity里的信息,做token解码后去过滤不要的ocr_info + ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer) + + # 进行 relations 到 ocr信息的转换 + result = [] + used_tail_id = [] + for relations in pred_relations: + for relation in relations: + if relation['tail_id'] in used_tail_id: + continue + if relation['head_id'] not in ocr_info or relation[ + 'tail_id'] not in ocr_info: + continue + used_tail_id.append(relation['tail_id']) + ocr_info_head = ocr_info[relation['head_id']] + ocr_info_tail = ocr_info[relation['tail_id']] + result.append((ocr_info_head, ocr_info_tail)) + + img = cv2.imread(image_path) + img_show = draw_re_results(img, result) + save_path = os.path.join(args.output_dir, os.path.basename(image_path)) + cv2.imwrite(save_path, img_show) + + +def load_ocr(img_folder, json_path): + import json + d = [] + with open(json_path, "r") as fin: + lines = fin.readlines() + for line in lines: + image_name, info_str = line.split("\t") + info_dict = json.loads(info_str) + info_dict['image_path'] = os.path.join(img_folder, image_name) + d.append(info_dict) + return d + + +def filter_bg_by_txt(ocr_info, batch, tokenizer): + entities = batch['entities'][0] + input_ids = batch['input_ids'][0] + + new_info_dict = {} + for i in range(len(entities['start'])): + entitie_head = entities['start'][i] + entitie_tail = entities['end'][i] + word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist() + txt = tokenizer.convert_ids_to_tokens(word_input_ids) + txt = tokenizer.convert_tokens_to_string(txt) + + for i, info in enumerate(ocr_info): + if info['text'] == txt: + new_info_dict[i] = info + return new_info_dict + + +def post_process(pred_relations, ocr_info, img): + result = [] + for relations in pred_relations: + for relation in relations: + ocr_info_head = ocr_info[relation['head_id']] + ocr_info_tail = ocr_info[relation['tail_id']] + result.append((ocr_info_head, ocr_info_tail)) + return result + + +def draw_re(result, image_path, output_folder): + img = cv2.imread(image_path) + + from matplotlib import pyplot as plt + for ocr_info_head, ocr_info_tail in result: + cv2.rectangle( + img, + tuple(ocr_info_head['bbox'][:2]), + tuple(ocr_info_head['bbox'][2:]), (255, 0, 0), + thickness=2) + cv2.rectangle( + img, + tuple(ocr_info_tail['bbox'][:2]), + tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255), + thickness=2) + center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, + (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2] + center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, + (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2] + cv2.line( + img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2) + plt.imshow(img) + plt.savefig( + os.path.join(output_folder, os.path.basename(image_path)), dpi=600) + # plt.show() + + +if __name__ == "__main__": + args = parse_args() + infer(args) diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py index da027a140bdb4fa12a40d423998d94e438a7cd11..1638e78a11105feb1cb037a545005b2384672eb8 100644 --- a/ppstructure/vqa/infer_ser_e2e.py +++ b/ppstructure/vqa/infer_ser_e2e.py @@ -23,8 +23,10 @@ from PIL import Image import paddle from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification +from paddleocr import PaddleOCR + # relative reference -from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine +from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info @@ -48,74 +50,82 @@ def parse_ocr_info_for_ser(ocr_result): return ocr_info -@paddle.no_grad() -def infer(args): - os.makedirs(args.output_dir, exist_ok=True) +class SerPredictor(object): + def __init__(self, args): + self.max_seq_length = args.max_seq_length + + # init ser token and model + self.tokenizer = LayoutXLMTokenizer.from_pretrained( + args.model_name_or_path) + self.model = LayoutXLMForTokenClassification.from_pretrained( + args.model_name_or_path) + self.model.eval() + + # init ocr_engine + self.ocr_engine = PaddleOCR( + rec_model_dir=args.ocr_rec_model_dir, + det_model_dir=args.ocr_det_model_dir, + use_angle_cls=False, + show_log=False) + # init dict + label2id_map, self.id2label_map = get_bio_label_maps( + args.label_map_path) + self.label2id_map_for_draw = dict() + for key in label2id_map: + if key.startswith("I-"): + self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] + else: + self.label2id_map_for_draw[key] = label2id_map[key] + + def __call__(self, img): + ocr_result = self.ocr_engine.ocr(img, cls=False) + + ocr_info = parse_ocr_info_for_ser(ocr_result) + + inputs = preprocess( + tokenizer=self.tokenizer, + ori_img=img, + ocr_info=ocr_info, + max_seq_len=self.max_seq_length) + + outputs = self.model( + input_ids=inputs["input_ids"], + bbox=inputs["bbox"], + image=inputs["image"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"]) + + preds = outputs[0] + preds = postprocess(inputs["attention_mask"], preds, self.id2label_map) + ocr_info = merge_preds_list_with_ocr_info( + ocr_info, inputs["segment_offset_id"], preds, + self.label2id_map_for_draw) + return ocr_info, inputs - # init token and model - tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) - model = LayoutXLMForTokenClassification.from_pretrained( - args.model_name_or_path) - model.eval() - label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) - label2id_map_for_draw = dict() - for key in label2id_map: - if key.startswith("I-"): - label2id_map_for_draw[key] = label2id_map["B" + key[1:]] - else: - label2id_map_for_draw[key] = label2id_map[key] +if __name__ == "__main__": + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) # get infer img list infer_imgs = get_image_file_list(args.infer_imgs) - ocr_engine = build_ocr_engine(args.ocr_rec_model_dir, - args.ocr_det_model_dir) - # loop for infer + ser_engine = SerPredictor(args) with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: for idx, img_path in enumerate(infer_imgs): - print("process: [{}/{}]".format(idx, len(infer_imgs), img_path)) + print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) img = cv2.imread(img_path) - ocr_result = ocr_engine.ocr(img_path, cls=False) - - ocr_info = parse_ocr_info_for_ser(ocr_result) - - inputs = preprocess( - tokenizer=tokenizer, - ori_img=img, - ocr_info=ocr_info, - max_seq_len=args.max_seq_length) - - outputs = model( - input_ids=inputs["input_ids"], - bbox=inputs["bbox"], - image=inputs["image"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"]) - - preds = outputs[0] - preds = postprocess(inputs["attention_mask"], preds, id2label_map) - ocr_info = merge_preds_list_with_ocr_info( - ocr_info, inputs["segment_offset_id"], preds, - label2id_map_for_draw) - + result, _ = ser_engine(img) fout.write(img_path + "\t" + json.dumps( { - "ocr_info": ocr_info, + "ser_resule": result, }, ensure_ascii=False) + "\n") - img_res = draw_ser_results(img, ocr_info) + img_res = draw_ser_results(img, result) cv2.imwrite( os.path.join(args.output_dir, os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg"), img_res) - - return - - -if __name__ == "__main__": - args = parse_args() - infer(args) diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d0f52eeecbc6c2ceba5964355008f638f371dd --- /dev/null +++ b/ppstructure/vqa/infer_ser_re_e2e.py @@ -0,0 +1,131 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import json +import cv2 +import numpy as np +from copy import deepcopy +from PIL import Image + +import paddle +from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction + +# relative reference +from utils import parse_args, get_image_file_list, draw_re_results +from infer_ser_e2e import SerPredictor + + +def make_input(ser_input, ser_result, max_seq_len=512): + entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} + + entities = ser_input['entities'][0] + assert len(entities) == len(ser_result) + + # entities + start = [] + end = [] + label = [] + entity_idx_dict = {} + for i, (res, entity) in enumerate(zip(ser_result, entities)): + if res['pred'] == 'O': + continue + entity_idx_dict[len(start)] = i + start.append(entity['start']) + end.append(entity['end']) + label.append(entities_labels[res['pred']]) + entities = dict(start=start, end=end, label=label) + + # relations + head = [] + tail = [] + for i in range(len(entities["label"])): + for j in range(len(entities["label"])): + if entities["label"][i] == 1 and entities["label"][j] == 2: + head.append(i) + tail.append(j) + + relations = dict(head=head, tail=tail) + + batch_size = ser_input["input_ids"].shape[0] + entities_batch = [] + relations_batch = [] + for b in range(batch_size): + entities_batch.append(entities) + relations_batch.append(relations) + + ser_input['entities'] = entities_batch + ser_input['relations'] = relations_batch + + ser_input.pop('segment_offset_id') + return ser_input, entity_idx_dict + + +class SerReSystem(object): + def __init__(self, args): + self.ser_engine = SerPredictor(args) + self.tokenizer = LayoutXLMTokenizer.from_pretrained( + args.re_model_name_or_path) + self.model = LayoutXLMForRelationExtraction.from_pretrained( + args.re_model_name_or_path) + self.model.eval() + + def __call__(self, img): + ser_result, ser_inputs = self.ser_engine(img) + re_input, entity_idx_dict = make_input(ser_inputs, ser_result) + + re_result = self.model(**re_input) + + pred_relations = re_result['pred_relations'][0] + # 进行 relations 到 ocr信息的转换 + result = [] + used_tail_id = [] + for relation in pred_relations: + if relation['tail_id'] in used_tail_id: + continue + used_tail_id.append(relation['tail_id']) + ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]] + ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]] + result.append((ocr_info_head, ocr_info_tail)) + + return result + + +if __name__ == "__main__": + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + + # get infer img list + infer_imgs = get_image_file_list(args.infer_imgs) + + # loop for infer + ser_re_engine = SerReSystem(args) + with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: + for idx, img_path in enumerate(infer_imgs): + print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) + + img = cv2.imread(img_path) + + result = ser_re_engine(img) + fout.write(img_path + "\t" + json.dumps( + { + "result": result, + }, ensure_ascii=False) + "\n") + + img_res = draw_re_results(img, result) + cv2.imwrite( + os.path.join(args.output_dir, + os.path.splitext(os.path.basename(img_path))[0] + + "_re.jpg"), img_res) diff --git a/ppstructure/vqa/metric.py b/ppstructure/vqa/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58370521296886670486982caf1202cf99a489 --- /dev/null +++ b/ppstructure/vqa/metric.py @@ -0,0 +1,175 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +import numpy as np + +import logging + +logger = logging.getLogger(__name__) + +PREFIX_CHECKPOINT_DIR = "checkpoint" +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") + + +def get_last_checkpoint(folder): + content = os.listdir(folder) + checkpoints = [ + path for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir( + os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join( + folder, + max(checkpoints, + key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) + + +def re_score(pred_relations, gt_relations, mode="strict"): + """Evaluate RE predictions + + Args: + pred_relations (list) : list of list of predicted relations (several relations in each sentence) + gt_relations (list) : list of list of ground truth relations + + rel = { "head": (start_idx (inclusive), end_idx (exclusive)), + "tail": (start_idx (inclusive), end_idx (exclusive)), + "head_type": ent_type, + "tail_type": ent_type, + "type": rel_type} + + vocab (Vocab) : dataset vocabulary + mode (str) : in 'strict' or 'boundaries'""" + + assert mode in ["strict", "boundaries"] + + relation_types = [v for v in [0, 1] if not v == 0] + scores = { + rel: { + "tp": 0, + "fp": 0, + "fn": 0 + } + for rel in relation_types + ["ALL"] + } + + # Count GT relations and Predicted relations + n_sents = len(gt_relations) + n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) + n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) + + # Count TP, FP and FN per type + for pred_sent, gt_sent in zip(pred_relations, gt_relations): + for rel_type in relation_types: + # strict mode takes argument types into account + if mode == "strict": + pred_rels = {(rel["head"], rel["head_type"], rel["tail"], + rel["tail_type"]) + for rel in pred_sent if rel["type"] == rel_type} + gt_rels = {(rel["head"], rel["head_type"], rel["tail"], + rel["tail_type"]) + for rel in gt_sent if rel["type"] == rel_type} + + # boundaries mode only takes argument spans into account + elif mode == "boundaries": + pred_rels = {(rel["head"], rel["tail"]) + for rel in pred_sent if rel["type"] == rel_type} + gt_rels = {(rel["head"], rel["tail"]) + for rel in gt_sent if rel["type"] == rel_type} + + scores[rel_type]["tp"] += len(pred_rels & gt_rels) + scores[rel_type]["fp"] += len(pred_rels - gt_rels) + scores[rel_type]["fn"] += len(gt_rels - pred_rels) + + # Compute per entity Precision / Recall / F1 + for rel_type in scores.keys(): + if scores[rel_type]["tp"]: + scores[rel_type]["p"] = scores[rel_type]["tp"] / ( + scores[rel_type]["fp"] + scores[rel_type]["tp"]) + scores[rel_type]["r"] = scores[rel_type]["tp"] / ( + scores[rel_type]["fn"] + scores[rel_type]["tp"]) + else: + scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 + + if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: + scores[rel_type]["f1"] = ( + 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / + (scores[rel_type]["p"] + scores[rel_type]["r"])) + else: + scores[rel_type]["f1"] = 0 + + # Compute micro F1 Scores + tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) + fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) + fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) + + if tp: + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * precision * recall / (precision + recall) + + else: + precision, recall, f1 = 0, 0, 0 + + scores["ALL"]["p"] = precision + scores["ALL"]["r"] = recall + scores["ALL"]["f1"] = f1 + scores["ALL"]["tp"] = tp + scores["ALL"]["fp"] = fp + scores["ALL"]["fn"] = fn + + # Compute Macro F1 Scores + scores["ALL"]["Macro_f1"] = np.mean( + [scores[ent_type]["f1"] for ent_type in relation_types]) + scores["ALL"]["Macro_p"] = np.mean( + [scores[ent_type]["p"] for ent_type in relation_types]) + scores["ALL"]["Macro_r"] = np.mean( + [scores[ent_type]["r"] for ent_type in relation_types]) + + # logger.info(f"RE Evaluation in *** {mode.upper()} *** mode") + + # logger.info( + # "processed {} sentences with {} relations; found: {} relations; correct: {}.".format( + # n_sents, n_rels, n_found, tp + # ) + # ) + # logger.info( + # "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"]) + # ) + # logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1)) + # logger.info( + # "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format( + # scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"] + # ) + # ) + + # for rel_type in relation_types: + # logger.info( + # "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format( + # rel_type, + # scores[rel_type]["tp"], + # scores[rel_type]["fp"], + # scores[rel_type]["fn"], + # scores[rel_type]["p"], + # scores[rel_type]["r"], + # scores[rel_type]["f1"], + # scores[rel_type]["tp"] + scores[rel_type]["fp"], + # ) + # ) + + return scores diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py new file mode 100644 index 0000000000000000000000000000000000000000..ed19646cf57e69ac99e417ae27568655a4e00039 --- /dev/null +++ b/ppstructure/vqa/train_re.py @@ -0,0 +1,261 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import random +import numpy as np +import paddle + +from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction + +from xfun import XFUNDataset +from utils import parse_args, get_bio_label_maps, print_arguments +from data_collator import DataCollator +from metric import re_score + +from ppocr.utils.logging import get_logger + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + +def cal_metric(re_preds, re_labels, entities): + gt_relations = [] + for b in range(len(re_labels)): + rel_sent = [] + for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]): + rel = {} + rel["head_id"] = head + rel["head"] = (entities[b]["start"][rel["head_id"]], + entities[b]["end"][rel["head_id"]]) + rel["head_type"] = entities[b]["label"][rel["head_id"]] + + rel["tail_id"] = tail + rel["tail"] = (entities[b]["start"][rel["tail_id"]], + entities[b]["end"][rel["tail_id"]]) + rel["tail_type"] = entities[b]["label"][rel["tail_id"]] + + rel["type"] = 1 + rel_sent.append(rel) + gt_relations.append(rel_sent) + re_metrics = re_score(re_preds, gt_relations, mode="boundaries") + return re_metrics + + +def evaluate(model, eval_dataloader, logger, prefix=""): + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = {}".format(len(eval_dataloader.dataset))) + + re_preds = [] + re_labels = [] + entities = [] + eval_loss = 0.0 + model.eval() + for idx, batch in enumerate(eval_dataloader): + with paddle.no_grad(): + outputs = model(**batch) + loss = outputs['loss'].mean().item() + if paddle.distributed.get_rank() == 0: + logger.info("[Eval] process: {}/{}, loss: {:.5f}".format( + idx, len(eval_dataloader), loss)) + + eval_loss += loss + re_preds.extend(outputs['pred_relations']) + re_labels.extend(batch['relations']) + entities.extend(batch['entities']) + re_metrics = cal_metric(re_preds, re_labels, entities) + re_metrics = { + "precision": re_metrics["ALL"]["p"], + "recall": re_metrics["ALL"]["r"], + "f1": re_metrics["ALL"]["f1"], + } + model.train() + return re_metrics + + +def train(args): + logger = get_logger(log_file=os.path.join(args.output_dir, "train.log")) + print_arguments(args, logger) + + # Added here for reproducibility (even between python 2 and 3) + set_seed(args.seed) + + label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) + pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + + # dist mode + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) + + model = LayoutXLMModel.from_pretrained(args.model_name_or_path) + model = LayoutXLMForRelationExtraction(model, dropout=None) + + # dist mode + if paddle.distributed.get_world_size() > 1: + model = paddle.distributed.DataParallel(model) + + train_dataset = XFUNDataset( + tokenizer, + data_dir=args.train_data_dir, + label_path=args.train_label_path, + label2id_map=label2id_map, + img_size=(224, 224), + max_seq_len=args.max_seq_length, + pad_token_label_id=pad_token_label_id, + contains_re=True, + add_special_ids=False, + return_attention_mask=True, + load_mode='all') + + eval_dataset = XFUNDataset( + tokenizer, + data_dir=args.eval_data_dir, + label_path=args.eval_label_path, + label2id_map=label2id_map, + img_size=(224, 224), + max_seq_len=args.max_seq_length, + pad_token_label_id=pad_token_label_id, + contains_re=True, + add_special_ids=False, + return_attention_mask=True, + load_mode='all') + + train_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True) + args.train_batch_size = args.per_gpu_train_batch_size * \ + max(1, paddle.distributed.get_world_size()) + train_dataloader = paddle.io.DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=8, + use_shared_memory=True, + collate_fn=DataCollator()) + + eval_dataloader = paddle.io.DataLoader( + eval_dataset, + batch_size=args.per_gpu_eval_batch_size, + num_workers=8, + shuffle=False, + collate_fn=DataCollator()) + + t_total = len(train_dataloader) * args.num_train_epochs + + # build linear decay with warmup lr sch + lr_scheduler = paddle.optimizer.lr.PolynomialDecay( + learning_rate=args.learning_rate, + decay_steps=t_total, + end_lr=0.0, + power=1.0) + if args.warmup_steps > 0: + lr_scheduler = paddle.optimizer.lr.LinearWarmup( + lr_scheduler, + args.warmup_steps, + start_lr=0, + end_lr=args.learning_rate, ) + grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10) + optimizer = paddle.optimizer.Adam( + learning_rate=args.learning_rate, + parameters=model.parameters(), + epsilon=args.adam_epsilon, + grad_clip=grad_clip, + weight_decay=args.weight_decay) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = {}".format(len(train_dataset))) + logger.info(" Num Epochs = {}".format(args.num_train_epochs)) + logger.info(" Instantaneous batch size per GPU = {}".format( + args.per_gpu_train_batch_size)) + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = {}". + format(args.train_batch_size * paddle.distributed.get_world_size())) + logger.info(" Total optimization steps = {}".format(t_total)) + + global_step = 0 + model.clear_gradients() + train_dataloader_len = len(train_dataloader) + best_metirc = {'f1': 0} + model.train() + + for epoch in range(int(args.num_train_epochs)): + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + # model outputs are always tuple in ppnlp (see doc) + loss = outputs['loss'] + loss = loss.mean() + + logger.info( + "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}". + format(epoch, args.num_train_epochs, step, train_dataloader_len, + global_step, np.mean(loss.numpy()), optimizer.get_lr())) + + loss.backward() + optimizer.step() + optimizer.clear_grad() + # lr_scheduler.step() # Update learning rate schedule + + global_step += 1 + + if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and + global_step % args.eval_steps == 0): + # Log metrics + if (paddle.distributed.get_rank() == 0 and args. + evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well + results = evaluate(model, eval_dataloader, logger) + if results['f1'] > best_metirc['f1']: + best_metirc = results + output_dir = os.path.join(args.output_dir, + "checkpoint-best") + os.makedirs(output_dir, exist_ok=True) + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + paddle.save(args, + os.path.join(output_dir, + "training_args.bin")) + logger.info("Saving model checkpoint to {}".format( + output_dir)) + logger.info("eval results: {}".format(results)) + logger.info("best_metirc: {}".format(best_metirc)) + + if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and + global_step % args.save_steps == 0): + # Save model checkpoint + output_dir = os.path.join(args.output_dir, "checkpoint-latest") + os.makedirs(output_dir, exist_ok=True) + if paddle.distributed.get_rank() == 0: + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + paddle.save(args, + os.path.join(output_dir, "training_args.bin")) + logger.info("Saving model checkpoint to {}".format( + output_dir)) + logger.info("best_metirc: {}".format(best_metirc)) + + +if __name__ == "__main__": + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + train(args) diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py index 90ca69d93fd22983533fcacd639bbd64dc3e11ec..d3144e7167c59b5883047a948abaedfd21ba9b1c 100644 --- a/ppstructure/vqa/train_ser.py +++ b/ppstructure/vqa/train_ser.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + import random import copy import logging @@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM from xfun import XFUNDataset from utils import parse_args from utils import get_bio_label_maps +from utils import print_arguments -logger = logging.getLogger(__name__) +from ppocr.utils.logging import get_logger def set_seed(args): @@ -38,17 +44,8 @@ def set_seed(args): def train(args): os.makedirs(args.output_dir, exist_ok=True) - logging.basicConfig( - filename=os.path.join(args.output_dir, "train.log") - if paddle.distributed.get_rank() == 0 else None, - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO - if paddle.distributed.get_rank() == 0 else logging.WARN, ) - - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - logger.addHandler(ch) + logger = get_logger(log_file=os.path.join(args.output_dir, "train.log")) + print_arguments(args, logger) label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index @@ -136,10 +133,10 @@ def train(args): loss = outputs[0] loss = loss.mean() logger.info( - "[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ". + "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}". format(epoch_id, args.num_train_epochs, step, - len(train_dataloader), - lr_scheduler.get_lr(), loss.numpy()[0])) + len(train_dataloader), global_step, + loss.numpy()[0], lr_scheduler.get_lr())) loss.backward() tr_loss += loss.item() @@ -154,13 +151,9 @@ def train(args): # Only evaluate when single GPU otherwise metrics may not average well if paddle.distributed.get_rank( ) == 0 and args.evaluate_during_training: - results, _ = evaluate( - args, - model, - tokenizer, - label2id_map, - id2label_map, - pad_token_label_id, ) + results, _ = evaluate(args, model, tokenizer, label2id_map, + id2label_map, pad_token_label_id, + logger) if best_metrics is None or results["f1"] >= best_metrics[ "f1"]: @@ -204,6 +197,7 @@ def evaluate(args, label2id_map, id2label_map, pad_token_label_id, + logger, prefix=""): eval_dataset = XFUNDataset( tokenizer, @@ -299,15 +293,6 @@ def evaluate(args, return results, preds_list -def print_arguments(args): - """print arguments""" - print('----------- Configuration Arguments -----------') - for arg, value in sorted(vars(args).items()): - print('%s: %s' % (arg, value)) - print('------------------------------------------------') - - if __name__ == "__main__": args = parse_args() - print_arguments(args) train(args) diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py index a4ac1e77d37d0a662294480a393c2f67e7f4cc64..0af180ada2eae740c042378c73b884239ddbf7b9 100644 --- a/ppstructure/vqa/utils.py +++ b/ppstructure/vqa/utils.py @@ -24,8 +24,6 @@ import paddle from PIL import Image, ImageDraw, ImageFont -from paddleocr import PaddleOCR - def get_bio_label_maps(label_map_path): with open(label_map_path, "r") as fin: @@ -66,9 +64,9 @@ def get_image_file_list(img_file): def draw_ser_results(image, ocr_results, - font_path="../doc/fonts/simfang.ttf", + font_path="../../doc/fonts/simfang.ttf", font_size=18): - np.random.seed(0) + np.random.seed(2021) color = (np.random.permutation(range(255)), np.random.permutation(range(255)), np.random.permutation(range(255))) @@ -82,38 +80,64 @@ def draw_ser_results(image, draw = ImageDraw.Draw(img_new) font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - for ocr_info in ocr_results: if ocr_info["pred_id"] not in color_map: continue color = color_map[ocr_info["pred_id"]] - - # draw ocr results outline - bbox = ocr_info["bbox"] - bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) - draw.rectangle(bbox, fill=color) - - # draw ocr results text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) - start_y = max(0, bbox[0][1] - font_size) - tw = font.getsize(text)[0] - draw.rectangle( - [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, - start_y + font_size)], - fill=(0, 0, 255)) - draw.text( - (bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) + + draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) img_new = Image.blend(image, img_new, 0.5) return np.array(img_new) -def build_ocr_engine(rec_model_dir, det_model_dir): - ocr_engine = PaddleOCR( - rec_model_dir=rec_model_dir, - det_model_dir=det_model_dir, - use_angle_cls=False) - return ocr_engine +def draw_box_txt(bbox, text, draw, font, font_size, color): + # draw ocr results outline + bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) + draw.rectangle(bbox, fill=color) + + # draw ocr results + start_y = max(0, bbox[0][1] - font_size) + tw = font.getsize(text)[0] + draw.rectangle( + [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)], + fill=(0, 0, 255)) + draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) + + +def draw_re_results(image, + result, + font_path="../../doc/fonts/simfang.ttf", + font_size=18): + np.random.seed(0) + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + img_new = image.copy() + draw = ImageDraw.Draw(img_new) + + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + color_head = (0, 0, 255) + color_tail = (255, 0, 0) + color_line = (0, 255, 0) + + for ocr_info_head, ocr_info_tail in result: + draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, + font_size, color_head) + draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, + font_size, color_tail) + + center_head = ( + (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, + (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2) + center_tail = ( + (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, + (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2) + + draw.line([center_head, center_tail], fill=color_line, width=5) + + img_new = Image.blend(image, img_new, 0.5) + return np.array(img_new) # pad sentences @@ -130,7 +154,7 @@ def pad_sentences(tokenizer, len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len needs_to_be_padded = pad_to_max_seq_len and \ - max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len + max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len if needs_to_be_padded: difference = max_seq_len - len(encoded_inputs["input_ids"]) @@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512): truncate is often used in training process """ for key in encoded_inputs: + if key == 'entities': + encoded_inputs[key] = [encoded_inputs[key]] + continue encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) @@ -184,14 +211,14 @@ def preprocess( height = ori_img.shape[0] width = ori_img.shape[1] - img = cv2.resize(ori_img, - (224, 224)).transpose([2, 0, 1]).astype(np.float32) + img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32) segment_offset_id = [] words_list = [] bbox_list = [] input_ids_list = [] token_type_ids_list = [] + entities = [] for info in ocr_info: # x1, y1, x2, y2 @@ -211,6 +238,13 @@ def preprocess( encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] + # for re + entities.append({ + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": "O", + }) + input_ids_list.extend(encode_res["input_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"]) bbox_list.extend([bbox] * len(encode_res["input_ids"])) @@ -222,6 +256,7 @@ def preprocess( "token_type_ids": token_type_ids_list, "bbox": bbox_list, "attention_mask": [1] * len(input_ids_list), + "entities": entities } encoded_inputs = pad_sentences( @@ -294,35 +329,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list, return ocr_info +def print_arguments(args, logger=None): + print_func = logger.info if logger is not None else print + """print arguments""" + print_func('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print_func('%s: %s' % (arg, value)) + print_func('------------------------------------------------') + + def parse_args(): parser = argparse.ArgumentParser() # Required parameters # yapf: disable - parser.add_argument("--model_name_or_path", default=None, type=str, required=True,) - parser.add_argument("--train_data_dir", default=None, type=str, required=False,) - parser.add_argument("--train_label_path", default=None, type=str, required=False,) - parser.add_argument("--eval_data_dir", default=None, type=str, required=False,) - parser.add_argument("--eval_label_path", default=None, type=str, required=False,) + parser.add_argument("--model_name_or_path", + default=None, type=str, required=True,) + parser.add_argument("--re_model_name_or_path", + default=None, type=str, required=False,) + parser.add_argument("--train_data_dir", default=None, + type=str, required=False,) + parser.add_argument("--train_label_path", default=None, + type=str, required=False,) + parser.add_argument("--eval_data_dir", default=None, + type=str, required=False,) + parser.add_argument("--eval_label_path", default=None, + type=str, required=False,) parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--evaluate_during_training", action="store_true",) - parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",) - parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",) - parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",) - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",) - parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",) - parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",) - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",) - parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",) - parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",) - parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",) + parser.add_argument("--per_gpu_train_batch_size", default=8, + type=int, help="Batch size per GPU/CPU for training.",) + parser.add_argument("--per_gpu_eval_batch_size", default=8, + type=int, help="Batch size per GPU/CPU for eval.",) + parser.add_argument("--learning_rate", default=5e-5, + type=float, help="The initial learning rate for Adam.",) + parser.add_argument("--weight_decay", default=0.0, + type=float, help="Weight decay if we apply some.",) + parser.add_argument("--adam_epsilon", default=1e-8, + type=float, help="Epsilon for Adam optimizer.",) + parser.add_argument("--max_grad_norm", default=1.0, + type=float, help="Max gradient norm.",) + parser.add_argument("--num_train_epochs", default=3, type=int, + help="Total number of training epochs to perform.",) + parser.add_argument("--warmup_steps", default=0, type=int, + help="Linear warmup over warmup_steps.",) + parser.add_argument("--eval_steps", type=int, default=10, + help="eval every X updates steps.",) + parser.add_argument("--save_steps", type=int, default=50, + help="Save checkpoint every X updates steps.",) + parser.add_argument("--seed", type=int, default=2048, + help="random seed for initialization",) parser.add_argument("--ocr_rec_model_dir", default=None, type=str, ) parser.add_argument("--ocr_det_model_dir", default=None, type=str, ) - parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) + parser.add_argument( + "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) parser.add_argument("--infer_imgs", default=None, type=str, required=False) - parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results") + parser.add_argument("--ocr_json_path", default=None, + type=str, required=False, help="ocr prediction results") # yapf: enable args = parser.parse_args() return args