evaluator.py 6.0 KB
Newer Older
D
Dong Zhihong 已提交
1
import numpy as np
2
from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable
Q
Qiao Longfei 已提交
3
import paddle.v2.fluid.core as core
武毅 已提交
4 5


D
Dong Zhihong 已提交
6 7 8 9 10 11 12 13 14 15 16
def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.data_type,
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


D
Dong Zhihong 已提交
17 18 19
class Evaluator(object):
    """
    Evalutor Base class.
D
Dong Zhihong 已提交
20 21 22 23

    create metric states
    add mini-batch evaluator caculate operator
    add increment operator to accumulate the metric states
D
Dong Zhihong 已提交
24
    """
武毅 已提交
25

D
Dong Zhihong 已提交
26
    def __init__(self, name, **kwargs):
D
Dong Zhihong 已提交
27 28 29
        """
        init the global states
        """
D
Dong Zhihong 已提交
30
        self._states = {}
D
Dong Zhihong 已提交
31 32 33 34 35
        if kwargs.has_key("main_program"):
            self._main_program = kwargs.get("main_program")
        else:
            self._main_program = g_main_program

D
Dong Zhihong 已提交
36
    def _update_ops(self, *args, **kwargs):
D
Dong Zhihong 已提交
37 38 39 40
        """
        append update ops to the global states
        """
        raise NotImplementedError()
D
Dong Zhihong 已提交
41

D
Dong Zhihong 已提交
42
    def reset(self, executor, reset_program=None):
武毅 已提交
43
        """
D
Dong Zhihong 已提交
44 45
        Clear metric states at the begin of each pass/user specified batch
        """
D
Dong Zhihong 已提交
46
        if reset_program == None:
D
Dong Zhihong 已提交
47 48 49
            reset_program = Program()
        else:
            reset_program = program
D
Dong Zhihong 已提交
50
        block = reset_program.global_block()
D
Dong Zhihong 已提交
51
        for k, var in self._states.iteritems():
D
Dong Zhihong 已提交
52 53
            g_var = _clone_var_in_block_(block, var)
            zeros = block.create_var(dtype="float32", persistable=True)
D
Dong Zhihong 已提交
54
            block.append_op(
D
Dong Zhihong 已提交
55 56 57
                type="fill_constant",
                outputs={"Out": [zeros]},
                attrs={
D
Dong Zhihong 已提交
58 59 60
                    "shape": g_var.shape,
                    "value": .0,
                    "data_type": 5,
D
Dong Zhihong 已提交
61
                })
