evaluator.py 4.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

D
dzhwinter 已提交
15
import warnings
D
Dong Zhihong 已提交
16
import numpy as np
武毅 已提交
17

18
import paddle
19 20 21 22 23
from . import layers
from .framework import Program, Variable, program_guard
from . import unique_name
from .layer_helper import LayerHelper
from .initializer import Constant
24
from .layers import detection
武毅 已提交
25

Y
Yu Yang 已提交
26 27

def _clone_var_(block, var):
D
Dong Zhihong 已提交
28
    assert isinstance(var, Variable)
29 30 31 32 33 34 35 36
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.dtype,
        type=var.type,
        lod_level=var.lod_level,
        persistable=True,
    )
D
Dong Zhihong 已提交
37 38


39
class Evaluator:
D
Dong Zhihong 已提交
40
    """
41 42 43 44 45 46
    Warning: better to use the fluid.metrics.* things, more
    flexible support via pure Python and Operator, and decoupled
    with executor. Short doc are intended to urge new user
    start from Metrics.

    Base Class for all evaluators.
47

Y
Yu Yang 已提交
48
    Args:
49
        name(str): The name of evaluator. such as, "accuracy". Used for generate
Y
Yu Yang 已提交
50
            temporary variable name.
51
        main_program(Program, optional): The evaluator should be added to this
Y
Yu Yang 已提交
52
            main_program. Default default_main_program()
53
        startup_program(Program, optional):The parameter should be added to this
Y
Yu Yang 已提交
54
            startup_program. Default default_startup_program()
55

Y
Yu Yang 已提交
56
    Attributes:
57
        states(list): The list of state variables. states will be reset to zero
Y
Yu Yang 已提交
58
            when `reset` is invoked.
59
        metrics(list): The list of metrics variables. They will be calculate
Y
Yu Yang 已提交
60
            every mini-batch
D
Dong Zhihong 已提交
61
    """
武毅 已提交
62

D
Dong Zhihong 已提交
63
    def __init__(self, name, **kwargs):
D
dzhwinter 已提交
64 65
        warnings.warn(
            "The %s is deprecated, because maintain a modified program inside evaluator cause bug easily, please use fluid.metrics.%s instead."
66 67 68
            % (self.__class__.__name__, self.__class__.__name__),
            Warning,
        )
Y
Yu Yang 已提交
69 70 71 72 73
        self.states = []
        self.metrics = []
        self.helper = LayerHelper(name, **kwargs)

    def reset(self, executor, reset_program=None):
D
Dong Zhihong 已提交
74
        """
Y
Yu Yang 已提交
75
        reset metric states at the begin of each pass/user specified batch
76 77 78 79

        Args:
            executor(Executor|ParallelExecutor): a executor for executing the reset_program
            reset_program(Program): a single Program for reset process
D
Dong Zhihong 已提交
80
        """
Y
Yu Yang 已提交
81 82 83
        if reset_program is None:
            reset_program = Program()

84 85 86 87
        with program_guard(main_program=reset_program):
            for var in self.states:
                assert isinstance(var, Variable)
                g_var = _clone_var_(reset_program.current_block(), var)
88 89 90
                layers.fill_constant(
                    shape=g_var.shape, value=0.0, dtype=g_var.dtype, out=g_var
                )
D
Dong Zhihong 已提交
91

Y
Yu Yang 已提交
92
        executor.run(reset_program)
93

Y
Yu Yang 已提交
94
    def eval(self, executor, eval_program=None):
D
Dong Zhihong 已提交
95
        """
Y
Yu Yang 已提交
96
        Evaluate the statistics merged by multiple mini-batches.
97 98 99
        Args:
            executor(Executor|ParallelExecutor): a executor for executing the eval_program
            eval_program(Program): a single Program for eval process
D
Dong Zhihong 已提交
100 101
        """
        raise NotImplementedError()
D
Dong Zhihong 已提交
102

103
    def _create_state(self, suffix, dtype, shape):
武毅 已提交
104
        """
105 106
        Create state variable.

Y
Yu Yang 已提交
107
        Args:
108
            suffix(str): the state suffix.
109
            dtype(str|core.VarDesc.VarType): the state data type
110
            shape(tuple|list): the shape of state
Y
Yu Yang 已提交
111 112

        Returns: State variable
武毅 已提交
113

D
Dong Zhihong 已提交
114
        """
115 116 117 118 119 120
        state = self.helper.create_variable(
            name="_".join([unique_name.generate(self.helper.name), suffix]),
            persistable=True,
            dtype=dtype,
            shape=shape,
        )
Y
Yu Yang 已提交
121 122
        self.states.append(state)
        return state