event.py 1.7 KB
Newer Older
Y
Yu Yang 已提交
1 2
"""
All training events.
3 4 5 6 7 8 9 10 11

There are:

* BeginIteration
* EndIteration
* BeginPass
* EndPass

TODO(yuyang18): Complete it!
Y
Yu Yang 已提交
12
"""
13
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
14 15 16 17

__all__ = [
    'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
Y
Yu Yang 已提交
18 19


20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
class WithMetric(object):
    def __init__(self, evaluator):
        if not isinstance(evaluator, api.Evaluator):
            raise TypeError("Evaluator should be api.Evaluator type")
        self.__evaluator__ = evaluator

    @property
    def metrics(self):
        names = self.__evaluator__.getNames()
        retv = dict()
        for each_name in names:
            val = self.__evaluator__.getValue(each_name)
            retv[each_name] = val
        return retv


Y
Yu Yang 已提交
36
class TestResult(WithMetric):
Y
Yu Yang 已提交
37
    def __init__(self, evaluator, cost):
Y
Yu Yang 已提交
38
        super(TestResult, self).__init__(evaluator)
Y
Yu Yang 已提交
39
        self.cost = cost
Y
Yu Yang 已提交
40 41


42 43 44 45 46 47 48 49 50 51 52 53 54 55
class BeginPass(object):
    """
    Event On One Pass Training Start.
    """

    def __init__(self, pass_id):
        self.pass_id = pass_id


class EndPass(WithMetric):
    """
    Event On One Pass Training Complete.
    """

L
Luo Tao 已提交
56
    def __init__(self, pass_id, cost, evaluator):
57
        self.pass_id = pass_id
L
Luo Tao 已提交
58
        self.cost = cost
59 60 61 62 63 64 65 66 67 68 69 70 71 72
        WithMetric.__init__(self, evaluator)


class BeginIteration(object):
    """
    Event On One Batch Training Start.
    """

    def __init__(self, pass_id, batch_id):
        self.pass_id = pass_id
        self.batch_id = batch_id


class EndIteration(WithMetric):
Y
Yu Yang 已提交
73 74 75 76
    """
    Event On One Batch Training Complete.
    """

77
    def __init__(self, pass_id, batch_id, cost, evaluator):
Y
Yu Yang 已提交
78 79 80
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
81
        WithMetric.__init__(self, evaluator)
新手
引导
客服 返回
顶部