trainer.py 7.0 KB
Newer Older
Y
Yu Yang 已提交
1
import collections
Y
Yu Yang 已提交
2

Y
Yu Yang 已提交
3
import py_paddle.swig_paddle as api
Q
qiaolongfei 已提交
4
from paddle.proto.ModelConfig_pb2 import ModelConfig
D
dangqingqing 已提交
5
from data_feeder import DataFeeder
Y
Yu Yang 已提交
6

Q
qiaolongfei 已提交
7 8
from . import event as v2_event
from . import layer as v2_layer
Y
Yu Yang 已提交
9 10 11
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters

Y
Yu Yang 已提交
12
__all__ = ['ITrainer', 'SGD']
Y
Yu Yang 已提交
13 14 15


def default_event_handler(event):
Y
Yu Yang 已提交
16 17 18 19 20 21 22
    """
    Default event handler. It will print some log and save mode.

    TODO(yuyang18): Complete it!
    :param event:
    :return:
    """
Y
Yu Yang 已提交
23 24 25
    pass


Y
Yu Yang 已提交
26 27 28 29 30 31 32
def __bfs_travel_topology__(callback, *topologies):
    for each_layer in topologies:
        callback(each_layer)
        __bfs_travel_topology__(callback,
                                *each_layer.__parent_layers__.values())


Y
Yu Yang 已提交
33
class ITrainer(object):
Y
Yu Yang 已提交
34 35 36 37
    """
    The interface of Trainer. The only exposed method is `train`.
    """

Y
Yu Yang 已提交
38
    def train(self, reader, topology, parameters, event_handler=None):
Y
Yu Yang 已提交
39 40 41
        """
        train method.

Y
Yu Yang 已提交
42
        :param reader:
Y
Yu Yang 已提交
43 44 45 46 47 48
        :param topology:
        :param parameters:
        :param event_handler:
        :return:
        """

Y
Yu Yang 已提交
49 50 51
        raise NotImplementedError()


Y
Yu Yang 已提交
52
class SGD(ITrainer):
Y
Yu Yang 已提交
53
    def __init__(self, topology, parameters, update_equation):
Y
Yu Yang 已提交
54 55 56
        """
        Simple SGD Trainer.

Y
Yu Yang 已提交
57 58
        :param update_equation: The optimizer object.
        :type update_equation: v2_optimizer.Optimizer
Y
Yu Yang 已提交
59
        """
Y
Yu Yang 已提交
60 61 62
        if not isinstance(parameters, v2_parameters.Parameters):
            raise TypeError('parameters should be parameters')

Y
Yu Yang 已提交
63
        if not isinstance(update_equation, v2_optimizer.Optimizer):
Y
Yu Yang 已提交
64 65
            raise TypeError("update equation parameter must be "
                            "paddle.v2.optimizer.Optimizer")
Y
Yu Yang 已提交
66
        self.__optimizer__ = update_equation
Y
Yu Yang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        self.__topology__ = topology
        self.__parameters__ = parameters
        self.__topology_in_proto__ = v2_layer.parse_network(topology)
        data_types = dict()

        def __travel__(l):
            if hasattr(l, 'type'):
                data_types[l.name] = l.type

        if not isinstance(topology, collections.Sequence):
            topology = [topology]
        __bfs_travel_topology__(__travel__, *topology)
        self.__data_types__ = [
            (iname, data_types[iname])
            for iname in self.__topology_in_proto__.input_layer_names
        ]

        if not isinstance(self.__topology_in_proto__, ModelConfig):
            raise TypeError('topology should be a model config')
Y
Yu Yang 已提交
86

