evaluator.py 4.4 KB
Newer Older
D
Dong Zhihong 已提交
1
from paddle.v2.framework.framework import Program, g_program, unique_name
D
Dong Zhihong 已提交
2
from paddle.v2.framework.layer_helper import LayerHelper
武毅 已提交
3 4 5
import paddle.v2.framework.core as core


D
Dong Zhihong 已提交
6 7 8
class Evaluator(object):
    """
    Evalutor Base class.
D
Dong Zhihong 已提交
9 10 11 12

    create metric states
    add mini-batch evaluator caculate operator
    add increment operator to accumulate the metric states
D
Dong Zhihong 已提交
13
    """
武毅 已提交
14

D
Dong Zhihong 已提交
15
    def __init__(self, name, **kwargs):
D
Dong Zhihong 已提交
16
        self._states = {}
D
Dong Zhihong 已提交
17
        self._helper = LayerHelper(layer_type=name, **kwargs)
D
Dong Zhihong 已提交
18 19 20 21
        # if kwargs.has_key("program"):
        #     self._program =  kwargs.get("program")
        # else:
        #     self._program = g_program
武毅 已提交
22

D
Dong Zhihong 已提交
23 24 25 26 27 28
    # def _update(self):
    #     """
    #     Updates the internal states througth operator
    #   """
    #     raise NotImplementedError()

D
Dong Zhihong 已提交
29
    def reset(self, executor, program=None):
武毅 已提交
30
        """
D
Dong Zhihong 已提交
31
      Clear metric states at the begin of each pass/user specified batch
D
Dong Zhihong 已提交
32
      """
D
Dong Zhihong 已提交
33 34 35 36 37
        if program == None:
            reset_program = Program()
        else:
            reset_program = program
        for k, var in self._states.iteritems():
D
Dong Zhihong 已提交
38 39 40 41 42 43 44 45 46 47
            zeros = helper.create_tmp_variable(dtype=var.data_type)
            self._helper.append_op(
                type="fill_constant",
                outputs={"Out": [zeros]},
                attrs={
                    "shape": var.shape,
                    "value": 0,
                })
            self._helper.append_op(
                type="scale", inputs={"X": zeros}, outputs={"Out": var})
D
Dong Zhihong 已提交
48
        executor.run(reset_program)
武毅 已提交
49

D
Dong Zhihong 已提交
50
    def eval(self):
D
Dong Zhihong 已提交
51 52 53
        """
      Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
      """
D
Dong Zhihong 已提交
54
        raise NotImplementedError()
D
Dong Zhihong 已提交
55

D
Dong Zhihong 已提交
56 57

class Accuracy(Evaluator):
D
Dong Zhihong 已提交
58 59 60 61
    """
    Accuracy need two state variable Total, Correct
    """

D
Dong Zhihong 已提交
62
    def __init__(self, input, label, k=1, **kwargs):
D
Dong Zhihong 已提交
63 64 65 66 67 68 69 70 71 72 73
        super(Accuracy, self).__init__("accuracy", **kwargs)
        g_total = helper.create_global_variable(
            name=unique_name("Total"),
            persistable=True,
            dtype="int64",
            shape=[1])
        g_correct = helper.create_global_variable(
            name=unique_name("Correct"),
            persistable=True,
            dtype="int64",
            shape=[1])
D
Dong Zhihong 已提交
74 75
        self._states["Total"] = g_total
        self._states["Correct"] = g_correct
D
Dong Zhihong 已提交
76

D
Dong Zhihong 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        topk_out = helper.create_tmp_variable(dtype=input.data_type)
        topk_indices = helper.create_tmp_variable(dtype="int64")
        helper.append_op(
            type="top_k",
            inputs={"X": [input]},
            outputs={"Out": [topk_out],
                     "Indices": [topk_indices]},
            attrs={"k": k})
        acc_out_dtype = kwargs.get("out_dtype", "float32")
        acc_out = helper.create_tmp_variable(dtype=acc_out_dtype)
        helper.append_op(
            type="accuracy",
            inputs={
                "Out": [topk_out],
                "Indices": [topk_indices],
                "Label": [label]
            },
D
Dong Zhihong 已提交
94 95
            outputs={
                "Accuracy": [acc_out],
D
Dong Zhihong 已提交
96 97
                "Correct": [correct],
                "Total": [total],
D
Dong Zhihong 已提交
98
            })
D
Dong Zhihong 已提交
99 100

        helper.append_op(
D
Dong Zhihong 已提交
101
            type="sum",
D
Dong Zhihong 已提交
102 103 104 105 106
            inputs={"X": [g_total, total]},
            outputs={"Out": [g_total]})
        helper.append_op(
            type="sum",
            inputs={"X": [g_correct, correct]},
D
Dong Zhihong 已提交
107 108
            outputs={"Out": [g_total]})
        return acc_out
D
Dong Zhihong 已提交
109

D
Dong Zhihong 已提交
110 111 112 113 114 115 116 117 118 119 120 121
    def eval(self, executor, program=None):
        if program == None:
            eval_program = Program()
        else:
            eval_program = program
        eval_out = helper.create_tmp_variable(dtype=self._helper.input_dtype())
        self._helper.append_op(
            type="elementwise_div",
            inputs={"X": self._states["Total"],
                    "Y": self._states["Correct"]},
            outputs={"Out": eval_out})
        return executor.run(eval_program, fetch_list=[eval_out])
D
Dong Zhihong 已提交
122 123


D
Dong Zhihong 已提交
124
# Demo for composing low level op to compute the F1 metric
D
Dong Zhihong 已提交
125 126 127
class F1(Evaluator):
    def __init__(self, input, label, **kwargs):
        super(F1, self).__init__("F1", **kwargs)
D
Dong Zhihong 已提交
128 129 130 131 132 133 134
        g_tp = helper.create_global_variable(
            name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1])
        g_fp = helper.create_global_variable(
            name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1])

        self._states["Tp"] = g_tp
        self._states["Fp"] = g_fp