D
Dong Zhihong 已提交
62
            block.append_op(
D
Dong Zhihong 已提交
63 64
                type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
        executor.run(reset_program, fetch_list=self._states.values())
武毅 已提交
65

D
Dong Zhihong 已提交
66
    def eval(self, executor, eval_program=None):
D
Dong Zhihong 已提交
67
        """
D
Dong Zhihong 已提交
68 69
        Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
        """
D
Dong Zhihong 已提交
70
        raise NotImplementedError()
D
Dong Zhihong 已提交
71

D
Dong Zhihong 已提交
72 73

class Accuracy(Evaluator):
D
Dong Zhihong 已提交
74 75 76 77
    """
    Accuracy need two state variable Total, Correct
    """

D
Dong Zhihong 已提交
78
    def __init__(self, *args, **kwargs):
D
Dong Zhihong 已提交
79
        super(Accuracy, self).__init__("accuracy", **kwargs)
D
Dong Zhihong 已提交
80
        block = self._main_program.global_block()
D
Dong Zhihong 已提交
81
        g_total = block.create_var(
D
Dong Zhihong 已提交
82 83 84 85
            name=unique_name("Total"),
            persistable=True,
            dtype="int64",
            shape=[1])
D
Dong Zhihong 已提交
86
        g_correct = block.create_var(
D
Dong Zhihong 已提交
87 88 89 90
            name=unique_name("Correct"),
            persistable=True,
            dtype="int64",
            shape=[1])
D
Dong Zhihong 已提交
91 92
        self._states["Total"] = g_total
        self._states["Correct"] = g_correct
D
Dong Zhihong 已提交
93

D
Dong Zhihong 已提交
94 95
    def _update_ops(self, input, label, k=1, **kwargs):
        block = self._main_program.global_block()
D
Dong Zhihong 已提交
96 97 98
        topk_out = block.create_var(dtype=input.data_type)
        topk_indices = block.create_var(dtype="int64")
        block.append_op(
D
Dong Zhihong 已提交
99 100 101 102 103
            type="top_k",
            inputs={"X": [input]},
            outputs={"Out": [topk_out],
                     "Indices": [topk_indices]},
            attrs={"k": k})
D
Dong Zhihong 已提交
104 105 106
        acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32"))
        correct = block.create_var(dtype="int64", persistable=True)
        total = block.create_var(dtype="int64", persistable=True)
D
Dong Zhihong 已提交
107
        block.append_op(
D
Dong Zhihong 已提交
108 109 110 111 112 113
            type="accuracy",
            inputs={
                "Out": [topk_out],
                "Indices": [topk_indices],
                "Label": [label]
            },
D
Dong Zhihong 已提交
114 115
            outputs={
                "Accuracy": [acc_out],
D
Dong Zhihong 已提交
116 117
                "Correct": [correct],
                "Total": [total],
D
Dong Zhihong 已提交
118
            })
D
Dong Zhihong 已提交
119

D
Dong Zhihong 已提交
120 121 122 123 124
        block.append_op(
            type="cast",
            inputs={"X": [self._states["Total"]]},
            outputs={"Out": [self._states["Total"]]},
            attrs={
D
Dong Zhihong 已提交
125 126
                "in_data_type": 5,  # float32
                "out_data_type": 2,  #int32
D
Dong Zhihong 已提交
127 128 129 130 131 132 133 134 135 136
            })
        block.append_op(
            type="cast",
            inputs={"X": [self._states["Correct"]]},
            outputs={"Out": [self._states["Correct"]]},
            attrs={
                "in_data_type": 5,
                "out_data_type": 2,
            })

D
Dong Zhihong 已提交
137
        block.append_op(
D
Dong Zhihong 已提交
138 139 140 141
            type="elementwise_add",
            inputs={"X": [self._states["Total"]],
                    "Y": [total]},
            outputs={"Out": [self._states["Total"]]})
D
Dong Zhihong 已提交
142
        block.append_op(
D
Dong Zhihong 已提交
143 144 145 146 147
            type="elementwise_add",
            inputs={"X": [self._states["Correct"]],
                    "Y": [correct]},
            outputs={"Out": [self._states["Correct"]]})

D
Dong Zhihong 已提交
148
        return acc_out
D
Dong Zhihong 已提交
149

D
Dong Zhihong 已提交
150 151 152
    def eval(self, executor, eval_program=None):
        if eval_program != None:
            eval_program = eval_program
D
Dong Zhihong 已提交
153 154 155
        else:
            eval_program = Program()
        block = eval_program.global_block()
D
Dong Zhihong 已提交
156
        eval_out = block.create_var(dtype=self._states["Total"].data_type)
D
Dong Zhihong 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        e_total = _clone_var_in_block_(block, self._states["Total"])
        e_correct = _clone_var_in_block_(block, self._states["Correct"])
        block.append_op(
            type="cast",
            inputs={"X": [e_total]},
            outputs={"Out": [e_total]},
            attrs={
                "in_data_type": 2,  #int32
                "out_data_type": 5,  #float32
            })
        block.append_op(
            type="cast",
            inputs={"X": [e_correct]},
            outputs={"Out": [e_correct]},
            attrs={
                "in_data_type": 2,
                "out_data_type": 5,
            })
D
Dong Zhihong 已提交
175
        block.append_op(
D
Dong Zhihong 已提交
176
            type="elementwise_div",
D
Dong Zhihong 已提交
177 178
            inputs={"X": e_correct,
                    "Y": e_total},
D
Dong Zhihong 已提交
179
            outputs={"Out": eval_out})
D
Dong Zhihong 已提交
180 181
        out = executor.run(eval_program, fetch_list=[eval_out])
        return np.array(out[0])
D
Dong Zhihong 已提交
182 183


D
Dong Zhihong 已提交
184 185 186 187
def accuracy(*args, **kwargs):
    cls = Accuracy(*args, **kwargs)
    out = cls._update_ops(*args, **kwargs)
    return cls, out