trainer.py 7.3 KB
Newer Older
Y
Yu Yang 已提交
1
import collections
Y
Yu Yang 已提交
2

Y
Yu Yang 已提交
3 4
import py_paddle.swig_paddle as api

5
from data_feeder import DataFeeder
Q
qiaolongfei 已提交
6
from topology import Topology
Q
qiaolongfei 已提交
7
from . import event as v2_event
Y
Yu Yang 已提交
8 9 10
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters

11
__all__ = ['SGD']
Y
Yu Yang 已提交
12 13 14 15
"""
Trainer package
TODO(yuyang18): Complete comments.
"""
Y
Yu Yang 已提交
16 17 18


def default_event_handler(event):
Y
Yu Yang 已提交
19 20 21 22 23 24 25
    """
    Default event handler. It will print some log and save mode.

    TODO(yuyang18): Complete it!
    :param event:
    :return:
    """
Y
Yu Yang 已提交
26 27 28
    pass


Y
Yu Yang 已提交
29 30 31 32 33 34 35 36 37 38 39
class SGD(object):
    """
    Simple SGD Trainer.
    TODO(yuyang18): Complete comments

    :param update_equation: The optimizer object.
    :type update_equation: paddle.v2.optimizer.Optimizer
    :param cost: Target cost that neural network should be optimized.
    :type cost: paddle.v2.config_base.Layer
    :param parameters: The parameters dictionary.
    :type parameters: paddle.v2.parameters.Parameters
D
dangqingqing 已提交
40 41 42
    :param extra_layers: Some layers in the neural network graph are not
                         in the path of cost layer.
    :type extra_layers: paddle.v2.config_base.Layer
Y
Yu Yang 已提交
43
    """
Y
Yu Yang 已提交
44

Q
qiaolongfei 已提交
45 46 47 48 49 50
    def __init__(self,
                 cost,
                 parameters,
                 update_equation,
                 extra_layers=None,
                 is_local=True):
51

Y
Yu Yang 已提交
52 53 54
        if not isinstance(parameters, v2_parameters.Parameters):
            raise TypeError('parameters should be parameters')

Y
Yu Yang 已提交
55
        if not isinstance(update_equation, v2_optimizer.Optimizer):
Y
Yu Yang 已提交
56 57
            raise TypeError("update equation parameter must be "
                            "paddle.v2.optimizer.Optimizer")
58
        topology = Topology(cost, extra_layers=extra_layers)
Y
Yu Yang 已提交
59
        self.__optimizer__ = update_equation
Y
Yu Yang 已提交
60 61
        self.__topology__ = topology
        self.__parameters__ = parameters
62
        self.__topology_in_proto__ = topology.proto()
Q
qiaolongfei 已提交
63 64 65 66 67 68 69 70 71 72 73
        self.__is_local__ = is_local

        self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
        # # In local mode, disable sparse_remote_update.
        if is_local:
            for param in self.__topology_in_proto__.parameters:
                if param.sparse_remote_update:
                    param.sparse_remote_update = False

        self.__gm_create_mode__ = api.CREATE_MODE_NORMAL if not \
            self.__use_sparse_updater__ else api.CREATE_MODE_SGD_SPARSE_CPU_TRAINING
Y
Yu Yang 已提交
74
        self.__data_types__ = topology.data_type()
