eval_rec_utils.py 5.8 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import numpy as np

import paddle.fluid as fluid

__all__ = ['eval_rec_run', 'test_rec_benchmark']

import logging

FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

T
tink2123 已提交
32
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn
L
LDOUBLEV 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
from ppocr.utils.character import convert_rec_label_to_lod
from ppocr.utils.character import convert_rec_attention_infer_res
from ppocr.utils.utility import create_module
import json
from copy import deepcopy
import cv2
from ppocr.data.reader_main import reader_main


def eval_rec_run(exe, config, eval_info_dict, mode):
    """
    Run evaluation program, return program outputs.
    """
    char_ops = config['Global']['char_ops']
    total_loss = 0
    total_sample_num = 0
    total_acc_num = 0
    total_batch_num = 0
T
tink2123 已提交
51
    if mode == "eval":
L
LDOUBLEV 已提交
52 53 54 55 56 57 58 59 60 61 62
        is_remove_duplicate = False
    else:
        is_remove_duplicate = True

    for data in eval_info_dict['reader']():
        img_num = len(data)
        img_list = []
        label_list = []
        for ino in range(img_num):
            img_list.append(data[ino][0])
            label_list.append(data[ino][1])
T
tink2123 已提交
63 64 65 66

        if config['Global']['loss_type'] != "srn": 
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_info_dict['program'], \
L
LDOUBLEV 已提交
67 68 69
                       feed={'image': img_list}, \
                       fetch_list=eval_info_dict['fetch_varname_list'], \
                       return_numpy=False)
T
tink2123 已提交
70 71 72 73 74 75 76 77 78
            preds = np.array(outs[0])

            if preds.shape[1] != 1:
                preds, preds_lod = convert_rec_attention_infer_res(preds)
            else:
                preds_lod = outs[0].lod()[0]
            labels, labels_lod = convert_rec_label_to_lod(label_list)
            acc, acc_num, sample_num = cal_predicts_accuracy(
                char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
L
LDOUBLEV 已提交
79
        else:
T
tink2123 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
            encoder_word_pos_list = []
            gsrm_word_pos_list = []
            gsrm_slf_attn_bias1_list = []
            gsrm_slf_attn_bias2_list = []
            for ino in range(img_num):
                encoder_word_pos_list.append(data[ino][2])
                gsrm_word_pos_list.append(data[ino][3])
                gsrm_slf_attn_bias1_list.append(data[ino][4])
                gsrm_slf_attn_bias2_list.append(data[ino][5])

            img_list = np.concatenate(img_list, axis=0)
            label_list = np.concatenate(label_list, axis=0)
            encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64)
            gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64)
            gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
            gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)

            labels = label_list

            outs = exe.run(eval_info_dict['program'], \
                       feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list, 
                             'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
                             'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
                       fetch_list=eval_info_dict['fetch_varname_list'], \
                       return_numpy=False)
            preds = np.array(outs[0])
            acc, acc_num, sample_num = cal_predicts_accuracy_srn(
                char_ops, preds, labels, config['Global']['max_text_length'])

L
LDOUBLEV 已提交
109 110
        total_acc_num += acc_num
        total_sample_num += sample_num
littletomatodonkey's avatar
littletomatodonkey 已提交
111
        logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
L
LDOUBLEV 已提交
112 113 114 115 116 117 118 119
        total_batch_num += 1
    avg_acc = total_acc_num * 1.0 / total_sample_num
    metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
               "total_sample_num": total_sample_num}
    return metrics


def test_rec_benchmark(exe, config, eval_info_dict):
120
    " Evaluate lmdb dataset "
T
tink2123 已提交
121 122
    eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860',  \
                      'IC13_857', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
L
LDOUBLEV 已提交
123 124 125 126 127
    eval_data_dir = config['TestReader']['lmdb_sets_dir']
    total_evaluation_data_number = 0
    total_correct_number = 0
    eval_data_acc_info = {}
    for eval_data in eval_data_list:
T
tink2123 已提交
128
        config['TestReader']['lmdb_sets_dir'] = \
L
LDOUBLEV 已提交
129
            eval_data_dir + "/" + eval_data
T
tink2123 已提交
130
        eval_reader = reader_main(config=config, mode="test")
L
LDOUBLEV 已提交
131
        eval_info_dict['reader'] = eval_reader
T
tink2123 已提交
132
        metrics = eval_rec_run(exe, config, eval_info_dict, "test")
L
LDOUBLEV 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145
        total_evaluation_data_number += metrics['total_sample_num']
        total_correct_number += metrics['total_acc_num']
        eval_data_acc_info[eval_data] = metrics

    avg_acc = total_correct_number * 1.0 / total_evaluation_data_number
    logger.info('-' * 50)
    strs = ""
    for eval_data in eval_data_list:
        eval_acc = eval_data_acc_info[eval_data]['avg_acc']
        strs += "\n {}, accuracy:{:.6f}".format(eval_data, eval_acc)
    strs += "\n average, accuracy:{:.6f}".format(avg_acc)
    logger.info(strs)
    logger.info('-' * 50)