test_precision_recall_op.py 5.4 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import unittest
import numpy as np
from op_test import OpTest


def calc_precision(tp_count, fp_count):
    if tp_count > 0.0 or fp_count > 0.0:
        return tp_count / (tp_count + fp_count)
    return 1.0


def calc_recall(tp_count, fn_count):
    if tp_count > 0.0 or fn_count > 0.0:
        return tp_count / (tp_count + fn_count)
    return 1.0


def calc_f1_score(precision, recall):
    if precision > 0.0 or recall > 0.0:
        return 2 * precision * recall / (precision + recall)
    return 0.0


Y
yangyaming 已提交
24 25
def get_states(idxs, labels, cls_num, weights=None):
    ins_num = idxs.shape[0]
Y
yangyaming 已提交
26
    # TP FP TN FN
Y
yangyaming 已提交
27
    states = np.zeros((cls_num, 4)).astype('float32')
Y
yangyaming 已提交
28 29
    for i in xrange(ins_num):
        w = weights[i] if weights is not None else 1.0
Y
yangyaming 已提交
30 31 32 33 34
        idx = idxs[i][0]
        label = labels[i][0]
        if idx == label:
            states[idx][0] += w
            for j in xrange(cls_num):
Y
yangyaming 已提交
35
                states[j][2] += w
Y
yangyaming 已提交
36
            states[idx][2] -= w
Y
yangyaming 已提交
37
        else:
Y
yangyaming 已提交
38 39 40
            states[label][3] += w
            states[idx][1] += w
            for j in xrange(cls_num):
Y
yangyaming 已提交
41
                states[j][2] += w
Y
yangyaming 已提交
42 43
            states[label][2] -= w
            states[idx][2] -= w
Y
yangyaming 已提交
44 45 46
    return states


Y
yangyaming 已提交
47
def compute_metrics(states, cls_num):
Y
yangyaming 已提交
48 49 50 51 52
    total_tp_count = 0.0
    total_fp_count = 0.0
    total_fn_count = 0.0
    macro_avg_precision = 0.0
    macro_avg_recall = 0.0
Y
yangyaming 已提交
53
    for i in xrange(cls_num):
Y
yangyaming 已提交
54 55 56 57 58 59
        total_tp_count += states[i][0]
        total_fp_count += states[i][1]
        total_fn_count += states[i][3]
        macro_avg_precision += calc_precision(states[i][0], states[i][1])
        macro_avg_recall += calc_recall(states[i][0], states[i][3])
    metrics = []
Y
yangyaming 已提交
60 61
    macro_avg_precision /= cls_num
    macro_avg_recall /= cls_num
Y
yangyaming 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    metrics.append(macro_avg_precision)
    metrics.append(macro_avg_recall)
    metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
    micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
    metrics.append(micro_avg_precision)
    micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
    metrics.append(micro_avg_recall)
    metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
    return np.array(metrics).astype('float32')


class TestPrecisionRecallOp_0(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
77 78 79
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
        idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
Y
yangyaming 已提交
80
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
81 82 83 84 85 86
        labels = np.random.choice(xrange(cls_num), ins_num).reshape(
            (ins_num, 1)).astype('int32')
        states = get_states(idxs, labels, cls_num)
        metrics = compute_metrics(states, cls_num)

        self.attrs = {'class_number': cls_num}
Y
yangyaming 已提交
87

Y
yangyaming 已提交
88
        self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels}
Y
yangyaming 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103

        self.outputs = {
            'BatchMetrics': metrics,
            'AccumMetrics': metrics,
            'AccumStatesInfo': states
        }

    def test_check_output(self):
        self.check_output()


class TestPrecisionRecallOp_1(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
104 105 106 107
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
        idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
108
        weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
Y
yangyaming 已提交
109
        labels = np.random.choice(xrange(cls_num), ins_num).reshape(
Y
yangyaming 已提交
110 111
            (ins_num, 1)).astype('int32')

Y
yangyaming 已提交
112 113 114 115 116
        states = get_states(idxs, labels, cls_num, weights)
        metrics = compute_metrics(states, cls_num)

        self.attrs = {'class_number': cls_num}

Y
yangyaming 已提交
117
        self.inputs = {
Y
yangyaming 已提交
118 119
            'MaxProbs': max_probs,
            'Indices': idxs,
Y
yangyaming 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
            'Labels': labels,
            'Weights': weights
        }

        self.outputs = {
            'BatchMetrics': metrics,
            'AccumMetrics': metrics,
            'AccumStatesInfo': states
        }

    def test_check_output(self):
        self.check_output()


class TestPrecisionRecallOp_2(OpTest):
    def setUp(self):
        self.op_type = "precision_recall"
        ins_num = 64
Y
yangyaming 已提交
138 139 140 141
        cls_num = 10
        max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
        idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
142
        weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
Y
yangyaming 已提交
143
        labels = np.random.choice(xrange(cls_num), ins_num).reshape(
Y
yangyaming 已提交
144
            (ins_num, 1)).astype('int32')
Y
yangyaming 已提交
145
        states = np.random.randint(0, 30, (cls_num, 4)).astype('float32')
Y
yangyaming 已提交
146

Y
yangyaming 已提交
147 148
        accum_states = get_states(idxs, labels, cls_num, weights)
        batch_metrics = compute_metrics(accum_states, cls_num)
Y
yangyaming 已提交
149
        accum_states += states
Y
yangyaming 已提交
150 151 152
        accum_metrics = compute_metrics(accum_states, cls_num)

        self.attrs = {'class_number': cls_num}
Y
yangyaming 已提交
153 154

        self.inputs = {
Y
yangyaming 已提交
155 156
            'MaxProbs': max_probs,
            'Indices': idxs,
Y
yangyaming 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
            'Labels': labels,
            'Weights': weights,
            'StatesInfo': states
        }

        self.outputs = {
            'BatchMetrics': batch_metrics,
            'AccumMetrics': accum_metrics,
            'AccumStatesInfo': accum_states
        }

    def test_check_output(self):
        self.check_output()


if __name__ == '__main__':
    unittest.main()