evaluator.py 6.0 KB
Newer Older
D
Dong Zhihong 已提交
1
import numpy as np
2
from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable
Q
Qiao Longfei 已提交
3
import paddle.v2.fluid.core as core
武毅 已提交
4 5


D
Dong Zhihong 已提交
6 7 8 9 10
def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
11
        dtype=var.dtype,
D
Dong Zhihong 已提交
12 13 14 15 16
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


D
Dong Zhihong 已提交
17 18 19
class Evaluator(object):
    """
    Evalutor Base class.
D
Dong Zhihong 已提交
20 21 22 23

    create metric states
    add mini-batch evaluator caculate operator
    add increment operator to accumulate the metric states
D
Dong Zhihong 已提交
24
    """
武毅 已提交
25

D
Dong Zhihong 已提交
26
    def __init__(self, name, **kwargs):
D
Dong Zhihong 已提交
27 28 29
        """
        init the global states
        """
D
Dong Zhihong 已提交
30
        self._states = {}
D
Dong Zhihong 已提交
31 32 33 34 35
        if kwargs.has_key("main_program"):
            self._main_program = kwargs.get("main_program")
        else:
            self._main_program = g_main_program

36 37 38
    def states(self):
        return self._states

D
Dong Zhihong 已提交
39
    def _update_ops(self, *args, **kwargs):
D
Dong Zhihong 已提交
40 41 42 43
        """
        append update ops to the global states
        """
        raise NotImplementedError()
D
Dong Zhihong 已提交
44

D
Dong Zhihong 已提交
45
    def reset(self, executor, reset_program=None):
武毅 已提交
46
        """
D
Dong Zhihong 已提交
47 48
        Clear metric states at the begin of each pass/user specified batch
        """
D
Dong Zhihong 已提交
49
        if reset_program == None:
D
Dong Zhihong 已提交
50 51 52
            reset_program = Program()
        else:
            reset_program = program
D
Dong Zhihong 已提交
53
        block = reset_program.global_block()
D
Dong Zhihong 已提交
54
        for k, var in self._states.iteritems():
D
Dong Zhihong 已提交
55 56
            g_var = _clone_var_in_block_(block, var)
            zeros = block.create_var(dtype="float32", persistable=True)
D
Dong Zhihong 已提交
57
            block.append_op(
D
Dong Zhihong 已提交
58 59 60
                type="fill_constant",
                outputs={"Out": [zeros]},
                attrs={
D
Dong Zhihong 已提交
61 62
                    "shape": g_var.shape,
                    "value": .0,
F
fengjiayi 已提交
63
                    "dtype": 5,
D
Dong Zhihong 已提交
64
                })
