trainer.py 6.2 KB
Newer Older
Q
qijun 已提交
1 2 3
"""
Trainer package
"""
Y
Yu Yang 已提交
4
import collections
Y
Yu Yang 已提交
5

Y
Yu Yang 已提交
6 7
import py_paddle.swig_paddle as api

8
from data_feeder import DataFeeder
Q
qiaolongfei 已提交
9
from topology import Topology
Q
qiaolongfei 已提交
10
from . import event as v2_event
Y
Yu Yang 已提交
11 12 13
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters

14
__all__ = ['SGD']
Q
qijun 已提交
15

Y
Yu Yang 已提交
16 17

def default_event_handler(event):
Y
Yu Yang 已提交
18 19 20 21 22 23 24
    """
    Default event handler. It will print some log and save mode.

    TODO(yuyang18): Complete it!
    :param event:
    :return:
    """
Y
Yu Yang 已提交
25 26 27
    pass


Y
Yu Yang 已提交
28 29 30
class SGD(object):
    """
    Simple SGD Trainer.
Q
qijun 已提交
31 32
    SGD Trainer combines data reader, network topolopy and update_equation together
    to train/test a neural network.
Y
Yu Yang 已提交
33 34 35 36 37 38 39 40

    :param update_equation: The optimizer object.
    :type update_equation: paddle.v2.optimizer.Optimizer
    :param cost: Target cost that neural network should be optimized.
    :type cost: paddle.v2.config_base.Layer
    :param parameters: The parameters dictionary.
    :type parameters: paddle.v2.parameters.Parameters
    """
Y
Yu Yang 已提交
41

Y
Yu Yang 已提交
42
    def __init__(self, cost, parameters, update_equation):
43

Y
Yu Yang 已提交
44 45 46
        if not isinstance(parameters, v2_parameters.Parameters):
            raise TypeError('parameters should be parameters')

Y
Yu Yang 已提交
47
        if not isinstance(update_equation, v2_optimizer.Optimizer):
Y
Yu Yang 已提交
48 49
            raise TypeError("update equation parameter must be "
                            "paddle.v2.optimizer.Optimizer")
50
        topology = Topology(cost)
Y
Yu Yang 已提交
51
        self.__optimizer__ = update_equation
Y
Yu Yang 已提交
52 53
        self.__topology__ = topology
        self.__parameters__ = parameters
54
        self.__topology_in_proto__ = topology.proto()
Y
Yu Yang 已提交
55
        self.__data_types__ = topology.data_type()
