eval_det_utils.py 5.6 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import numpy as np

import paddle.fluid as fluid

__all__ = ['eval_det_run']

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

from ppocr.utils.utility import create_module
from .eval_det_iou import DetectionIoUEvaluator
import json
from copy import deepcopy
import cv2
from ppocr.data.reader_main import reader_main
L
LDOUBLEV 已提交
37
import os
L
LDOUBLEV 已提交
38 39 40 41 42 43 44 45 46


def cal_det_res(exe, config, eval_info_dict):
    global_params = config['Global']
    save_res_path = global_params['save_res_path']
    postprocess_params = deepcopy(config["PostProcess"])
    postprocess_params.update(global_params)
    postprocess = create_module(postprocess_params['function']) \
        (params=postprocess_params)
L
LDOUBLEV 已提交
47 48
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))
L
LDOUBLEV 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
    with open(save_res_path, "wb") as fout:
        tackling_num = 0
        for data in eval_info_dict['reader']():
            img_num = len(data)
            tackling_num = tackling_num + img_num
            logger.info("test tackling num:%d", tackling_num)
            img_list = []
            ratio_list = []
            img_name_list = []
            for ino in range(img_num):
                img_list.append(data[ino][0])
                ratio_list.append(data[ino][1])
                img_name_list.append(data[ino][2])
D
dyning 已提交
62 63 64
            try:
                img_list = np.concatenate(img_list, axis=0)
            except:
D
dyning 已提交
65
                err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \
D
dyning 已提交
66 67
                Please set \"test_batch_size_per_card\" in main yml as 1\n \
                or add \"test_image_shape: [h, w]\" in reader yml for EvalReader."
68

D
dyning 已提交
69
                raise Exception(err)
L
LDOUBLEV 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            outs = exe.run(eval_info_dict['program'], \
                           feed={'image': img_list}, \
                           fetch_list=eval_info_dict['fetch_varname_list'])
            outs_dict = {}
            for tno in range(len(outs)):
                fetch_name = eval_info_dict['fetch_name_list'][tno]
                fetch_value = np.array(outs[tno])
                outs_dict[fetch_name] = fetch_value
            dt_boxes_list = postprocess(outs_dict, ratio_list)
            for ino in range(img_num):
                dt_boxes = dt_boxes_list[ino]
                img_name = img_name_list[ino]
                dt_boxes_json = []
                for box in dt_boxes:
                    tmp_json = {"transcription": ""}
                    tmp_json['points'] = box.tolist()
                    dt_boxes_json.append(tmp_json)
                otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
                fout.write(otstr.encode())
    return


def load_label_infor(label_file_path, do_ignore=False):
    img_name_label_dict = {}
    with open(label_file_path, "rb") as fin:
        lines = fin.readlines()
        for line in lines:
            substr = line.decode().strip("\n").split("\t")
            bbox_infor = json.loads(substr[1])
            bbox_num = len(bbox_infor)
            for bno in range(bbox_num):
                text = bbox_infor[bno]['transcription']
                ignore = False
                if text == "###" and do_ignore:
                    ignore = True
                bbox_infor[bno]['ignore'] = ignore
106
            img_name_label_dict[os.path.basename(substr[0])] = bbox_infor
L
LDOUBLEV 已提交
107 108 109 110
    return img_name_label_dict


def cal_det_metrics(gt_label_path, save_res_path):
L
LDOUBLEV 已提交
111 112 113 114 115 116
    """
    calculate the detection metrics
    Args:
        gt_label_path(string): The groundtruth detection label file path
        save_res_path(string): The saved predicted detection label path
    return:
117
        claculated metrics including Hmean, precision and recall
L
LDOUBLEV 已提交
118
    """
L
LDOUBLEV 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    evaluator = DetectionIoUEvaluator()
    gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
    dt_label_infor = load_label_infor(save_res_path)
    results = []
    for img_name in gt_label_infor:
        gt_label = gt_label_infor[img_name]
        if img_name not in dt_label_infor:
            dt_label = []
        else:
            dt_label = dt_label_infor[img_name]
        result = evaluator.evaluate_image(gt_label, dt_label)
        results.append(result)
    methodMetrics = evaluator.combine_results(results)
    return methodMetrics


def eval_det_run(exe, config, eval_info_dict, mode):
    cal_det_res(exe, config, eval_info_dict)

    save_res_path = config['Global']['save_res_path']
    if mode == "eval":
        gt_label_path = config['EvalReader']['label_file_path']
        metrics = cal_det_metrics(gt_label_path, save_res_path)
    else:
        gt_label_path = config['TestReader']['label_file_path']
        do_eval = config['TestReader']['do_eval']
        if do_eval:
            metrics = cal_det_metrics(gt_label_path, save_res_path)
        else:
            metrics = {}
    return metrics