Y
Yu Yang 已提交
87 88 89 90 91 92 93 94 95
        gm = api.GradientMachine.createFromConfigProto(
            self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
            self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
        parameters.append_gradient_machine(gm)
        self.__gradient_machine__ = gm
        self.__gradient_machine__.randParameters()

    def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
Y
Yu Yang 已提交
96 97 98
        """
        Training method. Will train num_passes of input data.

Y
Yu Yang 已提交
99
        :param reader:
Q
qiaolongfei 已提交
100
        :param topology: Network Topology, use one or more Layers to represent it.
Y
Yu Yang 已提交
101 102 103 104 105 106 107 108
        :param parameters: The parameter pools.
        :param num_passes: The total train passes.
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
        :param data_types: Not important, will be removed after data refactor.
        :return:
        """
Y
Yu Yang 已提交
109 110 111
        if event_handler is None:
            event_handler = default_event_handler

Y
Yu Yang 已提交
112 113
        if reader_dict is None:
            reader_dict = self.default_reader_dict()
Q
qiaolongfei 已提交
114

Y
Yu Yang 已提交
115 116
        __check_train_args__(**locals())
        updater = self.__optimizer__.create_local_updater()
Y
Yu Yang 已提交
117
        updater.init(self.__gradient_machine__)
Y
Yu Yang 已提交
118

Y
Yu Yang 已提交
119 120
        self.__gradient_machine__.start()
        batch_evaluator = self.__gradient_machine__.makeEvaluator()
Y
Yu Yang 已提交
121
        assert isinstance(batch_evaluator, api.Evaluator)
Y
Yu Yang 已提交
122
        pass_evaluator = self.__gradient_machine__.makeEvaluator()
Y
Yu Yang 已提交
123
        assert isinstance(pass_evaluator, api.Evaluator)
Y
Yu Yang 已提交
124 125
        out_args = api.Arguments.createArguments(0)

Y
Yu Yang 已提交
126
        feeder = DataFeeder(self.__data_types__, reader_dict)
Y
Yu Yang 已提交
127 128

        for pass_id in xrange(num_passes):
Y
Yu Yang 已提交
129 130
            event_handler(v2_event.BeginPass(pass_id))
            pass_evaluator.start()
Y
Yu Yang 已提交
131
            updater.startPass()
Y
Yu Yang 已提交
132
            for batch_id, data_batch in enumerate(reader()):
Y
Yu Yang 已提交
133
                pass_type = updater.startBatch(len(data_batch))
Y
Yu Yang 已提交
134 135
                self.__gradient_machine__.forwardBackward(
                    feeder(data_batch), out_args, pass_type)
Y
Yu Yang 已提交
136 137 138 139
                batch_evaluator.start()
                event_handler(
                    v2_event.BeginIteration(
                        pass_id=pass_id, batch_id=batch_id))
Y
Yu Yang 已提交
140
                pass_type = updater.startBatch(len(data_batch))
Y
Yu Yang 已提交
141 142 143 144 145
                self.__gradient_machine__.forwardBackward(
                    feeder(data_batch), out_args, pass_type)
                self.__gradient_machine__.eval(pass_evaluator)
                self.__gradient_machine__.eval(batch_evaluator)
                for each_param in self.__gradient_machine__.getParameters():
Y
Yu Yang 已提交
146 147 148 149 150 151
                    updater.update(each_param)
                # Get cost. We use numpy to calculate total cost for this batch.
                cost_vec = out_args.getSlotValue(0)
                cost_vec = cost_vec.copyToNumpyMat()
                cost = cost_vec.sum() / len(data_batch)
                updater.finishBatch(cost)
Y
Yu Yang 已提交
152
                batch_evaluator.finish()
Y
Yu Yang 已提交
153
                event_handler(
Y
Yu Yang 已提交
154
                    v2_event.EndIteration(
Y
Yu Yang 已提交
155 156 157 158
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
                        evaluator=batch_evaluator))
Y
Yu Yang 已提交
159 160

            updater.finishPass()
Y
Yu Yang 已提交
161 162
            pass_evaluator.finish()
            event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
Y
Yu Yang 已提交
163 164 165 166 167 168 169
        self.__gradient_machine__.finish()

    def default_reader_dict(self):
        reader_dict = dict()
        for i, tp in enumerate(self.__data_types__):
            reader_dict[tp[0]] = i
        return reader_dict
Y
Yu Yang 已提交
170

Y
Yu Yang 已提交
171 172 173 174 175 176 177 178 179 180 181 182
    def test(self, reader, reader_dict=None):
        if reader_dict is None:
            reader_dict = self.default_reader_dict()

        feeder = DataFeeder(self.__data_types__, reader_dict)
        evaluator = self.__gradient_machine__.makeEvaluator()
        out_args = api.Arguments.createArguments(0)
        evaluator.start()
        for data_batch in reader():
            self.__gradient_machine__.forward(
                feeder(data_batch), out_args, api.PASS_TEST)
            self.__gradient_machine__.eval(evaluator)
Y
Yu Yang 已提交
183

Y
Yu Yang 已提交
184 185 186 187 188
        evaluator.finish()
        return v2_event.TestResult(evaluator=evaluator)


def __check_train_args__(reader, event_handler, **kwargs):
Y
Yu Yang 已提交
189 190 191
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
192
    if not callable(reader) or not isinstance(reader(), collections.Iterator):
Y
Yu Yang 已提交
193 194
        raise TypeError('train_data_reader should be a function, '
                        'which can return a iterator')
Y
Yu Yang 已提交
195 196

    if not callable(event_handler):
Y
Yu Yang 已提交
197
        raise TypeError('event handler should be a function')