trainer.py 10.4 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Q
qijun 已提交
14
"""
Q
qijun 已提交
15
Module Trainer
Q
qijun 已提交
16
"""
Y
Yu Yang 已提交
17
import collections
Q
qiaolongfei 已提交
18
from topology import Topology
Q
qiaolongfei 已提交
19
from . import event as v2_event
Y
Yu Yang 已提交
20 21 22
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters

23
__all__ = ['SGD']
Q
qijun 已提交
24

Y
Yu Yang 已提交
25 26

def default_event_handler(event):
Y
Yu Yang 已提交
27 28 29 30 31 32 33
    """
    Default event handler. It will print some log and save mode.

    TODO(yuyang18): Complete it!
    :param event:
    :return:
    """
Y
Yu Yang 已提交
34 35 36
    pass


37 38 39
class SGD(object):
    """
    Simple SGD Trainer.
Q
qijun 已提交
40 41
    SGD Trainer combines data reader, network topolopy and update_equation together
    to train/test a neural network.
42 43 44 45 46

    :param cost: Target cost that neural network should be optimized.
    :type cost: paddle.v2.config_base.Layer
    :param parameters: The parameters dictionary.
    :type parameters: paddle.v2.parameters.Parameters
W
wanghaoshuang 已提交
47 48
    :param update_equation: The optimizer object.
    :type update_equation: paddle.v2.optimizer.Optimizer
D
dangqingqing 已提交
49 50 51
    :param extra_layers: Some layers in the neural network graph are not
                         in the path of cost layer.
    :type extra_layers: paddle.v2.config_base.Layer
W
wanghaoshuang 已提交
52 53
    :param is_local: Whether trainning locally
    :type is_local: bool
54 55 56 57
    :param pserver_spec: comma string for pserver location,
                         eg:127.10.0.10:3000,127.10.0.11:3000,
                         and this parameter is only used for fault
                         tolerant mode cluster training.
W
wanghaoshuang 已提交
58 59 60
    :type pserver_spec: string
    :param use_etcd: Whether using etcd pserver.
    :param use_etcd: bool
61
    """
Y
Yu Yang 已提交
62

Q
qiaolongfei 已提交
63 64 65 66 67
    def __init__(self,
                 cost,
                 parameters,
                 update_equation,
                 extra_layers=None,
68
                 is_local=True,
69 70
                 pserver_spec=None,
                 use_etcd=True):
71

Y
Yu Yang 已提交
72 73 74
        if not isinstance(parameters, v2_parameters.Parameters):
            raise TypeError('parameters should be parameters')

Y
Yu Yang 已提交
75
        if not isinstance(update_equation, v2_optimizer.Optimizer):
Y
Yu Yang 已提交
76 77
            raise TypeError("update equation parameter must be "
                            "paddle.v2.optimizer.Optimizer")
Y
Yu Yang 已提交
78
        import py_paddle.swig_paddle as api
79
        topology = Topology(cost, extra_layers=extra_layers)
T
update  
typhoonzero 已提交
80 81 82 83 84
        # HACK(typhoonzero): update ParameterConfig(proto) in case of optimizers
        # are defined after layers, or between layers.
        topology.update_from_default()
        parameters.update_param_conf(topology.proto())

Y
Yu Yang 已提交
85
        self.__optimizer__ = update_equation
Y
Yu Yang 已提交
86 87
        self.__topology__ = topology
        self.__parameters__ = parameters
88
        self.__topology_in_proto__ = topology.proto()
89
        self.__is_local__ = is_local
90
        self.__pserver_spec__ = pserver_spec
91
        self.__use_etcd__ = use_etcd
92

93 94 95 96 97 98
        self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
        # # In local mode, disable sparse_remote_update.
        if is_local:
            for param in self.__topology_in_proto__.parameters:
                if param.sparse_remote_update:
                    param.sparse_remote_update = False
99

100 101
        self.__gm_create_mode__ = api.CREATE_MODE_NORMAL if not \
            self.__use_sparse_updater__ else api.CREATE_MODE_SGD_SPARSE_CPU_TRAINING
