diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index 51a616c4399695bfb598ea2ef4524683592b687d..975139c7937605972dfb7051918b74c66f480555 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -195,7 +195,7 @@ export CUDA_VISIBLE_DEVICES=0 python3.7 infer_ser.py \ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ --ser_model_type "LayoutXLM" \ - --output_dir "output_res/" \ + --output_dir "output/ser/" \ --infer_imgs "XFUND/zh_val/image/" \ --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json" ``` @@ -210,7 +210,7 @@ python3.7 infer_ser_e2e.py \ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ --ser_model_type "LayoutXLM" \ --max_seq_length 512 \ - --output_dir "output_res_e2e/" \ + --output_dir "output/ser_e2e/" \ --infer_imgs "images/input/zh_val_0.jpg" ``` @@ -284,7 +284,7 @@ python3 eval_re.py \ --eval_data_dir "XFUND/zh_val/image" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --label_map_path 'labels/labels_ser.txt' \ - --output_dir "output/re_test/" \ + --output_dir "output/re/" \ --per_gpu_eval_batch_size 8 \ --num_workers 8 \ --seed 2048 @@ -302,7 +302,7 @@ python3 infer_re.py \ --eval_data_dir "XFUND/zh_val/image" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --label_map_path 'labels/labels_ser.txt' \ - --output_dir "output_res" \ + --output_dir "output/re/" \ --per_gpu_eval_batch_size 1 \ --seed 2048 ``` @@ -317,7 +317,7 @@ python3.7 infer_ser_re_e2e.py \ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ --re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \ --max_seq_length 512 \ - --output_dir "output_ser_re_e2e_train/" \ + --output_dir "output/ser_re_e2e/" \ --infer_imgs "images/input/zh_val_21.jpg" ``` diff --git a/ppstructure/vqa/infer.sh b/ppstructure/vqa/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..2cd1cea4476672732b3a7f9ad97a3e42172dbb92 --- /dev/null +++ b/ppstructure/vqa/infer.sh @@ -0,0 +1,61 @@ +export CUDA_VISIBLE_DEVICES=6 +# python3.7 infer_ser_e2e.py \ +# --model_name_or_path "output/ser_distributed/best_model" \ +# --max_seq_length 512 \ +# --output_dir "output_res_e2e/" \ +# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg" + + +# python3.7 infer_ser_re_e2e.py \ +# --model_name_or_path "output/ser_distributed/best_model" \ +# --re_model_name_or_path "output/re_test/best_model" \ +# --max_seq_length 512 \ +# --output_dir "output_ser_re_e2e_train/" \ +# --infer_imgs "images/input/zh_val_21.jpg" + +# python3.7 infer_ser.py \ +# --model_name_or_path "output/ser_LayoutLM/best_model" \ +# --ser_model_type "LayoutLM" \ +# --output_dir "ser_LayoutLM/" \ +# --infer_imgs "images/input/zh_val_21.jpg" \ +# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" + +python3.7 infer_ser.py \ + --model_name_or_path "output/ser_new/best_model" \ + --ser_model_type "LayoutXLM" \ + --output_dir "ser_new/" \ + --infer_imgs "images/input/zh_val_21.jpg" \ + --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" + +# python3.7 infer_ser_e2e.py \ +# --model_name_or_path "output/ser_new/best_model" \ +# --ser_model_type "LayoutXLM" \ +# --max_seq_length 512 \ +# --output_dir "output/ser_new/" \ +# --infer_imgs "images/input/zh_val_0.jpg" + + +# python3.7 infer_ser_e2e.py \ +# --model_name_or_path "output/ser_LayoutLM/best_model" \ +# --ser_model_type "LayoutLM" \ +# --max_seq_length 512 \ +# --output_dir "output/ser_LayoutLM/" \ +# --infer_imgs "images/input/zh_val_0.jpg" + +# python3 infer_re.py \ +# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \ +# --max_seq_length 512 \ +# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \ +# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \ +# --label_map_path 'labels/labels_ser.txt' \ +# --output_dir "output_res" \ +# --per_gpu_eval_batch_size 1 \ +# --seed 2048 + +# python3.7 infer_ser_re_e2e.py \ +# --model_name_or_path "output/ser_LayoutLM/best_model" \ +# --ser_model_type "LayoutLM" \ +# --re_model_name_or_path "output/re_new/best_model" \ +# --max_seq_length 512 \ +# --output_dir "output_ser_re_e2e/" \ +# --infer_imgs "images/input/zh_val_21.jpg" \ No newline at end of file diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py index 2ffa60f5d62c9e73e66861fd78701f164b55a9e5..98c61bacceb9e8a6e5f6da55cc7719b46ab308d0 100644 --- a/ppstructure/vqa/infer_re.py +++ b/ppstructure/vqa/infer_re.py @@ -56,19 +56,19 @@ def infer(args): ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path) for idx, batch in enumerate(eval_dataloader): + ocr_info = ocr_info_list[idx] + image_path = ocr_info['image_path'] + ocr_info = ocr_info['ocr_info'] + save_img_path = os.path.join( args.output_dir, - os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg") + os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg") logger.info("[Infer] process: {}/{}, save_result to {}".format( idx, len(eval_dataloader), save_img_path)) with paddle.no_grad(): outputs = model(**batch) pred_relations = outputs['pred_relations'] - ocr_info = ocr_info_list[idx] - image_path = ocr_info['image_path'] - ocr_info = ocr_info['ocr_info'] - # 根据entity里的信息,做token解码后去过滤不要的ocr_info ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer) diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py index 05a029822b580c204e95d85c3918c7786b05a4cf..bceb3434b5299b190568cb0368243d19f555c630 100644 --- a/ppstructure/vqa/infer_ser_e2e.py +++ b/ppstructure/vqa/infer_ser_e2e.py @@ -98,13 +98,13 @@ class SerPredictor(object): ocr_info=ocr_info, max_seq_len=self.max_seq_length) - if args.ser_model_type == 'LayoutLM': + if self.args.ser_model_type == 'LayoutLM': preds = self.model( input_ids=inputs["input_ids"], bbox=inputs["bbox"], token_type_ids=inputs["token_type_ids"], attention_mask=inputs["attention_mask"]) - elif args.ser_model_type == 'LayoutXLM': + elif self.args.ser_model_type == 'LayoutXLM': preds = self.model( input_ids=inputs["input_ids"], bbox=inputs["bbox"], diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py index 23737406d1d2f31f79df9ddb1a9a6bcc5976aabe..a6316b625c4ec213dcb81f2e4f1e8f8a422f1cc6 100644 --- a/ppstructure/vqa/infer_ser_re_e2e.py +++ b/ppstructure/vqa/infer_ser_re_e2e.py @@ -117,7 +117,11 @@ if __name__ == "__main__": "w", encoding='utf-8') as fout: for idx, img_path in enumerate(infer_imgs): - print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) + save_img_path = os.path.join( + args.output_dir, + os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg") + print("process: [{}/{}], save_result to {}".format( + idx, len(infer_imgs), save_img_path)) img = cv2.imread(img_path) @@ -128,7 +132,4 @@ if __name__ == "__main__": }, ensure_ascii=False) + "\n") img_res = draw_re_results(img, result) - cv2.imwrite( - os.path.join(args.output_dir, - os.path.splitext(os.path.basename(img_path))[0] + - "_re.jpg"), img_res) + cv2.imwrite(save_img_path, img_res)