event.py 1.5 KB
Newer Older
Y
Yu Yang 已提交
1 2
"""
All training events.
3 4 5 6 7 8 9 10 11

There are:

* BeginIteration
* EndIteration
* BeginPass
* EndPass

TODO(yuyang18): Complete it!
Y
Yu Yang 已提交
12
"""
Y
Yu Yang 已提交
13 14
import py_paddle.swig_paddle as api
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
Y
Yu Yang 已提交
15 16


Y
Yu Yang 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
class WithMetric(object):
    def __init__(self, evaluator):
        if not isinstance(evaluator, api.Evaluator):
            raise TypeError("Evaluator should be api.Evaluator type")
        self.__evaluator__ = evaluator

    @property
    def metrics(self):
        names = self.__evaluator__.getNames()
        retv = dict()
        for each_name in names:
            val = self.__evaluator__.getValue(each_name)
            retv[each_name] = val
        return retv


class BeginPass(object):
    """
    Event On One Pass Training Start.
    """

    def __init__(self, pass_id):
        self.pass_id = pass_id


class EndPass(WithMetric):
    """
    Event On One Pass Training Complete.
    """

    def __init__(self, pass_id, evaluator):
        self.pass_id = pass_id
        WithMetric.__init__(self, evaluator)


class BeginIteration(object):
    """
    Event On One Batch Training Start.
    """

    def __init__(self, pass_id, batch_id):
        self.pass_id = pass_id
        self.batch_id = batch_id


class EndIteration(WithMetric):
Y
Yu Yang 已提交
63 64 65 66
    """
    Event On One Batch Training Complete.
    """

Y
Yu Yang 已提交
67
    def __init__(self, pass_id, batch_id, cost, evaluator):
Y
Yu Yang 已提交
68 69 70
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
Y
Yu Yang 已提交
71
        WithMetric.__init__(self, evaluator)