evaluator.py 6.8 KB
Newer Older
D
Dong Zhihong 已提交
1
import numpy as np
武毅 已提交
2

3
import layers
4
from framework import Program, unique_name, Variable, program_guard
5
from layer_helper import LayerHelper
武毅 已提交
6

7 8 9 10
__all__ = [
    'Accuracy',
    'ChunkEvaluator',
]
Y
Yu Yang 已提交
11 12 13


def _clone_var_(block, var):
D
Dong Zhihong 已提交
14 15 16 17
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
18
        dtype=var.dtype,
D
Dong Zhihong 已提交
19 20 21 22 23
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


D
Dong Zhihong 已提交
24 25
class Evaluator(object):
    """
Y
Yu Yang 已提交
26
    Base Class for all evaluators
27

Y
Yu Yang 已提交
28
    Args:
29
        name(str): The name of evaluator. such as, "accuracy". Used for generate
Y
Yu Yang 已提交
30
            temporary variable name.
31
        main_program(Program, optional): The evaluator should be added to this
Y
Yu Yang 已提交
32
            main_program. Default default_main_program()
33
        startup_program(Program, optional):The parameter should be added to this
Y
Yu Yang 已提交
34
            startup_program. Default default_startup_program()
35

Y
Yu Yang 已提交
36
    Attributes:
37
        states(list): The list of state variables. states will be reset to zero
Y
Yu Yang 已提交
38
            when `reset` is invoked.
39
        metrics(list): The list of metrics variables. They will be calculate
Y
Yu Yang 已提交
40
            every mini-batch
D
Dong Zhihong 已提交
41
    """
武毅 已提交
42

D
Dong Zhihong 已提交
43
    def __init__(self, name, **kwargs):
Y
Yu Yang 已提交
44 45 46 47 48
        self.states = []
        self.metrics = []
        self.helper = LayerHelper(name, **kwargs)

    def reset(self, executor, reset_program=None):
D
Dong Zhihong 已提交
49
        """
Y
Yu Yang 已提交
50
        reset metric states at the begin of each pass/user specified batch
D
Dong Zhihong 已提交
51
        """
Y
Yu Yang 已提交
52 53 54
        if reset_program is None:
            reset_program = Program()

55 56 57 58 59 60
        with program_guard(main_program=reset_program):
            for var in self.states:
                assert isinstance(var, Variable)
                g_var = _clone_var_(reset_program.current_block(), var)
                layers.fill_constant(
                    shape=g_var.shape, value=0.0, dtype=g_var.dtype, out=g_var)
D
Dong Zhihong 已提交
61

Y
Yu Yang 已提交
62
        executor.run(reset_program)
63

Y
Yu Yang 已提交
64
    def eval(self, executor, eval_program=None):
D
Dong Zhihong 已提交
65
        """
Y
Yu Yang 已提交
66
        Evaluate the statistics merged by multiple mini-batches.
D
Dong Zhihong 已提交
67 68
        """
        raise NotImplementedError()
D
Dong Zhihong 已提交
69

Y
Yu Yang 已提交
70
    def create_state(self, suffix, dtype, shape):
武毅 已提交
71
        """
72 73
        Create state variable.

Y
Yu Yang 已提交
74
        NOTE: It is not a public API.
75

Y
Yu Yang 已提交
76
        Args:
77 78 79
            suffix(str): the state suffix.
            dtype(str|core.DataType): the state data type
            shape(tuple|list): the shape of state
Y
Yu Yang 已提交
80 81

        Returns: State variable
武毅 已提交
82

D
Dong Zhihong 已提交
83
        """
Y
Yu Yang 已提交
84 85 86 87 88 89 90
        state = self.helper.create_variable(
            name="_".join([unique_name(self.helper.name), suffix]),
            persistable=True,
            dtype=dtype,
            shape=shape)
        self.states.append(state)
        return state
D
Dong Zhihong 已提交
91

D
Dong Zhihong 已提交
92 93

class Accuracy(Evaluator):
D
Dong Zhihong 已提交
94
    """
Y
Yu Yang 已提交
95
    Average Accuracy for multiple mini-batches.
D
Dong Zhihong 已提交
96 97
    """

Y
Yu Yang 已提交
98
    def __init__(self, input, label, k=1, **kwargs):
D
Dong Zhihong 已提交
99
        super(Accuracy, self).__init__("accuracy", **kwargs)
