trainer.py 6.9 KB
Newer Older
Y
Yu Yang 已提交
1
import collections
Y
Yu Yang 已提交
2

Y
Yu Yang 已提交
3 4 5
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter

6
from data_feeder import DataFeeder
Q
qiaolongfei 已提交
7
from . import event as v2_event
Y
Yu Yang 已提交
8 9
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters
Q
qiaolongfei 已提交
10
from . import topology as v2_topology
Y
Yu Yang 已提交
11

Y
Yu Yang 已提交
12
__all__ = ['ITrainer', 'SGD']
Y
Yu Yang 已提交
13 14 15


def default_event_handler(event):
Y
Yu Yang 已提交
16 17 18 19 20 21 22
    """
    Default event handler. It will print some log and save mode.

    TODO(yuyang18): Complete it!
    :param event:
    :return:
    """
Y
Yu Yang 已提交
23 24 25 26
    pass


class ITrainer(object):
Y
Yu Yang 已提交
27 28 29 30
    """
    The interface of Trainer. The only exposed method is `train`.
    """

Y
Yu Yang 已提交
31 32 33 34 35 36
    def train(self,
              train_data_reader,
              topology,
              parameters,
              test_data_reader=None,
              event_handler=None):
Y
Yu Yang 已提交
37 38 39 40 41 42 43 44 45 46 47
        """
        train method.

        :param train_data_reader:
        :param topology:
        :param parameters:
        :param test_data_reader:
        :param event_handler:
        :return:
        """

Y
Yu Yang 已提交
48 49 50
        raise NotImplementedError()


Y
Yu Yang 已提交
51
class SGD(ITrainer):
Y
Yu Yang 已提交
52
    def __init__(self, update_equation):
Y
Yu Yang 已提交
53 54 55
        """
        Simple SGD Trainer.

Y
Yu Yang 已提交
56 57
        :param update_equation: The optimizer object.
        :type update_equation: v2_optimizer.Optimizer
Y
Yu Yang 已提交
58
        """
Y
Yu Yang 已提交
59 60 61
        if not isinstance(update_equation, v2_optimizer.Optimizer):
            raise ValueError("update equation parameter must be "
                             "paddle.v2.optimizer.Optimizer")
Y
Yu Yang 已提交
62 63 64 65 66 67 68 69 70 71
        self.__optimizer__ = update_equation

    def train(self,
              train_data_reader,
              topology,
              parameters,
              num_passes=1,
              test_data_reader=None,
              event_handler=None,
              batch_size=32,
D
dangqingqing 已提交
72 73
              data_types=None,
              reader_dict=None):
Y
Yu Yang 已提交
74 75 76 77
        """
        Training method. Will train num_passes of input data.

        :param train_data_reader:
Q
qiaolongfei 已提交
78
        :param topology: cost layers, use one or more Layers to represent it.
Y
Yu Yang 已提交
79 80 81 82 83 84 85 86 87 88
        :param parameters: The parameter pools.
        :param num_passes: The total train passes.
        :param test_data_reader:
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
        :param batch_size: Not important, will be removed after data refactor.
        :param data_types: Not important, will be removed after data refactor.
        :return:
        """
Y
Yu Yang 已提交
89 90 91
        if event_handler is None:
            event_handler = default_event_handler

Q
qiaolongfei 已提交
92 93
        topology = v2_topology.Topology(topology)

Y
Yu Yang 已提交
94 95 96
        __check_train_args__(**locals())

        gm = api.GradientMachine.createFromConfigProto(
Q
qiaolongfei 已提交
97 98
            topology.proto(), api.CREATE_MODE_NORMAL,
            self.__optimizer__.enable_types())
Y
Yu Yang 已提交
99
        assert isinstance(gm, api.GradientMachine)
Y
Yu Yang 已提交
100
        parameters.append_gradient_machine(gm)
Y
Yu Yang 已提交
101
        gm.randParameters()
Y
Yu Yang 已提交
102 103 104
        updater = self.__optimizer__.create_local_updater()
        updater.init(gm)

Y
Yu Yang 已提交
105
        gm.start()
Y
Yu Yang 已提交
106 107 108 109
        batch_evaluator = gm.makeEvaluator()
        assert isinstance(batch_evaluator, api.Evaluator)
        pass_evaluator = gm.makeEvaluator()
        assert isinstance(pass_evaluator, api.Evaluator)
