evaluator.py 4.4 KB
Newer Older
D
Dong Zhihong 已提交
1
import numpy as np
武毅 已提交
2

Y
Yu Yang 已提交
3 4 5 6
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.framework import Program, unique_name, \
    Variable
from paddle.v2.fluid.layer_helper import LayerHelper
武毅 已提交
7

Y
Yu Yang 已提交
8 9 10 11
__all__ = ['Accuracy']


def _clone_var_(block, var):
D
Dong Zhihong 已提交
12 13 14 15
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
16
        dtype=var.dtype,
D
Dong Zhihong 已提交
17 18 19 20 21
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


D
Dong Zhihong 已提交
22 23
class Evaluator(object):
    """
Y
Yu Yang 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
    Base Class for all evaluators
    
    Args:
        name(str): The name of evaluator. such as, "accuracy". Used for generate 
            temporary variable name.
        main_program(Program, optional): The evaluator should be added to this 
            main_program. Default g_main_program 
        startup_program(Program, optional):The parameter should be added to this 
            startup_program. Default g_startup_program
            
    Attributes:
        states(list): The list of state variables. states will be reset to zero 
            when `reset` is invoked.
        metrics(list): The list of metrics variables. They will be calculate 
            every mini-batch
D
Dong Zhihong 已提交
39
    """
武毅 已提交
40

D
Dong Zhihong 已提交
41
    def __init__(self, name, **kwargs):
Y
Yu Yang 已提交
42 43 44 45 46
        self.states = []
        self.metrics = []
        self.helper = LayerHelper(name, **kwargs)

    def reset(self, executor, reset_program=None):
D
Dong Zhihong 已提交
47
        """
Y
Yu Yang 已提交
48
        reset metric states at the begin of each pass/user specified batch
D
Dong Zhihong 已提交
49
        """
Y
Yu Yang 已提交
50 51 52 53 54 55 56 57 58 59 60 61
        if reset_program is None:
            reset_program = Program()

        for var in self.states:
            assert isinstance(var, Variable)
            g_var = _clone_var_(reset_program.current_block(), var)
            layers.fill_constant(
                shape=g_var.shape,
                value=0.0,
                dtype=g_var.dtype,
                out=g_var,
                main_program=reset_program)
D
Dong Zhihong 已提交
62

Y
Yu Yang 已提交
63
        executor.run(reset_program)
64

Y
Yu Yang 已提交
65
    def eval(self, executor, eval_program=None):
D
Dong Zhihong 已提交
66
        """
Y
Yu Yang 已提交
67
        Evaluate the statistics merged by multiple mini-batches.
D
Dong Zhihong 已提交
68 69
        """
        raise NotImplementedError()
D
Dong Zhihong 已提交
70

Y
Yu Yang 已提交
71
    def create_state(self, suffix, dtype, shape):
武毅 已提交
72
        """
Y
Yu Yang 已提交
73 74 75 76 77 78 79 80 81 82
        Create state variable. 
        
        NOTE: It is not a public API.
        
        Args:
            suffix(str): the state suffix. 
            dtype(str|core.DataType): the state data type 
            shape(tuple|list): the shape of state 

        Returns: State variable
武毅 已提交
83

D
Dong Zhihong 已提交
84
        """
Y
Yu Yang 已提交
85 86 87 88 89 90 91
        state = self.helper.create_variable(
            name="_".join([unique_name(self.helper.name), suffix]),
            persistable=True,
            dtype=dtype,
            shape=shape)
        self.states.append(state)
        return state
D
Dong Zhihong 已提交
92

D
Dong Zhihong 已提交
93 94

class Accuracy(Evaluator):
D
Dong Zhihong 已提交
95
    """
Y
Yu Yang 已提交
96
    Average Accuracy for multiple mini-batches.
D
Dong Zhihong 已提交
97 98
    """

Y
Yu Yang 已提交
99
    def __init__(self, input, label, k=1, **kwargs):
D
Dong Zhihong 已提交
100
        super(Accuracy, self).__init__("accuracy", **kwargs)
Y
Yu Yang 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        main_program = self.helper.main_program
        if main_program.current_block().idx != 0:
            raise ValueError("You can only invoke Evaluator in root block")

        self.total = self.create_state(dtype='int64', shape=[1], suffix='total')
        self.correct = self.create_state(
            dtype='int64', shape=[1], suffix='correct')
        kwargs = {'main_program': main_program}
        total = self.helper.create_tmp_variable(dtype='int')
        correct = self.helper.create_tmp_variable(dtype='int')
        acc = layers.accuracy(
            input=input,
            label=label,
            k=k,
            total=total,
            correct=correct,
            **kwargs)
        total = layers.cast(x=total, dtype='int64', **kwargs)
        correct = layers.cast(x=correct, dtype='int64', **kwargs)
        layers.sums(input=[self.total, total], out=self.total, **kwargs)
        layers.sums(input=[self.correct, correct], out=self.correct, **kwargs)

        self.metrics.append(acc)
D
Dong Zhihong 已提交
124

D
Dong Zhihong 已提交
125
    def eval(self, executor, eval_program=None):
Y
Yu Yang 已提交
126
        if eval_program is None:
D
Dong Zhihong 已提交
127
            eval_program = Program()
Y
Yu Yang 已提交
128 129 130 131 132 133 134 135
        block = eval_program.current_block()
        kwargs = {'main_program': eval_program}
        total = _clone_var_(block, self.total)
        correct = _clone_var_(block, self.correct)
        total = layers.cast(total, dtype='float32', **kwargs)
        correct = layers.cast(correct, dtype='float32', **kwargs)
        out = layers.elementwise_div(x=correct, y=total, **kwargs)
        return np.array(executor.run(eval_program, fetch_list=[out])[0])