Y
Yu Yang 已提交
100 101 102 103 104 105 106 107 108 109
        main_program = self.helper.main_program
        if main_program.current_block().idx != 0:
            raise ValueError("You can only invoke Evaluator in root block")

        self.total = self.create_state(dtype='int64', shape=[1], suffix='total')
        self.correct = self.create_state(
            dtype='int64', shape=[1], suffix='correct')
        total = self.helper.create_tmp_variable(dtype='int')
        correct = self.helper.create_tmp_variable(dtype='int')
        acc = layers.accuracy(
110 111 112 113 114
            input=input, label=label, k=k, total=total, correct=correct)
        total = layers.cast(x=total, dtype='int64')
        correct = layers.cast(x=correct, dtype='int64')
        layers.sums(input=[self.total, total], out=self.total)
        layers.sums(input=[self.correct, correct], out=self.correct)
Y
Yu Yang 已提交
115 116

        self.metrics.append(acc)
D
Dong Zhihong 已提交
117

D
Dong Zhihong 已提交
118
    def eval(self, executor, eval_program=None):
Y
Yu Yang 已提交
119
        if eval_program is None:
D
Dong Zhihong 已提交
120
            eval_program = Program()
Y
Yu Yang 已提交
121
        block = eval_program.current_block()
122 123 124 125 126 127
        with program_guard(main_program=eval_program):
            total = _clone_var_(block, self.total)
            correct = _clone_var_(block, self.correct)
            total = layers.cast(total, dtype='float32')
            correct = layers.cast(correct, dtype='float32')
            out = layers.elementwise_div(x=correct, y=total)
Y
Yu Yang 已提交
128
        return np.array(executor.run(eval_program, fetch_list=[out])[0])
G
guosheng 已提交
129 130 131 132


class ChunkEvaluator(Evaluator):
    """
133 134
    Accumulate counter numbers output by chunk_eval from mini-batches and
    compute the precision recall and F1-score using the accumulated counter
G
guosheng 已提交
135 136 137
    numbers.
    """

138 139 140 141 142 143 144 145
    def __init__(
            self,
            input,
            label,
            chunk_scheme,
            num_chunk_types,
            excluded_chunk_types=None, ):
        super(ChunkEvaluator, self).__init__("chunk_eval")
G
guosheng 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        main_program = self.helper.main_program
        if main_program.current_block().idx != 0:
            raise ValueError("You can only invoke Evaluator in root block")

        self.num_infer_chunks = self.create_state(
            dtype='int64', shape=[1], suffix='num_infer_chunks')
        self.num_label_chunks = self.create_state(
            dtype='int64', shape=[1], suffix='num_label_chunks')
        self.num_correct_chunks = self.create_state(
            dtype='int64', shape=[1], suffix='num_correct_chunks')
        precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval(
            input=input,
            label=label,
            chunk_scheme=chunk_scheme,
            num_chunk_types=num_chunk_types,
161
            excluded_chunk_types=excluded_chunk_types, )
G
guosheng 已提交
162 163
        layers.sums(
            input=[self.num_infer_chunks, num_infer_chunks],
164
            out=self.num_infer_chunks)
G
guosheng 已提交
165 166
        layers.sums(
            input=[self.num_label_chunks, num_label_chunks],
167
            out=self.num_label_chunks)
G
guosheng 已提交
168 169
        layers.sums(
            input=[self.num_correct_chunks, num_correct_chunks],
170
            out=self.num_correct_chunks)
G
guosheng 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

        self.metrics.extend([precision, recall, f1_score])

    def eval(self, executor, eval_program=None):
        if eval_program is None:
            eval_program = Program()
        block = eval_program.current_block()
        num_infer_chunks, num_label_chunks, num_correct_chunks = executor.run(
            eval_program,
            fetch_list=[_clone_var_(block, state) for state in self.states])
        num_infer_chunks = num_infer_chunks[0]
        num_label_chunks = num_label_chunks[0]
        num_correct_chunks = num_correct_chunks[0]
        precision = float(
            num_correct_chunks) / num_infer_chunks if num_infer_chunks else 0
        recall = float(
            num_correct_chunks) / num_label_chunks if num_label_chunks else 0
        f1_score = float(2 * precision * recall) / (
            precision + recall) if num_correct_chunks else 0
        return np.array(
            [precision], dtype='float32'), np.array(
                [recall], dtype='float32'), np.array(
                    [f1_score], dtype='float32')