test_precision_recall_op.py 6.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yangyaming 已提交
15 16
import unittest
import numpy as np
17
from op_test import OpTest
Y
yangyaming 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37


def calc_precision(tp_count, fp_count):
    if tp_count > 0.0 or fp_count > 0.0:
        return tp_count / (tp_count + fp_count)
    return 1.0


def calc_recall(tp_count, fn_count):
    if tp_count > 0.0 or fn_count > 0.0:
        return tp_count / (tp_count + fn_count)
    return 1.0


def calc_f1_score(precision, recall):
    if precision > 0.0 or recall > 0.0:
        return 2 * precision * recall / (precision + recall)
    return 0.0


Y
yangyaming 已提交
38 39
def get_states(idxs, labels, cls_num, weights=None):
    ins_num = idxs.shape[0]
Y
yangyaming 已提交
40
    # TP FP TN FN
Y
yangyaming 已提交
41
    states = np.zeros((cls_num, 4)).astype('float32')
42
    for i in range(ins_num):
Y
yangyaming 已提交
43
        w = weights[i] if weights is not None else 1.0
Y
yangyaming 已提交
44 45 46 47
        idx = idxs[i][0]
        label = labels[i][0]
        if idx == label:
            states[idx][0] += w
48
            for j in range(cls_num):
Y
yangyaming 已提交
49
                states[j][2] += w
Y
yangyaming 已提交
50
            states[idx][2] -= w
Y
yangyaming 已提交
51
        else:
Y
yangyaming 已提交
52 53
            states[label][3] += w
            states[idx][1] += w
54
            for j in range(cls_num):
Y
yangyaming 已提交
55
                states[j][2] += w
Y
yangyaming 已提交
56 57
            states[label][2] -= w
            states[idx][2] -= w
Y
yangyaming 已提交
58 59 60
    return states


Y
yangyaming 已提交
61
def compute_metrics(states, cls_num):
Y
yangyaming 已提交
62 63 64 65 66
    total_tp_count = 0.0
    total_fp_count = 0.0
    total_fn_count = 0.0
    macro_avg_precision = 0.0
    macro_avg_recall = 0.0
67
    for i in range(cls_num):
Y
yangyaming 已提交
68 69 70 71 72 73
        total_tp_count += states[i][0]
        total_fp_count += states[i][1]
        total_fn_count += states[i][3]
        macro_avg_precision += calc_precision(states[i][0], states[i][1])
        macro_avg_recall += calc_recall(states[i][0], states[i][3])
    metrics = []
Y
yangyaming 已提交
74 75
    macro_avg_precision /= cls_num
    macro_avg_recall /= cls_num
Y
yangyaming 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    metrics.append(macro_avg_precision)
    metrics.append(macro_avg_recall)
    metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
    micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
    metrics.append(micro_avg_precision)
    micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
    metrics.append(micro_avg_recall)
    metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
    return np.array(metrics).astype('float32')


class TestPrecisionRecallOp_0(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
91 92
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
93
        idxs = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
94
            (ins_num, 1)).astype('int32')
95
        labels = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
96 97 98 99 100
            (ins_num, 1)).astype('int32')
        states = get_states(idxs, labels, cls_num)
        metrics = compute_metrics(states, cls_num)

        self.attrs = {'class_number': cls_num}
Y
yangyaming 已提交
101

Y
yangyaming 已提交
102
        self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels}
Y
yangyaming 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

        self.outputs = {
            'BatchMetrics': metrics,
            'AccumMetrics': metrics,
            'AccumStatesInfo': states
        }

    def test_check_output(self):
        self.check_output()


class TestPrecisionRecallOp_1(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
118 119
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
120
        idxs = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
121
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
122
        weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
123
        labels = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
124 125
            (ins_num, 1)).astype('int32')

Y
yangyaming 已提交
126 127 128 129 130
        states = get_states(idxs, labels, cls_num, weights)
        metrics = compute_metrics(states, cls_num)

        self.attrs = {'class_number': cls_num}

Y
yangyaming 已提交
131
        self.inputs = {
Y
yangyaming 已提交
132 133
            'MaxProbs': max_probs,
            'Indices': idxs,
Y
yangyaming 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
            'Labels': labels,
            'Weights': weights
        }

        self.outputs = {
            'BatchMetrics': metrics,
            'AccumMetrics': metrics,
            'AccumStatesInfo': states
        }

    def test_check_output(self):
        self.check_output()


class TestPrecisionRecallOp_2(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
152 153
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
154
        idxs = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
155
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
156
        weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
157
        labels = np.random.choice(range(cls_num), ins_num).reshape(
Y
yangyaming 已提交
158
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
159
        states = np.random.randint(0, 30, (cls_num, 4)).astype('float32')
Y
yangyaming 已提交
160

Y
yangyaming 已提交
161 162
        accum_states = get_states(idxs, labels, cls_num, weights)
        batch_metrics = compute_metrics(accum_states, cls_num)
Y
yangyaming 已提交
163
        accum_states += states
Y
yangyaming 已提交
164 165 166
        accum_metrics = compute_metrics(accum_states, cls_num)

        self.attrs = {'class_number': cls_num}
Y
yangyaming 已提交
167 168

        self.inputs = {
Y
yangyaming 已提交
169 170
            'MaxProbs': max_probs,
            'Indices': idxs,
Y
yangyaming 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
            'Labels': labels,
            'Weights': weights,
            'StatesInfo': states
        }

        self.outputs = {
            'BatchMetrics': batch_metrics,
            'AccumMetrics': accum_metrics,
            'AccumStatesInfo': accum_states
        }

    def test_check_output(self):
        self.check_output()


if __name__ == '__main__':
    unittest.main()