提交 dbaab10e 编写于 作者: L LDOUBLEV

add end2end

上级 44ea70d5
import numpy as np
import json
import os
def poly_to_string(poly):
if len(poly.shape) > 1:
poly = np.array(poly).flatten()
string = "\t".join(str(i) for i in poly)
return string
def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
if not os.path.exists(label_dir):
raise ValueError(f"The file {label_dir} does not exist!")
assert label_dir != save_dir, "hahahhaha"
label_file = open(label_dir, 'r')
data = label_file.readlines()
gt_dict = {}
for line in data:
try:
tmp = line.split('\t')
assert len(tmp) == 2, ""
except:
tmp = line.strip().split(' ')
gt_lists = []
if tmp[0].split('/')[0] is not None:
img_path = tmp[0]
anno = json.loads(tmp[1])
gt_collect = []
for dic in anno:
#txt = dic['transcription'].replace(' ', '') # ignore blank
txt = dic['transcription']
if 'score' in dic and float(dic['score']) < 0.5:
continue
if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
#while ' ' in txt:
# txt = txt.replace(' ', '')
poly = np.array(dic['points']).flatten()
if txt == "###":
txt_tag = 1 ## ignore 1
else:
txt_tag = 0
if mode == "gt":
gt_label = poly_to_string(poly) + "\t" + str(
txt_tag) + "\t" + txt + "\n"
else:
gt_label = poly_to_string(poly) + "\t" + txt + "\n"
gt_lists.append(gt_label)
gt_dict[img_path] = gt_lists
else:
continue
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for img_name in gt_dict.keys():
save_name = img_name.split("/")[-1]
save_file = os.path.join(save_dir, save_name + ".txt")
with open(save_file, "w") as f:
f.writelines(gt_dict[img_name])
print("The convert label saved in {}".format(save_dir))
if __name__ == "__main__":
ppocr_label_gt = "/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt"
convert_label(ppocr_label_gt, "gt", "./save_gt_310_V2/")
ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt"
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
import os
def draw_debug_img(html_path):
err_cnt = 0
with open(html_path, 'w') as html:
html.write('<html>\n<body>\n')
html.write('<table border="1">\n')
html.write(
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
)
image_list = []
path = "./det_results/310_gt/"
#path = "infer_results/"
for i, filename in enumerate(sorted(os.listdir(path))):
if filename.endswith("txt"): continue
print(filename)
# The image path
base = "{}/{}".format(path, filename)
base_2 = "../PaddleOCR/det_results/ch_PPOCRV2_infer/{}".format(
filename)
base_3 = "../PaddleOCR/det_results/ch_ppocr_mobile_infer/{}".format(
filename)
if True:
html.write("<tr>\n")
html.write(f'<td> {filename}\n GT')
html.write('<td>GT\n<img src="%s" width=640></td>' % (base))
html.write('<td>PPOCRV2\n<img src="%s" width=640></td>' %
(base_2))
html.write('<td>ppocr_mobile\n<img src="%s" width=640></td>' %
(base_3))
html.write("</tr>\n")
html.write('<style>\n')
html.write('span {\n')
html.write(' color: red;\n')
html.write('}\n')
html.write('</style>\n')
html.write('</table>\n')
html.write('</html>\n</body>\n')
print("ok")
#print("all cnt: {}, err cnt: {}, acc: {}".format(len(imgs), err_cnt, 1.0 * (len(imgs) - err_cnt) / len(imgs)))
return
if __name__ == "__main__":
html_path = "sys_visual_iou_310.html"
draw_debug_img()
#!/usr/bin/env python
import os
import re
import sys
# import Polygon
import shapely
from shapely.geometry import Polygon
import numpy as np
from collections import defaultdict
import operator
import editdistance
# reload(sys)
# sys.setdefaultencoding('utf-8')
def strQ2B(ustring):
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 12288:
inside_code = 32
elif (inside_code >= 65281 and inside_code <= 65374):
inside_code -= 65248
rstring += chr(inside_code)
return rstring
def polygon_from_str(polygon_points):
"""
Create a shapely polygon object from gt or dt line.
"""
polygon_points = np.array(polygon_points).reshape(4, 2)
polygon = Polygon(polygon_points).convex_hull
return polygon
def polygon_iou(poly1, poly2):
"""
Intersection over union between two shapely polygons.
"""
if not poly1.intersects(
poly2): # this test is fast and can accelerate calculation
iou = 0
else:
try:
inter_area = poly1.intersection(poly2).area
union_area = poly1.area + poly2.area - inter_area
iou = float(inter_area) / union_area
except shapely.geos.TopologicalError:
# except Exception as e:
# print(e)
print('shapely.geos.TopologicalError occured, iou set to 0')
iou = 0
return iou
def ed(str1, str2):
return editdistance.eval(str1, str2)
def e2e_eval(gt_dir, res_dir):
print('start testing...')
iou_thresh = 0.5
val_names = os.listdir(gt_dir)
num_gt_chars = 0
gt_count = 0
dt_count = 0
hit = 0
ed_sum = 0
for i, val_name in enumerate(val_names):
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
gt_lines = [o.strip() for o in f.readlines()]
gts = []
ignore_masks = []
for line in gt_lines:
parts = line.strip().split('\t')
# ignore illegal data
if len(parts) < 9:
continue
assert (len(parts) < 11)
if len(parts) == 9:
gts.append(parts[:8] + [''])
else:
gts.append(parts[:8] + [parts[-1]])
ignore_masks.append(parts[8])
val_path = os.path.join(res_dir, val_name)
if not os.path.exists(val_path):
dt_lines = []
else:
with open(val_path, encoding='utf-8') as f:
dt_lines = [o.strip() for o in f.readlines()]
dts = []
for line in dt_lines:
# print(line)
parts = line.strip().split("\t")
assert (len(parts) < 10), "line error: {}".format(line)
if len(parts) == 8:
dts.append(parts + [''])
else:
dts.append(parts)
dt_match = [False] * len(dts)
gt_match = [False] * len(gts)
all_ious = defaultdict(tuple)
for index_gt, gt in enumerate(gts):
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
gt_poly = polygon_from_str(gt_coors)
for index_dt, dt in enumerate(dts):
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
dt_poly = polygon_from_str(dt_coors)
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
sorted_ious = sorted(
all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
# matched gt and dt
for gt_dt_pair in sorted_gt_dt_pairs:
index_gt, index_dt = gt_dt_pair
if gt_match[index_gt] == False and dt_match[index_dt] == False:
gt_match[index_gt] = True
dt_match[index_dt] = True
# gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
# dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
gt_str = strQ2B(gts[index_gt][8])
dt_str = strQ2B(dts[index_dt][8])
if ignore_masks[index_gt] == '0':
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
if gt_str == dt_str:
hit += 1
gt_count += 1
dt_count += 1
# unmatched dt
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_str = dts[tindex][8]
gt_str = ''
ed_sum += ed(dt_str, gt_str)
dt_count += 1
# unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
if gt_match_flag == False and ignore_masks[tindex] == '0':
dt_str = ''
gt_str = gts[tindex][8]
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
gt_count += 1
eps = 1e-9
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
precision = hit / (dt_count + eps)
recall = hit / (gt_count + eps)
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
avg_edit_dist_img = ed_sum / len(val_names)
avg_edit_dist_field = ed_sum / (gt_count + eps)
character_acc = 1 - ed_sum / (num_gt_chars + eps)
print('character_acc: %.2f' % (character_acc * 100) + "%")
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
print('precision: %.2f' % (precision * 100) + "%")
print('recall: %.2f' % (recall * 100) + "%")
print('fmeasure: %.2f' % (fmeasure * 100) + "%")
if __name__ == '__main__':
# if len(sys.argv) != 3:
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
# exit(-1)
# gt_folder = sys.argv[1]
# pred_folder = sys.argv[2]
gt_folder = sys.argv[1]
pred_folder = sys.argv[2]
e2e_eval(gt_folder, pred_folder)
# 简介
`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
## 端对端评测步骤
**步骤一:**
运行`tools/infer/predict_system.py`,得到保存的结果:
```
python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
```
文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/results.txt`中,格式如下:
```
all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
```
**步骤二:**
将步骤一保存的数据转换为端对端评测需要的数据格式:
修改 `tools/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
```
ppocr_label_gt = "gt_label.txt"
convert_label(ppocr_label_gt, "gt", "./save_gt_label/")
ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt"
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
```
运行`convert_ppocr_label.py`:
```
python3 tools/convert_ppocr_label.py
```
得到如下结果:
```
├── ./save_gt_label/
├── ./save_PPOCRV2_infer/
```
**步骤三:**
执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下:
```
python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
```
比如:
```
python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
```
将得到如下结果,fmeasure为主要关注的指标:
```
hit, dt_count, gt_count 1557 2693 3283
character_acc: 61.77%
avg_edit_dist_field: 3.08
avg_edit_dist_img: 51.82
precision: 57.82%
recall: 47.43%
fmeasure: 52.11%
```
...@@ -27,6 +27,7 @@ import numpy as np ...@@ -27,6 +27,7 @@ import numpy as np
import time import time
import logging import logging
from PIL import Image from PIL import Image
import json
import tools.infer.utility as utility import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
...@@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes): ...@@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes):
return _boxes return _boxes
def save_results_to_txt(results, path):
if os.path.isdir(path):
if not os.path.exists(path):
os.makedirs(path)
with open(os.path.join(path, "results.txt"), 'w') as f:
f.writelines(results)
f.close()
logger.info("The results will be saved in {}".format(
os.path.join(path, "results.txt")))
else:
draw_img_save = os.path.dirname(path)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
with open(path, 'w') as f:
f.writelines(results)
f.close()
logger.info("The results will be saved in {}".format(path))
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num] image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = args.is_visualize
font_path = args.vis_font_path font_path = args.vis_font_path
drop_score = args.drop_score drop_score = args.drop_score
...@@ -139,6 +160,7 @@ def main(args): ...@@ -139,6 +160,7 @@ def main(args):
cpu_mem, gpu_mem, gpu_util = 0, 0, 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time() _st = time.time()
count = 0 count = 0
save_res = []
for idx, image_file in enumerate(image_file_list): for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
...@@ -152,6 +174,21 @@ def main(args): ...@@ -152,6 +174,21 @@ def main(args):
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
# save results
preds = []
dt_num = len(dt_boxes)
for dno in range(dt_num):
text, score = rec_res[dno]
if score >= drop_score:
preds.append({
"transcription": text,
"points": np.array(dt_boxes[dno]).tolist()
})
text_str = "%s, %.3f" % (text, score)
save_res.append(image_file + '\t' + json.dumps(
preds, ensure_ascii=False) + '\n')
# print predicted results
logger.debug( logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res: for text, score in rec_res:
...@@ -180,6 +217,9 @@ def main(args): ...@@ -180,6 +217,9 @@ def main(args):
logger.debug("The visualized image saved in {}".format( logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file)))) os.path.join(draw_img_save_dir, os.path.basename(image_file))))
# The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt")
save_results_to_txt(save_res, args.draw_img_save_dir)
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark: if args.benchmark:
text_sys.text_detector.autolog.report() text_sys.text_detector.autolog.report()
......
...@@ -114,6 +114,7 @@ def init_args(): ...@@ -114,6 +114,7 @@ def init_args():
# #
parser.add_argument( parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results") "--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--is_visualize", type=str2bool, default=True)
parser.add_argument("--save_crop_res", type=str2bool, default=False) parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output") parser.add_argument("--crop_res_save_dir", type=str, default="./output")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册