trainer.py 9.0 KB
Newer Older
Y
Yu Yang 已提交
1 2
import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig
3 4
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
from . import parameters as v2_parameters
Y
Yu Yang 已提交
5
from . import optimizer as v2_optimizer
Y
Yu Yang 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter

__all__ = ['ITrainer', 'SGDTrainer', 'CompleteTrainOneBatch', 'BaseEvent']


class BaseEvent(object):
    """
    Just a marker class
    """
    pass


class CompleteTrainOneBatch(BaseEvent):
Y
Yu Yang 已提交
20 21 22 23
    """
    Event On One Batch Training Complete.
    """

24
    def __init__(self, pass_id, batch_id, cost, parameters):
Y
Yu Yang 已提交
25 26 27
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
Y
Yu Yang 已提交
28
        self.parameters = parameters
Y
Yu Yang 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44


def default_event_handler(event):
    pass


class ITrainer(object):
    def train(self,
              train_data_reader,
              topology,
              parameters,
              test_data_reader=None,
              event_handler=None):
        raise NotImplementedError()


45 46
class LazyParameterPool(v2_parameters.IParameterPool):
    """
Y
Yu Yang 已提交
47 48 49 50 51 52 53 54 55 56 57
    Lazy Parameter Pool stores a reference to GradientMachine. User could invoke
    `get_parameter` if needed, but the operation is lazy. It means the parameter
    will only fetched from GPU or Parameter Server if `get_parameter` is
    invoked. Also, set flag = writable will make a extra host2device copy after
    reading/modifying parameter.

    This class is not exposed to User. User should treat this class as a normal
    IParameterPool.

    See IParameterPool for usage documentation.

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    :type __gradient_machine__: api.GradientMachine
    """

    def get_parameter(self, name, flag=v2_parameters.ParameterFlag.READ_WRITE):
        param = filter(lambda x: x.getName() == name,
                       self.__gradient_machine__.getParameters())
        if len(param) == 0:
            raise ValueError("Cannot found parameter with name %s" % name)
        elif len(param) > 1:
            raise RuntimeError("Unexpected branch")
        else:
            conf = param[0].getConfig().toProto()
            param = param[0].getBuf(api.PARAMETER_VALUE)
            assert isinstance(param, api.Vector)
            assert isinstance(conf, ParameterConfig)

        shape = map(int, conf.dims)
        if api.isUsingGpu():
            arr = param.copyToNumpyArray().reshape(shape)
            if flag & v2_parameters.ParameterFlag.WRITE_ONLY:
                self.need_copy = True
                self.arrays[name] = arr
        else:
            arr = param.toNumpyArrayInplace().reshape(shape)
        return arr

    def get_names(self):
        return [
            param.getName()
            for param in self.__gradient_machine__.getParameters()
        ]

    def __init__(self, gradient_machine):
        self.__gradient_machine__ = gradient_machine
        self.need_copy = False
        self.arrays = dict()


Y
Yu Yang 已提交
96 97
class SGDTrainer(ITrainer):
    def __init__(self, update_equation):
Y
Yu Yang 已提交
98 99 100 101 102
        """
        Simple SGD Trainer.

        :param update_equation: Maybe we should give a DSL for update equation?
        """
Y
Yu Yang 已提交
103 104 105
        if not isinstance(update_equation, v2_optimizer.Optimizer):
            raise ValueError("update equation parameter must be "
                             "paddle.v2.optimizer.Optimizer")
Y
Yu Yang 已提交
106 107 108 109 110 111 112 113 114 115 116
        self.__optimizer__ = update_equation

    def train(self,
              train_data_reader,
              topology,
              parameters,
              num_passes=1,
              test_data_reader=None,
              event_handler=None,
              batch_size=32,
              data_types=None):
Y
Yu Yang 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        """
        Training method. Will train num_passes of input data.

        :param train_data_reader:
        :param topology: Network Topology, a protobuf ModelConfig message.
        :param parameters: The parameter pools.
        :param num_passes: The total train passes.
        :param test_data_reader:
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
        :param batch_size: Not important, will be removed after data refactor.
        :param data_types: Not important, will be removed after data refactor.
        :return:
        """