Y
Yu Yang 已提交
110 111
        out_args = api.Arguments.createArguments(0)

Q
qiaolongfei 已提交
112
        data_types_lists = [data_type[1] for data_type in topology.data_type()]
Y
Yu Yang 已提交
113 114
        converter = DataProviderConverter(input_types=data_types_lists)

D
dangqingqing 已提交
115
        feeder = DataFeeder(data_types, reader_dict)
Y
Yu Yang 已提交
116 117

        for pass_id in xrange(num_passes):
Y
Yu Yang 已提交
118 119
            event_handler(v2_event.BeginPass(pass_id))
            pass_evaluator.start()
Y
Yu Yang 已提交
120 121
            updater.startPass()
            for batch_id, data_batch in enumerate(
Y
Yu Yang 已提交
122 123
                    __data_reader_to_batch__(train_data_reader, batch_size,
                                             topology)):
Y
Yu Yang 已提交
124 125 126 127
                batch_evaluator.start()
                event_handler(
                    v2_event.BeginIteration(
                        pass_id=pass_id, batch_id=batch_id))
Y
Yu Yang 已提交
128
                pass_type = updater.startBatch(len(data_batch))
D
dangqingqing 已提交
129
                gm.forwardBackward(feeder(data_batch), out_args, pass_type)
Y
Yu Yang 已提交
130 131
                gm.eval(pass_evaluator)
                gm.eval(batch_evaluator)
Y
Yu Yang 已提交
132 133 134 135 136 137 138
                for each_param in gm.getParameters():
                    updater.update(each_param)
                # Get cost. We use numpy to calculate total cost for this batch.
                cost_vec = out_args.getSlotValue(0)
                cost_vec = cost_vec.copyToNumpyMat()
                cost = cost_vec.sum() / len(data_batch)
                updater.finishBatch(cost)
Y
Yu Yang 已提交
139
                batch_evaluator.finish()
Y
Yu Yang 已提交
140
                event_handler(
Y
Yu Yang 已提交
141
                    v2_event.EndIteration(
Y
Yu Yang 已提交
142 143 144 145
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
                        evaluator=batch_evaluator))
Y
Yu Yang 已提交
146 147

            updater.finishPass()
Y
Yu Yang 已提交
148 149
            pass_evaluator.finish()
            event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
Y
Yu Yang 已提交
150 151 152
        gm.finish()


Y
Yu Yang 已提交
153 154 155 156 157 158 159 160
def __data_reader_to_batch__(reader, batch_size, topology):
    """
    This function is not important, and will be removed when data refactored.
    """

    def input_reorder(func):
        for item in func():
            retv = []
Q
qiaolongfei 已提交
161
            for __layer_name__ in topology.proto().input_layer_names:
Y
Yu Yang 已提交
162 163 164 165 166 167
                retv.append(item[__layer_name__])
            yield retv

    return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)


Y
Yu Yang 已提交
168
def __generator_to_batch__(generator, batch_size):
Y
Yu Yang 已提交
169 170 171
    """
    This function is not important, and will be removed when data refactored.
    """
Y
Yu Yang 已提交
172 173 174 175 176 177 178 179 180 181 182 183
    ret_val = list()
    for each_item in generator:
        ret_val.append(each_item)
        if len(ret_val) == batch_size:
            yield ret_val
            ret_val = list()
    if len(ret_val) != 0:
        yield ret_val


def __check_train_args__(train_data_reader, topology, parameters,
                         test_data_reader, event_handler, **kwargs):
Y
Yu Yang 已提交
184 185 186
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
187 188 189 190 191 192 193 194 195 196 197
    if not callable(train_data_reader) or not isinstance(train_data_reader(),
                                                         collections.Iterator):
        raise ValueError('train_data_reader should be a function, '
                         'which can return a iterator')

    if test_data_reader is not None:
        if not callable(test_data_reader) or not isinstance(
                test_data_reader(), collections.Iterator):
            raise ValueError('test_data_reader should be a function, which can '
                             'return a iterator')

Q
qiaolongfei 已提交
198
    if not isinstance(topology, v2_topology.Topology):
Y
Yu Yang 已提交
199 200
        raise ValueError('topology should be a model config')

Y
Yu Yang 已提交
201
    if not isinstance(parameters, v2_parameters.Parameters):
Y
Yu Yang 已提交
202 203 204 205
        raise ValueError('parameters should be a parameter pool')

    if not callable(event_handler):
        raise ValueError('event handler should be a function')