D
Dong Zhihong 已提交
65
            block.append_op(
D
Dong Zhihong 已提交
66 67
                type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
        executor.run(reset_program, fetch_list=self._states.values())
武毅 已提交
68

D
Dong Zhihong 已提交
69
    def eval(self, executor, eval_program=None):
D
Dong Zhihong 已提交
70
        """
D
Dong Zhihong 已提交
71 72
        Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
        """
D
Dong Zhihong 已提交
73
        raise NotImplementedError()
D
Dong Zhihong 已提交
74

D
Dong Zhihong 已提交
75 76

class Accuracy(Evaluator):
D
Dong Zhihong 已提交
77 78 79 80
    """
    Accuracy need two state variable Total, Correct
    """

D
Dong Zhihong 已提交
81
    def __init__(self, *args, **kwargs):
D
Dong Zhihong 已提交
82
        super(Accuracy, self).__init__("accuracy", **kwargs)
D
Dong Zhihong 已提交
83
        block = self._main_program.global_block()
D
Dong Zhihong 已提交
84
        g_total = block.create_var(
D
Dong Zhihong 已提交
85 86 87 88
            name=unique_name("Total"),
            persistable=True,
            dtype="int64",
            shape=[1])
D
Dong Zhihong 已提交
89
        g_correct = block.create_var(
D
Dong Zhihong 已提交
90 91 92 93
            name=unique_name("Correct"),
            persistable=True,
            dtype="int64",
            shape=[1])
D
Dong Zhihong 已提交
94 95
        self._states["Total"] = g_total
        self._states["Correct"] = g_correct
D
Dong Zhihong 已提交
96

D
Dong Zhihong 已提交
97 98
    def _update_ops(self, input, label, k=1, **kwargs):
        block = self._main_program.global_block()
F
fengjiayi 已提交
99
        topk_out = block.create_var(dtype=input.dtype)
D
Dong Zhihong 已提交
100 101
        topk_indices = block.create_var(dtype="int64")
        block.append_op(
D
Dong Zhihong 已提交
102 103 104 105 106
            type="top_k",
            inputs={"X": [input]},
            outputs={"Out": [topk_out],
                     "Indices": [topk_indices]},
            attrs={"k": k})
D
Dong Zhihong 已提交
107 108 109
        acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32"))
        correct = block.create_var(dtype="int64", persistable=True)
        total = block.create_var(dtype="int64", persistable=True)
D
Dong Zhihong 已提交
110
        block.append_op(
D
Dong Zhihong 已提交
111 112 113 114 115 116
            type="accuracy",
            inputs={
                "Out": [topk_out],
                "Indices": [topk_indices],
                "Label": [label]
            },
D
Dong Zhihong 已提交
117 118
            outputs={
                "Accuracy": [acc_out],
D
Dong Zhihong 已提交
119 120
                "Correct": [correct],
                "Total": [total],
D
Dong Zhihong 已提交
121
            })
D
Dong Zhihong 已提交
122

D
Dong Zhihong 已提交
123 124 125 126 127
        block.append_op(
            type="cast",
            inputs={"X": [self._states["Total"]]},
            outputs={"Out": [self._states["Total"]]},
            attrs={
F
fengjiayi 已提交
128 129
                "in_dtype": 5,  # float32
                "out_dtype": 2,  # int32
D
Dong Zhihong 已提交
130 131 132 133 134 135
            })
        block.append_op(
            type="cast",
            inputs={"X": [self._states["Correct"]]},
            outputs={"Out": [self._states["Correct"]]},
            attrs={
F
fengjiayi 已提交
136 137
                "in_dtype": 5,
                "out_dtype": 2,
D
Dong Zhihong 已提交
138 139
            })

D
Dong Zhihong 已提交
140
        block.append_op(
D
Dong Zhihong 已提交
141 142 143 144
            type="elementwise_add",
            inputs={"X": [self._states["Total"]],
                    "Y": [total]},
            outputs={"Out": [self._states["Total"]]})
D
Dong Zhihong 已提交
145
        block.append_op(
D
Dong Zhihong 已提交
146 147 148 149 150
            type="elementwise_add",
            inputs={"X": [self._states["Correct"]],
                    "Y": [correct]},
            outputs={"Out": [self._states["Correct"]]})

D
Dong Zhihong 已提交
151
        return acc_out
D
Dong Zhihong 已提交
152

D
Dong Zhihong 已提交
153 154 155
    def eval(self, executor, eval_program=None):
        if eval_program != None:
            eval_program = eval_program
D
Dong Zhihong 已提交
156 157 158
        else:
            eval_program = Program()
        block = eval_program.global_block()
F
fengjiayi 已提交
159
        eval_out = block.create_var(dtype=self._states["Total"].dtype)
D
Dong Zhihong 已提交
160 161 162 163 164 165 166
        e_total = _clone_var_in_block_(block, self._states["Total"])
        e_correct = _clone_var_in_block_(block, self._states["Correct"])
        block.append_op(
            type="cast",
            inputs={"X": [e_total]},
            outputs={"Out": [e_total]},
            attrs={
F
fengjiayi 已提交
167 168
                "in_dtype": 2,  # int32
                "out_dtype": 5,  # float32
D
Dong Zhihong 已提交
169 170 171 172 173 174
            })
        block.append_op(
            type="cast",
            inputs={"X": [e_correct]},
            outputs={"Out": [e_correct]},
            attrs={
F
fengjiayi 已提交
175 176
                "in_dtype": 2,
                "out_dtype": 5,
D
Dong Zhihong 已提交
177
            })
D
Dong Zhihong 已提交
178
        block.append_op(
D
Dong Zhihong 已提交
179
            type="elementwise_div",
D
Dong Zhihong 已提交
180 181
            inputs={"X": e_correct,
                    "Y": e_total},
D
Dong Zhihong 已提交
182
            outputs={"Out": eval_out})
D
Dong Zhihong 已提交
183 184
        out = executor.run(eval_program, fetch_list=[eval_out])
        return np.array(out[0])
D
Dong Zhihong 已提交
185 186


D
Dong Zhihong 已提交
187 188 189 190
def accuracy(*args, **kwargs):
    cls = Accuracy(*args, **kwargs)
    out = cls._update_ops(*args, **kwargs)
    return cls, out