Y
Yu Yang 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144
        if event_handler is None:
            event_handler = default_event_handler

        __check_train_args__(**locals())

        gm = api.GradientMachine.createFromConfigProto(
            topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
        __copy_parameter_from_pool__(gm, parameters)

        updater = self.__optimizer__.create_local_updater()
        updater.init(gm)

Y
Yu Yang 已提交
145 146 147
        gm.start()
        out_args = api.Arguments.createArguments(0)

Y
Yu Yang 已提交
148 149 150 151 152 153 154 155 156 157 158
        data_types_lists = []
        for each in topology.input_layer_names:
            if each not in data_types:
                raise ValueError()
            data_types_lists.append(data_types[each])

        converter = DataProviderConverter(input_types=data_types_lists)

        for pass_id in xrange(num_passes):
            updater.startPass()
            for batch_id, data_batch in enumerate(
Y
Yu Yang 已提交
159 160
                    __data_reader_to_batch__(train_data_reader, batch_size,
                                             topology)):
Y
Yu Yang 已提交
161 162 163 164 165 166 167 168 169
                pass_type = updater.startBatch(len(data_batch))
                gm.forwardBackward(converter(data_batch), out_args, pass_type)
                for each_param in gm.getParameters():
                    updater.update(each_param)
                # Get cost. We use numpy to calculate total cost for this batch.
                cost_vec = out_args.getSlotValue(0)
                cost_vec = cost_vec.copyToNumpyMat()
                cost = cost_vec.sum() / len(data_batch)
                updater.finishBatch(cost)
170
                pool = LazyParameterPool(gradient_machine=gm)
Y
Yu Yang 已提交
171 172
                event_handler(
                    CompleteTrainOneBatch(
173 174 175 176 177 178 179
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
                        parameters=pool))

                if pool.need_copy:
                    __copy_parameter_from_lazy_pool__(gm, pool)
Y
Yu Yang 已提交
180 181 182 183 184

            updater.finishPass()
        gm.finish()


Y
Yu Yang 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
def __data_reader_to_batch__(reader, batch_size, topology):
    """
    This function is not important, and will be removed when data refactored.
    """

    def input_reorder(func):
        for item in func():
            retv = []
            for __layer_name__ in topology.input_layer_names:
                retv.append(item[__layer_name__])
            yield retv

    return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)


Y
Yu Yang 已提交
200
def __generator_to_batch__(generator, batch_size):
Y
Yu Yang 已提交
201 202 203
    """
    This function is not important, and will be removed when data refactored.
    """
Y
Yu Yang 已提交
204 205 206 207 208 209 210 211 212 213
    ret_val = list()
    for each_item in generator:
        ret_val.append(each_item)
        if len(ret_val) == batch_size:
            yield ret_val
            ret_val = list()
    if len(ret_val) != 0:
        yield ret_val


214 215 216 217 218 219 220 221 222 223 224
def __copy_parameter_from_lazy_pool__(gm, pool):
    assert isinstance(pool, LazyParameterPool)
    for each_param_name in pool.arrays.keys():
        param = filter(lambda x: x.getName() == each_param_name,
                       gm.getParameters())
        assert len(param) == 1
        param = param[0]
        param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(pool.arrays[
            each_param_name].flatten().astype('float32'))


Y
Yu Yang 已提交
225 226 227 228 229 230
def __copy_parameter_from_pool__(gm, pool):
    """

    :param gm:
    :type gm: api.GradientMachine
    :param pool:
231
    :type pool: v2_parameters.IParameterPool
Y
Yu Yang 已提交
232 233
    :return:
    """
234
    assert isinstance(pool, v2_parameters.IParameterPool)
Y
Yu Yang 已提交
235 236
    for each_param in gm.getParameters():
        name = each_param.getName()
237
        param = pool.get_parameter(name, v2_parameters.ParameterFlag.READ_ONLY)
Y
Yu Yang 已提交
238 239 240 241 242 243
        each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(param.flatten(
        ).astype('float32'))


def __check_train_args__(train_data_reader, topology, parameters,
                         test_data_reader, event_handler, **kwargs):
Y
Yu Yang 已提交
244 245 246
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260
    if not callable(train_data_reader) or not isinstance(train_data_reader(),
                                                         collections.Iterator):
        raise ValueError('train_data_reader should be a function, '
                         'which can return a iterator')

    if test_data_reader is not None:
        if not callable(test_data_reader) or not isinstance(
                test_data_reader(), collections.Iterator):
            raise ValueError('test_data_reader should be a function, which can '
                             'return a iterator')

    if not isinstance(topology, ModelConfig):
        raise ValueError('topology should be a model config')

261
    if not isinstance(parameters, v2_parameters.IParameterPool):
Y
Yu Yang 已提交
262 263 264 265
        raise ValueError('parameters should be a parameter pool')

    if not callable(event_handler):
        raise ValueError('event handler should be a function')