eval_rec_utils.py 6.0 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import numpy as np

import paddle.fluid as fluid

__all__ = ['eval_rec_run', 'test_rec_benchmark']

import logging

FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

T
tink2123 已提交
32
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn
L
LDOUBLEV 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
from ppocr.utils.character import convert_rec_label_to_lod
from ppocr.utils.character import convert_rec_attention_infer_res
from ppocr.utils.utility import create_module
import json
from copy import deepcopy
import cv2
from ppocr.data.reader_main import reader_main


def eval_rec_run(exe, config, eval_info_dict, mode):
    """
    Run evaluation program, return program outputs.
    """
    char_ops = config['Global']['char_ops']
    total_loss = 0
    total_sample_num = 0
    total_acc_num = 0
    total_batch_num = 0
T
tink2123 已提交
51
    if mode == "eval":
L
LDOUBLEV 已提交
52 53 54 55 56 57 58 59 60 61 62
        is_remove_duplicate = False
    else:
        is_remove_duplicate = True

    for data in eval_info_dict['reader']():
        img_num = len(data)
        img_list = []
        label_list = []
        for ino in range(img_num):
            img_list.append(data[ino][0])
            label_list.append(data[ino][1])
T
tink2123 已提交
63

T
tink2123 已提交
64
        if config['Global']['loss_type'] != "srn":
T
tink2123 已提交
65 66
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_info_dict['program'], \
L
LDOUBLEV 已提交
67 68 69
                       feed={'image': img_list}, \
                       fetch_list=eval_info_dict['fetch_varname_list'], \
                       return_numpy=False)
T
tink2123 已提交
70 71
            preds = np.array(outs[0])

T
tink2123 已提交
72
            if config['Global']['loss_type'] == "attention":
T
tink2123 已提交
73 74 75 76 77
                preds, preds_lod = convert_rec_attention_infer_res(preds)
            else:
                preds_lod = outs[0].lod()[0]
            labels, labels_lod = convert_rec_label_to_lod(label_list)
            acc, acc_num, sample_num = cal_predicts_accuracy(
T
tink2123 已提交
78 79
                char_ops, preds, preds_lod, labels, labels_lod,
                is_remove_duplicate)
L
LDOUBLEV 已提交
80
        else:
T
tink2123 已提交
81 82 83 84 85 86 87 88 89 90 91 92
            encoder_word_pos_list = []
            gsrm_word_pos_list = []
            gsrm_slf_attn_bias1_list = []
            gsrm_slf_attn_bias2_list = []
            for ino in range(img_num):
                encoder_word_pos_list.append(data[ino][2])
                gsrm_word_pos_list.append(data[ino][3])
                gsrm_slf_attn_bias1_list.append(data[ino][4])
                gsrm_slf_attn_bias2_list.append(data[ino][5])

            img_list = np.concatenate(img_list, axis=0)
            label_list = np.concatenate(label_list, axis=0)
T
tink2123 已提交
93 94 95 96 97 98 99 100
            encoder_word_pos_list = np.concatenate(
                encoder_word_pos_list, axis=0).astype(np.int64)
            gsrm_word_pos_list = np.concatenate(
                gsrm_word_pos_list, axis=0).astype(np.int64)
            gsrm_slf_attn_bias1_list = np.concatenate(
                gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
            gsrm_slf_attn_bias2_list = np.concatenate(
                gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
T
tink2123 已提交
101 102 103 104

            labels = label_list

            outs = exe.run(eval_info_dict['program'], \
T
tink2123 已提交
105
                       feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list,
T
tink2123 已提交
106 107 108 109 110 111 112 113
                             'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
                             'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
                       fetch_list=eval_info_dict['fetch_varname_list'], \
                       return_numpy=False)
            preds = np.array(outs[0])
            acc, acc_num, sample_num = cal_predicts_accuracy_srn(
                char_ops, preds, labels, config['Global']['max_text_length'])

L
LDOUBLEV 已提交
114 115
        total_acc_num += acc_num
        total_sample_num += sample_num
T
tink2123 已提交
116
        #logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
L
LDOUBLEV 已提交
117 118 119 120 121 122 123 124
        total_batch_num += 1
    avg_acc = total_acc_num * 1.0 / total_sample_num
    metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
               "total_sample_num": total_sample_num}
    return metrics


def test_rec_benchmark(exe, config, eval_info_dict):
125
    " Evaluate lmdb dataset "
T
tink2123 已提交
126 127
    eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
                      'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
L
LDOUBLEV 已提交
128 129 130 131 132
    eval_data_dir = config['TestReader']['lmdb_sets_dir']
    total_evaluation_data_number = 0
    total_correct_number = 0
    eval_data_acc_info = {}
    for eval_data in eval_data_list:
T
tink2123 已提交
133
        config['TestReader']['lmdb_sets_dir'] = \
L
LDOUBLEV 已提交
134
            eval_data_dir + "/" + eval_data
T
tink2123 已提交
135
        eval_reader = reader_main(config=config, mode="test")
L
LDOUBLEV 已提交
136
        eval_info_dict['reader'] = eval_reader
T
tink2123 已提交
137
        metrics = eval_rec_run(exe, config, eval_info_dict, "test")
L
LDOUBLEV 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150
        total_evaluation_data_number += metrics['total_sample_num']
        total_correct_number += metrics['total_acc_num']
        eval_data_acc_info[eval_data] = metrics

    avg_acc = total_correct_number * 1.0 / total_evaluation_data_number
    logger.info('-' * 50)
    strs = ""
    for eval_data in eval_data_list:
        eval_acc = eval_data_acc_info[eval_data]['avg_acc']
        strs += "\n {}, accuracy:{:.6f}".format(eval_data, eval_acc)
    strs += "\n average, accuracy:{:.6f}".format(avg_acc)
    logger.info(strs)
    logger.info('-' * 50)