event.py 2.9 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
Yu Yang 已提交
14
"""
Q
qijun 已提交
15
Testing and training events.
16 17 18

There are:

Q
qijun 已提交
19
* TestResult
20 21 22 23
* BeginIteration
* EndIteration
* BeginPass
* EndPass
Y
Yu Yang 已提交
24
"""
Y
Yu Yang 已提交
25
__all__ = [
武毅 已提交
26 27
    'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult',
    'EndForwardBackward'
Y
Yu Yang 已提交
28
]
Y
Yu Yang 已提交
29 30


Y
Yu Yang 已提交
31 32
class WithMetric(object):
    def __init__(self, evaluator):
Y
Yu Yang 已提交
33
        import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47
        if not isinstance(evaluator, api.Evaluator):
            raise TypeError("Evaluator should be api.Evaluator type")
        self.__evaluator__ = evaluator

    @property
    def metrics(self):
        names = self.__evaluator__.getNames()
        retv = dict()
        for each_name in names:
            val = self.__evaluator__.getValue(each_name)
            retv[each_name] = val
        return retv


Y
Yu Yang 已提交
48
class TestResult(WithMetric):
Y
Yu Yang 已提交
49 50 51 52
    """
    Result that trainer.test return.
    """

Y
Yu Yang 已提交
53
    def __init__(self, evaluator, cost):
Y
Yu Yang 已提交
54
        super(TestResult, self).__init__(evaluator)
Y
Yu Yang 已提交
55
        self.cost = cost
Y
Yu Yang 已提交
56 57


Y
Yu Yang 已提交
58 59 60 61 62 63 64 65 66 67 68 69
class BeginPass(object):
    """
    Event On One Pass Training Start.
    """

    def __init__(self, pass_id):
        self.pass_id = pass_id


class EndPass(WithMetric):
    """
    Event On One Pass Training Complete.
武毅 已提交
70 71
    To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
    in your event_handler call back
Y
Yu Yang 已提交
72 73
    """

武毅 已提交
74
    def __init__(self, pass_id, evaluator, gm):
Y
Yu Yang 已提交
75
        self.pass_id = pass_id
武毅 已提交
76
        self.gm = gm
Y
Yu Yang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89
        WithMetric.__init__(self, evaluator)


class BeginIteration(object):
    """
    Event On One Batch Training Start.
    """

    def __init__(self, pass_id, batch_id):
        self.pass_id = pass_id
        self.batch_id = batch_id


武毅 已提交
90 91 92 93 94 95 96 97 98 99 100
class EndForwardBackward(object):
    """
    Event On One Batch ForwardBackward Complete.
    """

    def __init__(self, pass_id, batch_id, gm):
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.gm = gm


Y
Yu Yang 已提交
101
class EndIteration(WithMetric):
Y
Yu Yang 已提交
102 103
    """
    Event On One Batch Training Complete.
武毅 已提交
104 105
    To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
    in your event_handler call back
Y
Yu Yang 已提交
106 107
    """

武毅 已提交
108
    def __init__(self, pass_id, batch_id, cost, evaluator, gm):
Y
Yu Yang 已提交
109 110 111
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
武毅 已提交
112
        self.gm = gm
Y
Yu Yang 已提交
113
        WithMetric.__init__(self, evaluator)