event.py 1.7 KB
Newer Older
Y
Yu Yang 已提交
1 2
"""
All training events.
3 4 5 6 7 8 9 10 11

There are:

* BeginIteration
* EndIteration
* BeginPass
* EndPass

TODO(yuyang18): Complete it!
Y
Yu Yang 已提交
12
"""
Y
Yu Yang 已提交
13
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
14 15 16 17

__all__ = [
    'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
Y
Yu Yang 已提交
18 19


Y
Yu Yang 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
class WithMetric(object):
    def __init__(self, evaluator):
        if not isinstance(evaluator, api.Evaluator):
            raise TypeError("Evaluator should be api.Evaluator type")
        self.__evaluator__ = evaluator

    @property
    def metrics(self):
        names = self.__evaluator__.getNames()
        retv = dict()
        for each_name in names:
            val = self.__evaluator__.getValue(each_name)
            retv[each_name] = val
        return retv


Y
Yu Yang 已提交
36
class TestResult(WithMetric):
Y
Yu Yang 已提交
37
    def __init__(self, evaluator, cost):
Y
Yu Yang 已提交
38
        super(TestResult, self).__init__(evaluator)
Y
Yu Yang 已提交
39
        self.cost = cost
Y
Yu Yang 已提交
40 41


Y
Yu Yang 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
class BeginPass(object):
    """
    Event On One Pass Training Start.
    """

    def __init__(self, pass_id):
        self.pass_id = pass_id


class EndPass(WithMetric):
    """
    Event On One Pass Training Complete.
    """

    def __init__(self, pass_id, evaluator):
        self.pass_id = pass_id
        WithMetric.__init__(self, evaluator)


class BeginIteration(object):
    """
    Event On One Batch Training Start.
    """

    def __init__(self, pass_id, batch_id):
        self.pass_id = pass_id
        self.batch_id = batch_id


class EndIteration(WithMetric):
Y
Yu Yang 已提交
72 73 74 75
    """
    Event On One Batch Training Complete.
    """

Y
Yu Yang 已提交
76
    def __init__(self, pass_id, batch_id, cost, evaluator):
Y
Yu Yang 已提交
77 78 79
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
Y
Yu Yang 已提交
80
        WithMetric.__init__(self, evaluator)