event.py 1.6 KB
Newer Older
Y
Yu Yang 已提交
1 2
"""
All training events.
3 4 5 6 7 8 9 10 11

There are:

* BeginIteration
* EndIteration
* BeginPass
* EndPass

TODO(yuyang18): Complete it!
Y
Yu Yang 已提交
12
"""
Y
Yu Yang 已提交
13
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
14 15 16 17

__all__ = [
    'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
Y
Yu Yang 已提交
18 19


Y
Yu Yang 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
class WithMetric(object):
    def __init__(self, evaluator):
        if not isinstance(evaluator, api.Evaluator):
            raise TypeError("Evaluator should be api.Evaluator type")
        self.__evaluator__ = evaluator

    @property
    def metrics(self):
        names = self.__evaluator__.getNames()
        retv = dict()
        for each_name in names:
            val = self.__evaluator__.getValue(each_name)
            retv[each_name] = val
        return retv


Y
Yu Yang 已提交
36 37 38 39 40
class TestResult(WithMetric):
    def __init__(self, evaluator):
        super(TestResult, self).__init__(evaluator)


Y
Yu Yang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
class BeginPass(object):
    """
    Event On One Pass Training Start.
    """

    def __init__(self, pass_id):
        self.pass_id = pass_id


class EndPass(WithMetric):
    """
    Event On One Pass Training Complete.
    """

    def __init__(self, pass_id, evaluator):
        self.pass_id = pass_id
        WithMetric.__init__(self, evaluator)


class BeginIteration(object):
    """
    Event On One Batch Training Start.
    """

    def __init__(self, pass_id, batch_id):
        self.pass_id = pass_id
        self.batch_id = batch_id


class EndIteration(WithMetric):
Y
Yu Yang 已提交
71 72 73 74
    """
    Event On One Batch Training Complete.
    """

Y
Yu Yang 已提交
75
    def __init__(self, pass_id, batch_id, cost, evaluator):
Y
Yu Yang 已提交
76 77 78
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
Y
Yu Yang 已提交
79
        WithMetric.__init__(self, evaluator)