diff --git a/tools/end2end/convert_ppocr_label.py b/tools/end2end/convert_ppocr_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..a230a14a9392a555db058dce57cf32fb2fdf4ece
--- /dev/null
+++ b/tools/end2end/convert_ppocr_label.py
@@ -0,0 +1,81 @@
+import numpy as np
+import json
+import os
+
+
+def poly_to_string(poly):
+ if len(poly.shape) > 1:
+ poly = np.array(poly).flatten()
+
+ string = "\t".join(str(i) for i in poly)
+ return string
+
+
+def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
+ if not os.path.exists(label_dir):
+ raise ValueError(f"The file {label_dir} does not exist!")
+
+ assert label_dir != save_dir, "hahahhaha"
+
+ label_file = open(label_dir, 'r')
+ data = label_file.readlines()
+
+ gt_dict = {}
+
+ for line in data:
+ try:
+ tmp = line.split('\t')
+ assert len(tmp) == 2, ""
+ except:
+ tmp = line.strip().split(' ')
+
+ gt_lists = []
+
+ if tmp[0].split('/')[0] is not None:
+ img_path = tmp[0]
+ anno = json.loads(tmp[1])
+ gt_collect = []
+ for dic in anno:
+ #txt = dic['transcription'].replace(' ', '') # ignore blank
+ txt = dic['transcription']
+ if 'score' in dic and float(dic['score']) < 0.5:
+ continue
+ if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
+ #while ' ' in txt:
+ # txt = txt.replace(' ', '')
+ poly = np.array(dic['points']).flatten()
+ if txt == "###":
+ txt_tag = 1 ## ignore 1
+ else:
+ txt_tag = 0
+ if mode == "gt":
+ gt_label = poly_to_string(poly) + "\t" + str(
+ txt_tag) + "\t" + txt + "\n"
+ else:
+ gt_label = poly_to_string(poly) + "\t" + txt + "\n"
+
+ gt_lists.append(gt_label)
+
+ gt_dict[img_path] = gt_lists
+ else:
+ continue
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ for img_name in gt_dict.keys():
+ save_name = img_name.split("/")[-1]
+ save_file = os.path.join(save_dir, save_name + ".txt")
+ with open(save_file, "w") as f:
+ f.writelines(gt_dict[img_name])
+
+ print("The convert label saved in {}".format(save_dir))
+
+
+if __name__ == "__main__":
+
+ ppocr_label_gt = "/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt"
+ convert_label(ppocr_label_gt, "gt", "./save_gt_310_V2/")
+
+ ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt"
+ convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
diff --git a/tools/end2end/draw_html.py b/tools/end2end/draw_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f9a4719ba09d9b49d1b920ed397237d485eb925
--- /dev/null
+++ b/tools/end2end/draw_html.py
@@ -0,0 +1,52 @@
+import os
+
+
+def draw_debug_img(html_path):
+
+ err_cnt = 0
+ with open(html_path, 'w') as html:
+ html.write('\n
\n')
+ html.write('\n')
+ html.write(
+ ""
+ )
+ image_list = []
+ path = "./det_results/310_gt/"
+ #path = "infer_results/"
+ for i, filename in enumerate(sorted(os.listdir(path))):
+ if filename.endswith("txt"): continue
+ print(filename)
+ # The image path
+ base = "{}/{}".format(path, filename)
+ base_2 = "../PaddleOCR/det_results/ch_PPOCRV2_infer/{}".format(
+ filename)
+ base_3 = "../PaddleOCR/det_results/ch_ppocr_mobile_infer/{}".format(
+ filename)
+
+ if True:
+ html.write("\n")
+ html.write(f' {filename}\n GT')
+ html.write(' | GT\n | ' % (base))
+ html.write('PPOCRV2\n | ' %
+ (base_2))
+ html.write('ppocr_mobile\n | ' %
+ (base_3))
+
+ html.write("
\n")
+ html.write('\n')
+ html.write('
\n')
+ html.write('\n\n')
+ print("ok")
+ #print("all cnt: {}, err cnt: {}, acc: {}".format(len(imgs), err_cnt, 1.0 * (len(imgs) - err_cnt) / len(imgs)))
+ return
+
+
+if __name__ == "__main__":
+
+ html_path = "sys_visual_iou_310.html"
+
+ draw_debug_img()
diff --git a/tools/end2end/eval_end2end.py b/tools/end2end/eval_end2end.py
new file mode 100644
index 0000000000000000000000000000000000000000..88277786f28cc9a3db71735ef602c0c82e0ec467
--- /dev/null
+++ b/tools/end2end/eval_end2end.py
@@ -0,0 +1,182 @@
+#!/usr/bin/env python
+import os
+import re
+import sys
+# import Polygon
+import shapely
+from shapely.geometry import Polygon
+import numpy as np
+from collections import defaultdict
+import operator
+import editdistance
+
+# reload(sys)
+# sys.setdefaultencoding('utf-8')
+
+
+def strQ2B(ustring):
+ rstring = ""
+ for uchar in ustring:
+ inside_code = ord(uchar)
+ if inside_code == 12288:
+ inside_code = 32
+ elif (inside_code >= 65281 and inside_code <= 65374):
+ inside_code -= 65248
+ rstring += chr(inside_code)
+ return rstring
+
+
+def polygon_from_str(polygon_points):
+ """
+ Create a shapely polygon object from gt or dt line.
+ """
+ polygon_points = np.array(polygon_points).reshape(4, 2)
+ polygon = Polygon(polygon_points).convex_hull
+ return polygon
+
+
+def polygon_iou(poly1, poly2):
+ """
+ Intersection over union between two shapely polygons.
+ """
+ if not poly1.intersects(
+ poly2): # this test is fast and can accelerate calculation
+ iou = 0
+ else:
+ try:
+ inter_area = poly1.intersection(poly2).area
+ union_area = poly1.area + poly2.area - inter_area
+ iou = float(inter_area) / union_area
+ except shapely.geos.TopologicalError:
+ # except Exception as e:
+ # print(e)
+ print('shapely.geos.TopologicalError occured, iou set to 0')
+ iou = 0
+ return iou
+
+
+def ed(str1, str2):
+ return editdistance.eval(str1, str2)
+
+
+def e2e_eval(gt_dir, res_dir):
+ print('start testing...')
+ iou_thresh = 0.5
+ val_names = os.listdir(gt_dir)
+ num_gt_chars = 0
+ gt_count = 0
+ dt_count = 0
+ hit = 0
+ ed_sum = 0
+
+ for i, val_name in enumerate(val_names):
+ with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
+ gt_lines = [o.strip() for o in f.readlines()]
+ gts = []
+ ignore_masks = []
+ for line in gt_lines:
+ parts = line.strip().split('\t')
+ # ignore illegal data
+ if len(parts) < 9:
+ continue
+ assert (len(parts) < 11)
+ if len(parts) == 9:
+ gts.append(parts[:8] + [''])
+ else:
+ gts.append(parts[:8] + [parts[-1]])
+
+ ignore_masks.append(parts[8])
+
+ val_path = os.path.join(res_dir, val_name)
+ if not os.path.exists(val_path):
+ dt_lines = []
+ else:
+ with open(val_path, encoding='utf-8') as f:
+ dt_lines = [o.strip() for o in f.readlines()]
+ dts = []
+ for line in dt_lines:
+ # print(line)
+ parts = line.strip().split("\t")
+ assert (len(parts) < 10), "line error: {}".format(line)
+ if len(parts) == 8:
+ dts.append(parts + [''])
+ else:
+ dts.append(parts)
+
+ dt_match = [False] * len(dts)
+ gt_match = [False] * len(gts)
+ all_ious = defaultdict(tuple)
+ for index_gt, gt in enumerate(gts):
+ gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
+ gt_poly = polygon_from_str(gt_coors)
+ for index_dt, dt in enumerate(dts):
+ dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
+ dt_poly = polygon_from_str(dt_coors)
+ iou = polygon_iou(dt_poly, gt_poly)
+ if iou >= iou_thresh:
+ all_ious[(index_gt, index_dt)] = iou
+ sorted_ious = sorted(
+ all_ious.items(), key=operator.itemgetter(1), reverse=True)
+ sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
+
+ # matched gt and dt
+ for gt_dt_pair in sorted_gt_dt_pairs:
+ index_gt, index_dt = gt_dt_pair
+ if gt_match[index_gt] == False and dt_match[index_dt] == False:
+ gt_match[index_gt] = True
+ dt_match[index_dt] = True
+ # gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
+ # dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
+ gt_str = strQ2B(gts[index_gt][8])
+ dt_str = strQ2B(dts[index_dt][8])
+ if ignore_masks[index_gt] == '0':
+ ed_sum += ed(gt_str, dt_str)
+ num_gt_chars += len(gt_str)
+ if gt_str == dt_str:
+ hit += 1
+ gt_count += 1
+ dt_count += 1
+
+ # unmatched dt
+ for tindex, dt_match_flag in enumerate(dt_match):
+ if dt_match_flag == False:
+ dt_str = dts[tindex][8]
+ gt_str = ''
+ ed_sum += ed(dt_str, gt_str)
+ dt_count += 1
+
+ # unmatched gt
+ for tindex, gt_match_flag in enumerate(gt_match):
+ if gt_match_flag == False and ignore_masks[tindex] == '0':
+ dt_str = ''
+ gt_str = gts[tindex][8]
+ ed_sum += ed(gt_str, dt_str)
+ num_gt_chars += len(gt_str)
+ gt_count += 1
+
+ eps = 1e-9
+ print('hit, dt_count, gt_count', hit, dt_count, gt_count)
+ precision = hit / (dt_count + eps)
+ recall = hit / (gt_count + eps)
+ fmeasure = 2.0 * precision * recall / (precision + recall + eps)
+ avg_edit_dist_img = ed_sum / len(val_names)
+ avg_edit_dist_field = ed_sum / (gt_count + eps)
+ character_acc = 1 - ed_sum / (num_gt_chars + eps)
+
+ print('character_acc: %.2f' % (character_acc * 100) + "%")
+ print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
+ print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
+ print('precision: %.2f' % (precision * 100) + "%")
+ print('recall: %.2f' % (recall * 100) + "%")
+ print('fmeasure: %.2f' % (fmeasure * 100) + "%")
+
+
+if __name__ == '__main__':
+ # if len(sys.argv) != 3:
+ # print("python3 ocr_e2e_eval.py gt_dir res_dir")
+ # exit(-1)
+ # gt_folder = sys.argv[1]
+ # pred_folder = sys.argv[2]
+ gt_folder = sys.argv[1]
+ pred_folder = sys.argv[2]
+ e2e_eval(gt_folder, pred_folder)
diff --git a/tools/end2end/readme.md b/tools/end2end/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..257a6f382ad7a5d69493bdd1d2c31ecd0e105c9a
--- /dev/null
+++ b/tools/end2end/readme.md
@@ -0,0 +1,69 @@
+
+# 简介
+
+`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
+
+
+## 端对端评测步骤
+
+**步骤一:**
+
+运行`tools/infer/predict_system.py`,得到保存的结果:
+
+```
+python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
+```
+
+文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/results.txt`中,格式如下:
+```
+all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
+```
+
+
+**步骤二:**
+
+将步骤一保存的数据转换为端对端评测需要的数据格式:
+修改 `tools/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
+
+```
+ppocr_label_gt = "gt_label.txt"
+convert_label(ppocr_label_gt, "gt", "./save_gt_label/")
+
+ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt"
+convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
+```
+
+运行`convert_ppocr_label.py`:
+```
+python3 tools/convert_ppocr_label.py
+```
+
+得到如下结果:
+```
+├── ./save_gt_label/
+├── ./save_PPOCRV2_infer/
+```
+
+**步骤三:**
+
+执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下:
+
+```
+python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
+```
+
+比如:
+
+```
+python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
+```
+将得到如下结果,fmeasure为主要关注的指标:
+```
+hit, dt_count, gt_count 1557 2693 3283
+character_acc: 61.77%
+avg_edit_dist_field: 3.08
+avg_edit_dist_img: 51.82
+precision: 57.82%
+recall: 47.43%
+fmeasure: 52.11%
+```
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index 8d674809a5fe22e458fcb0c68419a7313e71d5f6..e9aff6d210114a1ebcb42409a7b9480f69ead664 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -27,6 +27,7 @@ import numpy as np
import time
import logging
from PIL import Image
+import json
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
@@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes):
return _boxes
+def save_results_to_txt(results, path):
+ if os.path.isdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+ with open(os.path.join(path, "results.txt"), 'w') as f:
+ f.writelines(results)
+ f.close()
+ logger.info("The results will be saved in {}".format(
+ os.path.join(path, "results.txt")))
+ else:
+ draw_img_save = os.path.dirname(path)
+ if not os.path.exists(draw_img_save):
+ os.makedirs(draw_img_save)
+
+ with open(path, 'w') as f:
+ f.writelines(results)
+ f.close()
+ logger.info("The results will be saved in {}".format(path))
+
+
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
- is_visualize = True
+ is_visualize = args.is_visualize
font_path = args.vis_font_path
drop_score = args.drop_score
@@ -139,6 +160,7 @@ def main(args):
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time()
count = 0
+ save_res = []
for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file)
@@ -152,6 +174,21 @@ def main(args):
elapse = time.time() - starttime
total_time += elapse
+ # save results
+ preds = []
+ dt_num = len(dt_boxes)
+ for dno in range(dt_num):
+ text, score = rec_res[dno]
+ if score >= drop_score:
+ preds.append({
+ "transcription": text,
+ "points": np.array(dt_boxes[dno]).tolist()
+ })
+ text_str = "%s, %.3f" % (text, score)
+ save_res.append(image_file + '\t' + json.dumps(
+ preds, ensure_ascii=False) + '\n')
+
+ # print predicted results
logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
@@ -180,6 +217,9 @@ def main(args):
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file))))
+ # The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt")
+ save_results_to_txt(save_res, args.draw_img_save_dir)
+
logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark:
text_sys.text_detector.autolog.report()
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 33ed62125c0b59b5f23b72b5b8f6ecb3b0835cf3..7b7b81e3cd22c12561b30e6705eded1c92ec7761 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -114,6 +114,7 @@ def init_args():
#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
+ parser.add_argument("--is_visualize", type=str2bool, default=True)
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")