Y
Yu Yang 已提交
56 57 58 59 60 61
        gm = api.GradientMachine.createFromConfigProto(
            self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
            self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
        self.__gradient_machine__ = gm
        self.__gradient_machine__.randParameters()
Y
Yu Yang 已提交
62
        parameters.append_gradient_machine(gm)
Y
Yu Yang 已提交
63

Y
Yu Yang 已提交
64
    def train(self, reader, num_passes=1, event_handler=None, feeding=None):
Y
Yu Yang 已提交
65 66 67
        """
        Training method. Will train num_passes of input data.

Q
qijun 已提交
68 69 70
        :param reader: A reader that reads and yeilds data items. Usually we use a
                       batched reader to do mini-batch training.
        :type reader: collections.Iterable
Y
Yu Yang 已提交
71 72 73 74
        :param num_passes: The total train passes.
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
Y
Yu Yang 已提交
75 76 77
        :param feeding: Feeding is a map of neural network input name and array
                        index that reader returns.
        :type feeding: dict
Y
Yu Yang 已提交
78 79
        :return:
        """
Y
Yu Yang 已提交
80 81 82 83 84
        if event_handler is None:
            event_handler = default_event_handler
        __check_train_args__(**locals())

        updater = self.__optimizer__.create_local_updater()
Y
Yu Yang 已提交
85
        updater.init(self.__gradient_machine__)
Y
Yu Yang 已提交
86

Y
Yu Yang 已提交
87 88
        self.__gradient_machine__.start()
        batch_evaluator = self.__gradient_machine__.makeEvaluator()
Y
Yu Yang 已提交
89
        assert isinstance(batch_evaluator, api.Evaluator)
Y
Yu Yang 已提交
90
        pass_evaluator = self.__gradient_machine__.makeEvaluator()
Y
Yu Yang 已提交
91
        assert isinstance(pass_evaluator, api.Evaluator)
Y
Yu Yang 已提交
92
        out_args = api.Arguments.createArguments(0)
Y
Yu Yang 已提交
93
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
94
        for pass_id in xrange(num_passes):
Y
Yu Yang 已提交
95 96
            event_handler(v2_event.BeginPass(pass_id))
            pass_evaluator.start()
Y
Yu Yang 已提交
97
            updater.startPass()
Y
Yu Yang 已提交
98
            for batch_id, data_batch in enumerate(reader()):
Y
Yu Yang 已提交
99 100 101 102
                batch_evaluator.start()
                event_handler(
                    v2_event.BeginIteration(
                        pass_id=pass_id, batch_id=batch_id))
Y
Yu Yang 已提交
103
                pass_type = updater.startBatch(len(data_batch))
Y
Yu Yang 已提交
104 105 106 107
                self.__gradient_machine__.forwardBackward(
                    feeder(data_batch), out_args, pass_type)
                self.__gradient_machine__.eval(pass_evaluator)
                self.__gradient_machine__.eval(batch_evaluator)
L
liaogang 已提交
108 109
                for each_param in self.__gradient_machine__.getNonStaticParameters(
                ):
Y
Yu Yang 已提交
110
                    updater.update(each_param)
Y
Yu Yang 已提交
111
                cost_sum = out_args.sum()
Y
Yu Yang 已提交
112
                cost = cost_sum / len(data_batch)
Y
Yu Yang 已提交
113
                updater.finishBatch(cost)
Y
Yu Yang 已提交
114
                batch_evaluator.finish()
Y
Yu Yang 已提交
115
                event_handler(
Y
Yu Yang 已提交
116
                    v2_event.EndIteration(
Y
Yu Yang 已提交
117 118 119 120
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
                        evaluator=batch_evaluator))
Y
Yu Yang 已提交
121 122

            updater.finishPass()
Y
Yu Yang 已提交
123 124
            pass_evaluator.finish()
            event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
Y
Yu Yang 已提交
125 126
        self.__gradient_machine__.finish()

Y
Yu Yang 已提交
127
    def test(self, reader, feeding=None):
Q
qijun 已提交
128 129 130 131
        """
        Testing method. Will test input data.

        :param reader: A reader that reads and yeilds data items.
Q
qijun 已提交
132
        :type reader: collections.Iterable  
Q
qijun 已提交
133 134 135 136 137
        :param feeding: Feeding is a map of neural network input name and array
                        index that reader returns.
        :type feeding: dict
        :return:
        """
Y
Yu Yang 已提交
138
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
139 140 141
        evaluator = self.__gradient_machine__.makeEvaluator()
        out_args = api.Arguments.createArguments(0)
        evaluator.start()
Y
Yu Yang 已提交
142 143
        total_cost = 0
        num_samples = 0.0
Y
Yu Yang 已提交
144
        for data_batch in reader():
Y
Yu Yang 已提交
145
            num_samples += len(data_batch)
Y
Yu Yang 已提交
146 147
            self.__gradient_machine__.forward(
                feeder(data_batch), out_args, api.PASS_TEST)
Y
Yu Yang 已提交
148
            total_cost += out_args.sum()
Y
Yu Yang 已提交
149
            self.__gradient_machine__.eval(evaluator)
Y
Yu Yang 已提交
150

Y
Yu Yang 已提交
151
        evaluator.finish()
Y
Yu Yang 已提交
152 153
        return v2_event.TestResult(
            evaluator=evaluator, cost=total_cost / num_samples)
Y
Yu Yang 已提交
154 155 156


def __check_train_args__(reader, event_handler, **kwargs):
Y
Yu Yang 已提交
157 158 159
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
160
    if not callable(reader) or not isinstance(reader(), collections.Iterator):
Y
Yu Yang 已提交
161 162
        raise TypeError('train_data_reader should be a function, '
                        'which can return a iterator')
Y
Yu Yang 已提交
163
    if not callable(event_handler):
Y
Yu Yang 已提交
164
        raise TypeError('event handler should be a function')