Y
Yu Yang 已提交
75
        gm = api.GradientMachine.createFromConfigProto(
Q
qiaolongfei 已提交
76
            self.__topology_in_proto__, self.__gm_create_mode__,
Y
Yu Yang 已提交
77 78 79 80
            self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
        self.__gradient_machine__ = gm
        self.__gradient_machine__.randParameters()
Y
Yu Yang 已提交
81
        parameters.append_gradient_machine(gm)
Q
qiaolongfei 已提交
82 83 84 85
        self.__parameter_updater__ = None

    def use_remote_sparse_updater(self):
        return self.__use_sparse_updater__ and not self.__is_local__
Y
Yu Yang 已提交
86

Y
Yu Yang 已提交
87
    def train(self, reader, num_passes=1, event_handler=None, feeding=None):
Y
Yu Yang 已提交
88 89 90
        """
        Training method. Will train num_passes of input data.

Y
Yu Yang 已提交
91
        :param reader:
Y
Yu Yang 已提交
92 93 94 95
        :param num_passes: The total train passes.
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
Y
Yu Yang 已提交
96 97
        :param feeding: Feeding is a map of neural network input name and array
                        index that reader returns.
Y
Yu Yang 已提交
98
        :type feeding: dict|list
Y
Yu Yang 已提交
99 100
        :return:
        """
Y
Yu Yang 已提交
101 102 103 104
        if event_handler is None:
            event_handler = default_event_handler
        __check_train_args__(**locals())

Q
qiaolongfei 已提交
105
        if self.__is_local__:
106
            parameter_updater = self.__optimizer__.create_local_updater()
Q
qiaolongfei 已提交
107
        else:
108
            parameter_updater = self.__optimizer__.create_remote_updater(
Q
qiaolongfei 已提交
109
                num_passes, self.__use_sparse_updater__)
Q
qiaolongfei 已提交
110
        self.__parameter_updater__ = parameter_updater
111
        parameter_updater.init(self.__gradient_machine__)
Y
Yu Yang 已提交
112

Y
Yu Yang 已提交
113 114
        self.__gradient_machine__.start()
        batch_evaluator = self.__gradient_machine__.makeEvaluator()
Y
Yu Yang 已提交
115
        assert isinstance(batch_evaluator, api.Evaluator)
Y
Yu Yang 已提交
116
        pass_evaluator = self.__gradient_machine__.makeEvaluator()
Y
Yu Yang 已提交
117
        assert isinstance(pass_evaluator, api.Evaluator)
Y
Yu Yang 已提交
118
        out_args = api.Arguments.createArguments(0)
Y
Yu Yang 已提交
119
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
120
        for pass_id in xrange(num_passes):
Y
Yu Yang 已提交
121 122
            event_handler(v2_event.BeginPass(pass_id))
            pass_evaluator.start()
123
            parameter_updater.startPass()
Y
Yu Yang 已提交
124
            for batch_id, data_batch in enumerate(reader()):
Y
Yu Yang 已提交
125 126 127 128
                batch_evaluator.start()
                event_handler(
                    v2_event.BeginIteration(
                        pass_id=pass_id, batch_id=batch_id))
129
                pass_type = parameter_updater.startBatch(len(data_batch))
Q
qiaolongfei 已提交
130 131 132
                in_args = feeder(data_batch)
                if self.use_remote_sparse_updater():
                    self.__gradient_machine__.prefetch(in_args)
133
                    parameter_updater.getParametersRemote()
Q
qiaolongfei 已提交
134 135
                self.__gradient_machine__.forwardBackward(in_args, out_args,
                                                          pass_type)
Y
Yu Yang 已提交
136 137
                self.__gradient_machine__.eval(pass_evaluator)
                self.__gradient_machine__.eval(batch_evaluator)
L
liaogang 已提交
138 139
                for each_param in self.__gradient_machine__.getNonStaticParameters(
                ):
140
                    parameter_updater.update(each_param)
Y
Yu Yang 已提交
141
                cost_sum = out_args.sum()
Y
Yu Yang 已提交
142
                cost = cost_sum / len(data_batch)
143
                parameter_updater.finishBatch(cost)
Y
Yu Yang 已提交
144
                batch_evaluator.finish()
Y
Yu Yang 已提交
145
                event_handler(
Y
Yu Yang 已提交
146
                    v2_event.EndIteration(
Y
Yu Yang 已提交
147 148 149 150
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
                        evaluator=batch_evaluator))
Y
Yu Yang 已提交
151

152
            parameter_updater.finishPass()
Y
Yu Yang 已提交
153 154
            pass_evaluator.finish()
            event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
Y
Yu Yang 已提交
155 156
        self.__gradient_machine__.finish()

Y
Yu Yang 已提交
157 158
    def test(self, reader, feeding=None):
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
159 160 161
        evaluator = self.__gradient_machine__.makeEvaluator()
        out_args = api.Arguments.createArguments(0)
        evaluator.start()
Y
Yu Yang 已提交
162 163
        total_cost = 0
        num_samples = 0.0
Y
Yu Yang 已提交
164
        for data_batch in reader():
Y
Yu Yang 已提交
165
            num_samples += len(data_batch)
Q
qiaolongfei 已提交
166 167 168 169 170
            in_args = feeder(data_batch)
            if self.use_remote_sparse_updater():
                self.__gradient_machine__.prefetch(in_args)
                self.__parameter_updater__.getParametersRemote()
            self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
Y
Yu Yang 已提交
171
            total_cost += out_args.sum()
Y
Yu Yang 已提交
172
            self.__gradient_machine__.eval(evaluator)
Y
Yu Yang 已提交
173

Y
Yu Yang 已提交
174
        evaluator.finish()
Y
Yu Yang 已提交
175 176
        return v2_event.TestResult(
            evaluator=evaluator, cost=total_cost / num_samples)
Y
Yu Yang 已提交
177 178 179


def __check_train_args__(reader, event_handler, **kwargs):
Y
Yu Yang 已提交
180 181 182
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
183
    if not callable(reader) or not isinstance(reader(), collections.Iterator):
Y
Yu Yang 已提交
184 185
        raise TypeError('train_data_reader should be a function, '
                        'which can return a iterator')
Y
Yu Yang 已提交
186
    if not callable(event_handler):
Y
Yu Yang 已提交
187
        raise TypeError('event handler should be a function')