test_precision_recall_op.py 6.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
yangyaming 已提交
17 18
import unittest
import numpy as np
19
from op_test import OpTest
Y
yangyaming 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39


def calc_precision(tp_count, fp_count):
    if tp_count > 0.0 or fp_count > 0.0:
        return tp_count / (tp_count + fp_count)
    return 1.0


def calc_recall(tp_count, fn_count):
    if tp_count > 0.0 or fn_count > 0.0:
        return tp_count / (tp_count + fn_count)
    return 1.0


def calc_f1_score(precision, recall):
    if precision > 0.0 or recall > 0.0:
        return 2 * precision * recall / (precision + recall)
    return 0.0


Y
yangyaming 已提交
40 41
def get_states(idxs, labels, cls_num, weights=None):
    ins_num = idxs.shape[0]
Y
yangyaming 已提交
42
    # TP FP TN FN
Y
yangyaming 已提交
43
    states = np.zeros((cls_num, 4)).astype('float32')
44
    for i in range(ins_num):
Y
yangyaming 已提交
45
        w = weights[i] if weights is not None else 1.0
Y
yangyaming 已提交
46 47 48 49
        idx = idxs[i][0]
        label = labels[i][0]
        if idx == label:
            states[idx][0] += w
50
            for j in range(cls_num):
Y
yangyaming 已提交
51
                states[j][2] += w
Y
yangyaming 已提交
52
            states[idx][2] -= w
Y
yangyaming 已提交
53
        else:
Y
yangyaming 已提交
54 55
            states[label][3] += w
            states[idx][1] += w
56
            for j in range(cls_num):
Y
yangyaming 已提交
57
                states[j][2] += w
Y
yangyaming 已提交
58 59
            states[label][2] -= w
            states[idx][2] -= w
Y
yangyaming 已提交
60 61 62
    return states


Y
yangyaming 已提交
63
def compute_metrics(states, cls_num):
Y
yangyaming 已提交
64 65 66 67 68
    total_tp_count = 0.0
    total_fp_count = 0.0
    total_fn_count = 0.0
    macro_avg_precision = 0.0
    macro_avg_recall = 0.0
69
    for i in range(cls_num):
Y
yangyaming 已提交
70 71 72 73 74 75
        total_tp_count += states[i][0]
        total_fp_count += states[i][1]
        total_fn_count += states[i][3]
        macro_avg_precision += calc_precision(states[i][0], states[i][1])
        macro_avg_recall += calc_recall(states[i][0], states[i][3])
    metrics = []
Y
yangyaming 已提交
76 77
    macro_avg_precision /= cls_num
    macro_avg_recall /= cls_num
Y
yangyaming 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    metrics.append(macro_avg_precision)
    metrics.append(macro_avg_recall)
    metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
    micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
    metrics.append(micro_avg_precision)
    micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
    metrics.append(micro_avg_recall)
    metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
    return np.array(metrics).astype('float32')


class TestPrecisionRecallOp_0(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
93 94
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
95
        idxs = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
96
            (ins_num, 1)).astype('int32')
97
        labels = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
98 99 100 101 102
            (ins_num, 1)).astype('int32')
        states = get_states(idxs, labels, cls_num)
        metrics = compute_metrics(states, cls_num)

        self.attrs = {'class_number': cls_num}
Y
yangyaming 已提交
103

Y
yangyaming 已提交
104
        self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels}
Y
yangyaming 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

        self.outputs = {
            'BatchMetrics': metrics,
            'AccumMetrics': metrics,
            'AccumStatesInfo': states
        }

    def test_check_output(self):
        self.check_output()


class TestPrecisionRecallOp_1(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
120 121
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
122
        idxs = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
123
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
124
        weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
125
        labels = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
126 127
            (ins_num, 1)).astype('int32')

Y
yangyaming 已提交
128 129 130 131 132
        states = get_states(idxs, labels, cls_num, weights)
        metrics = compute_metrics(states, cls_num)

        self.attrs = {'class_number': cls_num}

Y
yangyaming 已提交
133
        self.inputs = {
Y
yangyaming 已提交
134 135
            'MaxProbs': max_probs,
            'Indices': idxs,
Y
yangyaming 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
            'Labels': labels,
            'Weights': weights
        }

        self.outputs = {
            'BatchMetrics': metrics,
            'AccumMetrics': metrics,
            'AccumStatesInfo': states
        }

    def test_check_output(self):
        self.check_output()


class TestPrecisionRecallOp_2(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
154 155
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
156
        idxs = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
157
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
158
        weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
159
        labels = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
160
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
161
        states = np.random.randint(0, 30, (cls_num, 4)).astype('float32')
Y
yangyaming 已提交
162

Y
yangyaming 已提交
163 164
        accum_states = get_states(idxs, labels, cls_num, weights)
        batch_metrics = compute_metrics(accum_states, cls_num)
Y
yangyaming 已提交
165
        accum_states += states
Y
yangyaming 已提交
166 167 168
        accum_metrics = compute_metrics(accum_states, cls_num)

        self.attrs = {'class_number': cls_num}
Y
yangyaming 已提交
169 170

        self.inputs = {
Y
yangyaming 已提交
171 172
            'MaxProbs': max_probs,
            'Indices': idxs,
Y
yangyaming 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
            'Labels': labels,
            'Weights': weights,
            'StatesInfo': states
        }

        self.outputs = {
            'BatchMetrics': batch_metrics,
            'AccumMetrics': accum_metrics,
            'AccumStatesInfo': accum_states
        }

    def test_check_output(self):
        self.check_output()


if __name__ == '__main__':
    unittest.main()