Y
Yu Yang 已提交
102
        self.__data_types__ = topology.data_type()
Y
Yu Yang 已提交
103
        gm = api.GradientMachine.createFromConfigProto(
104
            self.__topology_in_proto__, self.__gm_create_mode__,
Y
Yu Yang 已提交
105 106 107 108
            self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
        self.__gradient_machine__ = gm
        self.__gradient_machine__.randParameters()
Q
qiaolongfei 已提交
109
        self.__parameters__.append_gradient_machine(gm)
110 111
        self.__parameter_updater__ = None

T
typhoonzero 已提交
112 113 114
    def get_topology_proto(self):
        return self.__topology_in_proto__

Q
qiaolongfei 已提交
115
    def __use_remote_sparse_updater__(self):
116
        return self.__use_sparse_updater__ and not self.__is_local__
Y
Yu Yang 已提交
117

Q
qiaolongfei 已提交
118 119 120 121 122 123 124 125 126 127 128 129
    def __prepare_parameter__(self, in_args):
        """
        prepare parameter before forward backward.
        1. When use remote sparse updater, parameters should be got
        from ps according to input arguments.
        :param in_args: input arguments of this batch.
        :return:
        """
        if self.__use_remote_sparse_updater__():
            self.__gradient_machine__.prefetch(in_args)
            self.__parameter_updater__.getParametersRemote()

130
    def save_parameter_to_tar(self, f):
131 132 133
        self.__parameter_updater__.catchUpWith()
        self.__parameter_updater__.apply()
        self.__parameter_updater__.getParametersRemote(True, True)
134
        self.__parameters__.to_tar(f)
135
        self.__parameter_updater__.restore()
Y
Yu Yang 已提交
136

Y
Yu Yang 已提交
137
    def train(self, reader, num_passes=1, event_handler=None, feeding=None):
Y
Yu Yang 已提交
138 139 140
        """
        Training method. Will train num_passes of input data.

Q
qijun 已提交
141 142 143
        :param reader: A reader that reads and yeilds data items. Usually we use a
                       batched reader to do mini-batch training.
        :type reader: collections.Iterable
Y
Yu Yang 已提交
144 145 146 147
        :param num_passes: The total train passes.
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
Y
Yu Yang 已提交
148 149
        :param feeding: Feeding is a map of neural network input name and array
                        index that reader returns.
Y
Yu Yang 已提交
150
        :type feeding: dict|list
Y
Yu Yang 已提交
151 152
        :return:
        """
Y
Yu Yang 已提交
153 154
        import py_paddle.swig_paddle as api
        from data_feeder import DataFeeder
Y
Yu Yang 已提交
155 156 157 158
        if event_handler is None:
            event_handler = default_event_handler
        __check_train_args__(**locals())

159
        self.__parameter_updater__ = self.__optimizer__.create_updater(
160
            self.__is_local__, num_passes, self.__use_sparse_updater__,
161
            self.__pserver_spec__, self.__use_etcd__)
162
        self.__parameter_updater__.init(self.__gradient_machine__)
Y
Yu Yang 已提交
163

Y
Yu Yang 已提交
164 165
        self.__gradient_machine__.start()
        batch_evaluator = self.__gradient_machine__.makeEvaluator()
166
        assert isinstance(batch_evaluator, api.Evaluator)
Y
Yu Yang 已提交
167
        pass_evaluator = self.__gradient_machine__.makeEvaluator()
168
        assert isinstance(pass_evaluator, api.Evaluator)
Y
Yu Yang 已提交
169
        out_args = api.Arguments.createArguments(0)
Y
Yu Yang 已提交
170
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
171
        for pass_id in xrange(num_passes):
172 173
            event_handler(v2_event.BeginPass(pass_id))
            pass_evaluator.start()
174
            self.__parameter_updater__.startPass()
Y
Yu Yang 已提交
175
            for batch_id, data_batch in enumerate(reader()):
176 177 178 179
                batch_evaluator.start()
                event_handler(
                    v2_event.BeginIteration(
                        pass_id=pass_id, batch_id=batch_id))
180 181
                pass_type = self.__parameter_updater__.startBatch(
                    len(data_batch))
182
                in_args = feeder(data_batch)
Q
qiaolongfei 已提交
183
                self.__prepare_parameter__(in_args)
184 185
                self.__gradient_machine__.forwardBackward(in_args, out_args,
                                                          pass_type)
Y
Yu Yang 已提交
186 187
                self.__gradient_machine__.eval(pass_evaluator)
                self.__gradient_machine__.eval(batch_evaluator)
武毅 已提交
188 189 190 191 192
                event_handler(
                    v2_event.EndForwardBackward(
                        pass_id=pass_id,
                        batch_id=batch_id,
                        gm=self.__gradient_machine__))
193 194
                for each_param in self.__gradient_machine__.getNonStaticParameters(
                ):
195
                    self.__parameter_updater__.update(each_param)
Y
Yu Yang 已提交
196
                cost_sum = out_args.sum()
Y
Yu Yang 已提交
197
                cost = cost_sum / len(data_batch)
武毅 已提交
198 199
                self.__parameter_updater__.finishBatch(cost)
                batch_evaluator.finish()
Y
Yu Yang 已提交
200
                event_handler(
Y
Yu Yang 已提交
201
                    v2_event.EndIteration(
202 203 204
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
205 206
                        evaluator=batch_evaluator,
                        gm=self.__gradient_machine__))
Y
Yu Yang 已提交
207

208
            self.__parameter_updater__.finishPass()
209
            pass_evaluator.finish()
210 211 212 213 214
            event_handler(
                v2_event.EndPass(
                    pass_id,
                    evaluator=pass_evaluator,
                    gm=self.__gradient_machine__))
Y
Yu Yang 已提交
215 216
        self.__gradient_machine__.finish()

Y
Yu Yang 已提交
217
    def test(self, reader, feeding=None):
Q
qijun 已提交
218 219 220
        """
        Testing method. Will test input data.

221 222
        :param reader: A batch reader that reads and yeilds data items,
                       it should be a paddle.v2.batch.
223
        :type reader: collections.Iterable
Q
qijun 已提交
224 225 226 227 228
        :param feeding: Feeding is a map of neural network input name and array
                        index that reader returns.
        :type feeding: dict
        :return:
        """
Y
Yu Yang 已提交
229 230
        import py_paddle.swig_paddle as api
        from data_feeder import DataFeeder
Y
Yu Yang 已提交
231
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
232 233 234
        evaluator = self.__gradient_machine__.makeEvaluator()
        out_args = api.Arguments.createArguments(0)
        evaluator.start()
Y
Yu Yang 已提交
235 236
        total_cost = 0
        num_samples = 0.0
Y
Yu Yang 已提交
237
        for data_batch in reader():
Y
Yu Yang 已提交
238
            num_samples += len(data_batch)
239
            in_args = feeder(data_batch)
Q
qiaolongfei 已提交
240
            self.__prepare_parameter__(in_args)
241
            self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
Y
Yu Yang 已提交
242
            total_cost += out_args.sum()
Y
Yu Yang 已提交
243
            self.__gradient_machine__.eval(evaluator)
Y
Yu Yang 已提交
244

Y
Yu Yang 已提交
245
        evaluator.finish()
Y
Yu Yang 已提交
246 247
        return v2_event.TestResult(
            evaluator=evaluator, cost=total_cost / num_samples)
Y
Yu Yang 已提交
248 249 250


def __check_train_args__(reader, event_handler, **kwargs):
Y
Yu Yang 已提交
251 252 253
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
254
    if not callable(reader) or not isinstance(reader(), collections.Iterator):
255 256
        raise TypeError('train_data_reader should be a function, '
                        'which can return a iterator')
Y
Yu Yang 已提交
257
    if not callable(event_handler):
258
        raise TypeError('event handler should be a function')