trainer.py 4.2 KB
Newer Older
Y
Yu Yang 已提交
1
import collections
Y
Yu Yang 已提交
2

Y
Yu Yang 已提交
3
import py_paddle.swig_paddle as api
Q
qiaolongfei 已提交
4
from paddle.proto.ModelConfig_pb2 import ModelConfig
D
dangqingqing 已提交
5
from data_feeder import DataFeeder
Y
Yu Yang 已提交
6

Q
qiaolongfei 已提交
7 8
from . import event as v2_event
from . import layer as v2_layer
Y
Yu Yang 已提交
9 10 11
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters

Y
Yu Yang 已提交
12
__all__ = ['ITrainer', 'SGD']
Y
Yu Yang 已提交
13 14 15


def default_event_handler(event):
Y
Yu Yang 已提交
16 17 18 19 20 21 22
    """
    Default event handler. It will print some log and save mode.

    TODO(yuyang18): Complete it!
    :param event:
    :return:
    """
Y
Yu Yang 已提交
23 24 25 26
    pass


class ITrainer(object):
Y
Yu Yang 已提交
27 28 29 30
    """
    The interface of Trainer. The only exposed method is `train`.
    """

Y
Yu Yang 已提交
31
    def train(self, reader, topology, parameters, event_handler=None):
Y
Yu Yang 已提交
32 33 34
        """
        train method.

Y
Yu Yang 已提交
35
        :param reader:
Y
Yu Yang 已提交
36 37 38 39 40 41
        :param topology:
        :param parameters:
        :param event_handler:
        :return:
        """

Y
Yu Yang 已提交
42 43 44
        raise NotImplementedError()


Y
Yu Yang 已提交
45
class SGD(ITrainer):
Y
Yu Yang 已提交
46
    def __init__(self, update_equation):
Y
Yu Yang 已提交
47 48 49
        """
        Simple SGD Trainer.

Y
Yu Yang 已提交
50 51
        :param update_equation: The optimizer object.
        :type update_equation: v2_optimizer.Optimizer
Y
Yu Yang 已提交
52
        """
Y
Yu Yang 已提交
53 54 55
        if not isinstance(update_equation, v2_optimizer.Optimizer):
            raise ValueError("update equation parameter must be "
                             "paddle.v2.optimizer.Optimizer")
Y
Yu Yang 已提交
56 57 58
        self.__optimizer__ = update_equation

    def train(self,
Y
Yu Yang 已提交
59
              reader,
Y
Yu Yang 已提交
60 61 62 63
              topology,
              parameters,
              num_passes=1,
              event_handler=None,
D
dangqingqing 已提交
64 65
              data_types=None,
              reader_dict=None):
Y
Yu Yang 已提交
66 67 68
        """
        Training method. Will train num_passes of input data.

Y
Yu Yang 已提交
69
        :param reader:
Q
qiaolongfei 已提交
70
        :param topology: Network Topology, use one or more Layers to represent it.
Y
Yu Yang 已提交
71 72 73 74 75 76 77 78
        :param parameters: The parameter pools.
        :param num_passes: The total train passes.
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
        :param data_types: Not important, will be removed after data refactor.
        :return:
        """
Y
Yu Yang 已提交
79 80 81
        if event_handler is None:
            event_handler = default_event_handler

Q
qiaolongfei 已提交
82 83
        topology = v2_layer.parse_network(topology)

Y
Yu Yang 已提交
84 85 86 87 88
        __check_train_args__(**locals())

        gm = api.GradientMachine.createFromConfigProto(
            topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
Y
Yu Yang 已提交
89
        parameters.append_gradient_machine(gm)
Y
Yu Yang 已提交
90 91 92 93

        updater = self.__optimizer__.create_local_updater()
        updater.init(gm)

Y
Yu Yang 已提交
94 95 96
        gm.start()
        out_args = api.Arguments.createArguments(0)

D
dangqingqing 已提交
97
        feeder = DataFeeder(data_types, reader_dict)
Y
Yu Yang 已提交
98 99 100

        for pass_id in xrange(num_passes):
            updater.startPass()
Y
Yu Yang 已提交
101
            for batch_id, data_batch in enumerate(reader()):
Y
Yu Yang 已提交
102
                pass_type = updater.startBatch(len(data_batch))
D
dangqingqing 已提交
103
                gm.forwardBackward(feeder(data_batch), out_args, pass_type)
Y
Yu Yang 已提交
104 105 106 107 108 109 110 111
                for each_param in gm.getParameters():
                    updater.update(each_param)
                # Get cost. We use numpy to calculate total cost for this batch.
                cost_vec = out_args.getSlotValue(0)
                cost_vec = cost_vec.copyToNumpyMat()
                cost = cost_vec.sum() / len(data_batch)
                updater.finishBatch(cost)
                event_handler(
Y
Yu Yang 已提交
112
                    v2_event.EndIteration(
Y
Yu Yang 已提交
113
                        pass_id=pass_id, batch_id=batch_id, cost=cost))
Y
Yu Yang 已提交
114 115 116 117 118

            updater.finishPass()
        gm.finish()


Y
Yu Yang 已提交
119
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
Y
Yu Yang 已提交
120 121 122
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
123
    if not callable(reader) or not isinstance(reader(), collections.Iterator):
Y
Yu Yang 已提交
124 125
        raise TypeError('train_data_reader should be a function, '
                        'which can return a iterator')
Y
Yu Yang 已提交
126 127

    if not isinstance(topology, ModelConfig):
Y
Yu Yang 已提交
128
        raise TypeError('topology should be a model config')
Y
Yu Yang 已提交
129

Y
Yu Yang 已提交
130
    if not isinstance(parameters, v2_parameters.Parameters):
Y
Yu Yang 已提交
131
        raise TypeError('parameters should be a parameter pool')
Y
Yu Yang 已提交
132 133

    if not callable(event_handler):
Y
Yu Yang 已提交
134
        raise TypeError('event handler should be a function')