提交 9131c4a7 编写于 作者: 文幕地方's avatar 文幕地方

add LayoutLM ser

上级 f01dbb56
...@@ -195,7 +195,7 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -195,7 +195,7 @@ export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser.py \ python3.7 infer_ser.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \ --ser_model_type "LayoutXLM" \
--output_dir "output_res/" \ --output_dir "output/ser/" \
--infer_imgs "XFUND/zh_val/image/" \ --infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json" --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
``` ```
...@@ -210,7 +210,7 @@ python3.7 infer_ser_e2e.py \ ...@@ -210,7 +210,7 @@ python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \ --ser_model_type "LayoutXLM" \
--max_seq_length 512 \ --max_seq_length 512 \
--output_dir "output_res_e2e/" \ --output_dir "output/ser_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg" --infer_imgs "images/input/zh_val_0.jpg"
``` ```
...@@ -284,7 +284,7 @@ python3 eval_re.py \ ...@@ -284,7 +284,7 @@ python3 eval_re.py \
--eval_data_dir "XFUND/zh_val/image" \ --eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \ --label_map_path 'labels/labels_ser.txt' \
--output_dir "output/re_test/" \ --output_dir "output/re/" \
--per_gpu_eval_batch_size 8 \ --per_gpu_eval_batch_size 8 \
--num_workers 8 \ --num_workers 8 \
--seed 2048 --seed 2048
...@@ -302,7 +302,7 @@ python3 infer_re.py \ ...@@ -302,7 +302,7 @@ python3 infer_re.py \
--eval_data_dir "XFUND/zh_val/image" \ --eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \ --label_map_path 'labels/labels_ser.txt' \
--output_dir "output_res" \ --output_dir "output/re/" \
--per_gpu_eval_batch_size 1 \ --per_gpu_eval_batch_size 1 \
--seed 2048 --seed 2048
``` ```
...@@ -317,7 +317,7 @@ python3.7 infer_ser_re_e2e.py \ ...@@ -317,7 +317,7 @@ python3.7 infer_ser_re_e2e.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \ --re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \ --max_seq_length 512 \
--output_dir "output_ser_re_e2e_train/" \ --output_dir "output/ser_re_e2e/" \
--infer_imgs "images/input/zh_val_21.jpg" --infer_imgs "images/input/zh_val_21.jpg"
``` ```
......
export CUDA_VISIBLE_DEVICES=6
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --max_seq_length 512 \
# --output_dir "output_res_e2e/" \
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --re_model_name_or_path "output/re_test/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e_train/" \
# --infer_imgs "images/input/zh_val_21.jpg"
# python3.7 infer_ser.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --output_dir "ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_21.jpg" \
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
python3.7 infer_ser.py \
--model_name_or_path "output/ser_new/best_model" \
--ser_model_type "LayoutXLM" \
--output_dir "ser_new/" \
--infer_imgs "images/input/zh_val_21.jpg" \
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_new/best_model" \
# --ser_model_type "LayoutXLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_new/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3 infer_re.py \
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
# --max_seq_length 512 \
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
# --label_map_path 'labels/labels_ser.txt' \
# --output_dir "output_res" \
# --per_gpu_eval_batch_size 1 \
# --seed 2048
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --re_model_name_or_path "output/re_new/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e/" \
# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
...@@ -56,19 +56,19 @@ def infer(args): ...@@ -56,19 +56,19 @@ def infer(args):
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path) ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader): for idx, batch in enumerate(eval_dataloader):
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
save_img_path = os.path.join( save_img_path = os.path.join(
args.output_dir, args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg") os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
logger.info("[Infer] process: {}/{}, save_result to {}".format( logger.info("[Infer] process: {}/{}, save_result to {}".format(
idx, len(eval_dataloader), save_img_path)) idx, len(eval_dataloader), save_img_path))
with paddle.no_grad(): with paddle.no_grad():
outputs = model(**batch) outputs = model(**batch)
pred_relations = outputs['pred_relations'] pred_relations = outputs['pred_relations']
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info # 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer) ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
......
...@@ -98,13 +98,13 @@ class SerPredictor(object): ...@@ -98,13 +98,13 @@ class SerPredictor(object):
ocr_info=ocr_info, ocr_info=ocr_info,
max_seq_len=self.max_seq_length) max_seq_len=self.max_seq_length)
if args.ser_model_type == 'LayoutLM': if self.args.ser_model_type == 'LayoutLM':
preds = self.model( preds = self.model(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
bbox=inputs["bbox"], bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"]) attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM': elif self.args.ser_model_type == 'LayoutXLM':
preds = self.model( preds = self.model(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
bbox=inputs["bbox"], bbox=inputs["bbox"],
......
...@@ -117,7 +117,11 @@ if __name__ == "__main__": ...@@ -117,7 +117,11 @@ if __name__ == "__main__":
"w", "w",
encoding='utf-8') as fout: encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
print("process: [{}/{}], save_result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path) img = cv2.imread(img_path)
...@@ -128,7 +132,4 @@ if __name__ == "__main__": ...@@ -128,7 +132,4 @@ if __name__ == "__main__":
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result) img_res = draw_re_results(img, result)
cv2.imwrite( cv2.imwrite(save_img_path, img_res)
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_re.jpg"), img_res)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册