提交 d9d51f7d 编写于 作者: Y yukavio

rm eval_det_utils

上级 d4f1758d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import logging
import numpy as np
import paddle.fluid as fluid
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
__all__ = ['eval_det_run']
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
import cv2
import json
from copy import deepcopy
from ppocr.utils.utility import create_module
from ppocr.data.reader_main import reader_main
from tools.eval_utils.eval_det_iou import DetectionIoUEvaluator
def cal_det_res(exe, config, eval_info_dict):
global_params = config['Global']
save_res_path = global_params['save_res_path']
postprocess_params = deepcopy(config["PostProcess"])
postprocess_params.update(global_params)
postprocess = create_module(postprocess_params['function']) \
(params=postprocess_params)
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
with open(save_res_path, "wb") as fout:
tackling_num = 0
for data in eval_info_dict['reader']():
img_num = len(data)
tackling_num = tackling_num + img_num
logger.info("test tackling num:%d", tackling_num)
img_list = []
ratio_list = []
img_name_list = []
for ino in range(img_num):
img_list.append(data[ino][0])
ratio_list.append(data[ino][1])
img_name_list.append(data[ino][2])
try:
img_list = np.concatenate(img_list, axis=0)
except:
err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \
Please set \"test_batch_size_per_card\" in main yml as 1\n \
or add \"test_image_shape: [h, w]\" in reader yml for EvalReader."
raise Exception(err)
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list}, \
fetch_list=eval_info_dict['fetch_varname_list'])
outs_dict = {}
for tno in range(len(outs)):
fetch_name = eval_info_dict['fetch_name_list'][tno]
fetch_value = np.array(outs[tno])
outs_dict[fetch_name] = fetch_value
dt_boxes_list = postprocess(outs_dict, ratio_list)
for ino in range(img_num):
dt_boxes = dt_boxes_list[ino]
img_name = img_name_list[ino]
dt_boxes_json = []
for box in dt_boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_json.append(tmp_json)
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
return
def load_label_infor(label_file_path, do_ignore=False):
img_name_label_dict = {}
with open(label_file_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
substr = line.decode().strip("\n").split("\t")
bbox_infor = json.loads(substr[1])
bbox_num = len(bbox_infor)
for bno in range(bbox_num):
text = bbox_infor[bno]['transcription']
ignore = False
if text == "###" and do_ignore:
ignore = True
bbox_infor[bno]['ignore'] = ignore
img_name_label_dict[os.path.basename(substr[0])] = bbox_infor
return img_name_label_dict
def cal_det_metrics(gt_label_path, save_res_path):
"""
calculate the detection metrics
Args:
gt_label_path(string): The groundtruth detection label file path
save_res_path(string): The saved predicted detection label path
return:
claculated metrics including Hmean, precision and recall
"""
evaluator = DetectionIoUEvaluator()
gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
dt_label_infor = load_label_infor(save_res_path)
results = []
for img_name in gt_label_infor:
gt_label = gt_label_infor[img_name]
if img_name not in dt_label_infor:
dt_label = []
else:
dt_label = dt_label_infor[img_name]
result = evaluator.evaluate_image(gt_label, dt_label)
results.append(result)
methodMetrics = evaluator.combine_results(results)
return methodMetrics
def eval_det_run(eval_args, mode='eval'):
exe = eval_args['exe']
config = eval_args['config']
eval_info_dict = eval_args['eval_info_dict']
cal_det_res(exe, config, eval_info_dict)
save_res_path = config['Global']['save_res_path']
if mode == "eval":
gt_label_path = config['EvalReader']['label_file_path']
metrics = cal_det_metrics(gt_label_path, save_res_path)
else:
gt_label_path = config['TestReader']['label_file_path']
do_eval = config['TestReader']['do_eval']
if do_eval:
metrics = cal_det_metrics(gt_label_path, save_res_path)
else:
metrics = {}
return metrics['hmean']
...@@ -104,14 +104,6 @@ def main(): ...@@ -104,14 +104,6 @@ def main():
# compile program for multi-devices # compile program for multi-devices
init_model(config, train_program, exe) init_model(config, train_program, exe)
# params = get_pruned_params(train_program)
'''
sens_file = ['sensitivities_'+ str(x) for x in range(0,4)]
sens = []
for f in sens_file:
sens.append(load_sensitivities(f+'.data'))
sen = merge_sensitive(sens)
'''
sen = load_sensitivities("sensitivities_0.data") sen = load_sensitivities("sensitivities_0.data")
for i in skip_list: for i in skip_list:
sen.pop(i) sen.pop(i)
...@@ -161,28 +153,7 @@ def main(): ...@@ -161,28 +153,7 @@ def main():
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
def test_reader():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
print(config)
train_reader = reader_main(config=config, mode="train")
import time
starttime = time.time()
count = 0
try:
for data in train_reader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
print("reader:", count, len(data), batch_time)
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
if __name__ == '__main__': if __name__ == '__main__':
parser = program.ArgsParser() parser = program.ArgsParser()
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
# test_reader()
...@@ -42,7 +42,7 @@ import cv2 ...@@ -42,7 +42,7 @@ import cv2
from paddle import fluid from paddle import fluid
import paddleslim as slim import paddleslim as slim
from copy import deepcopy from copy import deepcopy
from eval_det_utils import eval_det_run from tools.eval_utils.eval_det_utils import eval_det_run
from tools import program from tools import program
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
...@@ -65,6 +65,14 @@ def get_pruned_params(program): ...@@ -65,6 +65,14 @@ def get_pruned_params(program):
return params return params
def eval_function(eval_args, mode='eval'):
exe = eval_args['exe']
config = eval_args['config']
eval_info_dict = eval_args['eval_info_dict']
metrics = eval_det_run(exe, config, eval_info_dict, mode=mode)
return metrics['hmean']
def main(): def main():
config = program.load_config(FLAGS.config) config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
...@@ -99,7 +107,7 @@ def main(): ...@@ -99,7 +107,7 @@ def main():
'fetch_varname_list':eval_fetch_varname_list} 'fetch_varname_list':eval_fetch_varname_list}
eval_args = dict() eval_args = dict()
eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict} eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict}
metrics = eval_det_run(eval_args) metrics = eval_function(eval_args)
print("Baseline: {}".format(metrics)) print("Baseline: {}".format(metrics))
params = get_pruned_params(eval_program) params = get_pruned_params(eval_program)
...@@ -108,7 +116,7 @@ def main(): ...@@ -108,7 +116,7 @@ def main():
eval_program, eval_program,
place, place,
params, params,
eval_det_run, eval_function,
sensitivities_file="sensitivities_0.data", sensitivities_file="sensitivities_0.data",
pruned_ratios=[0.1], pruned_ratios=[0.1],
eval_args=eval_args, eval_args=eval_args,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册