trainer.py 10.2 KB
Newer Older
Y
Yu Yang 已提交
1 2
import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig
3 4 5
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
from . import parameters as v2_parameters
import numpy
Y
Yu Yang 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter

__all__ = ['ITrainer', 'SGDTrainer', 'CompleteTrainOneBatch', 'BaseEvent']


class BaseEvent(object):
    """
    Just a marker class
    """
    pass


class CompleteTrainOneBatch(BaseEvent):
Y
Yu Yang 已提交
20 21 22 23
    """
    Event On One Batch Training Complete.
    """

24
    def __init__(self, pass_id, batch_id, cost, parameters):
Y
Yu Yang 已提交
25 26 27
        self.pass_id = pass_id
        self.batch_id = batch_id
        self.cost = cost
28
        self.paramters = parameters
Y
Yu Yang 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44


def default_event_handler(event):
    pass


class ITrainer(object):
    def train(self,
              train_data_reader,
              topology,
              parameters,
              test_data_reader=None,
              event_handler=None):
        raise NotImplementedError()


45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
class LazyParameterPool(v2_parameters.IParameterPool):
    """
    :type __gradient_machine__: api.GradientMachine
    """

    def get_parameter(self, name, flag=v2_parameters.ParameterFlag.READ_WRITE):
        param = filter(lambda x: x.getName() == name,
                       self.__gradient_machine__.getParameters())
        if len(param) == 0:
            raise ValueError("Cannot found parameter with name %s" % name)
        elif len(param) > 1:
            raise RuntimeError("Unexpected branch")
        else:
            conf = param[0].getConfig().toProto()
            param = param[0].getBuf(api.PARAMETER_VALUE)
            assert isinstance(param, api.Vector)
            assert isinstance(conf, ParameterConfig)

        shape = map(int, conf.dims)
        if api.isUsingGpu():
            arr = param.copyToNumpyArray().reshape(shape)
            if flag & v2_parameters.ParameterFlag.WRITE_ONLY:
                self.need_copy = True
                self.arrays[name] = arr
        else:
            arr = param.toNumpyArrayInplace().reshape(shape)
        return arr

    def get_names(self):
        return [
            param.getName()
            for param in self.__gradient_machine__.getParameters()
        ]

    def __init__(self, gradient_machine):
        self.__gradient_machine__ = gradient_machine
        self.need_copy = False
        self.arrays = dict()


class CustomizeUpdateEquation(object):
    def __init__(self, callback):
        self.__callback__ = callback
        if self.__callback__.func_code.co_argcount < 2:
            raise ValueError(
                "The update equation at least should contain 2 arguments, "
                "first is value, second is gradient")

        self.local_params_count = self.__callback__.func_code.co_argcount - 2
        self.local_params = dict()

    def enable_types(self):
        return [api.PARAMETER_VALUE, api.PARAMETER_GRADIENT]

    def init(self, gradient_machine):
        assert isinstance(gradient_machine, api.GradientMachine)
        for param in gradient_machine.getParameters():
            conf = param.getConfig().toProto()
            shape = map(int, conf.dims)
            self.local_params[conf.name] = []
            for _ in xrange(self.local_params_count):
                self.local_params[conf.name].append(
                    numpy.zeros(
                        shape=shape, dtype='float32'))

    def create_local_updater(self):
        return self

    def startPass(self):
        pass

    def finishPass(self):
        pass

    def startBatch(self, batch_size):
        return api.PASS_TRAIN

    def finishBatch(self, cost):
        pass

    def update(self, param):
        conf = param.getConfig().toProto()
        shape = map(int, conf.dims)
        if not api.isUsingGpu():
            v = param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace().reshape(
                shape)
            g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace(
            ).reshape(shape)
            args = [v, g]
            for arg in self.local_params[conf.name]:
                args.append(arg)
            self.__callback__(*args)
        else:
            raise NotImplementedError()


Y
Yu Yang 已提交
141 142
class SGDTrainer(ITrainer):
    def __init__(self, update_equation):
Y
Yu Yang 已提交
143 144 145 146 147
        """
        Simple SGD Trainer.

        :param update_equation: Maybe we should give a DSL for update equation?
        """
148 149
        if callable(update_equation):
            update_equation = CustomizeUpdateEquation(update_equation)
Y
Yu Yang 已提交
150 151 152 153 154 155 156 157 158 159 160 161

        self.__optimizer__ = update_equation

    def train(self,
              train_data_reader,
              topology,
              parameters,
              num_passes=1,
              test_data_reader=None,
              event_handler=None,
              batch_size=32,
              data_types=None):
Y
Yu Yang 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        """
        Training method. Will train num_passes of input data.

        :param train_data_reader:
        :param topology: Network Topology, a protobuf ModelConfig message.
        :param parameters: The parameter pools.
        :param num_passes: The total train passes.
        :param test_data_reader:
        :param event_handler: Event handler. A method will be invoked when event
                              occurred.
        :type event_handler: (BaseEvent) => None
        :param batch_size: Not important, will be removed after data refactor.
        :param data_types: Not important, will be removed after data refactor.
        :return:
        """
Y
Yu Yang 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189
        if event_handler is None:
            event_handler = default_event_handler

        __check_train_args__(**locals())

        gm = api.GradientMachine.createFromConfigProto(
            topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
        assert isinstance(gm, api.GradientMachine)
        __copy_parameter_from_pool__(gm, parameters)

        updater = self.__optimizer__.create_local_updater()
        updater.init(gm)

Y
Yu Yang 已提交
190 191 192
        gm.start()
        out_args = api.Arguments.createArguments(0)

Y
Yu Yang 已提交
193 194 195 196 197 198 199 200 201 202 203
        data_types_lists = []
        for each in topology.input_layer_names:
            if each not in data_types:
                raise ValueError()
            data_types_lists.append(data_types[each])

        converter = DataProviderConverter(input_types=data_types_lists)

        for pass_id in xrange(num_passes):
            updater.startPass()
            for batch_id, data_batch in enumerate(
Y
Yu Yang 已提交
204 205
                    __data_reader_to_batch__(train_data_reader, batch_size,
                                             topology)):
Y
Yu Yang 已提交
206 207 208 209 210 211 212 213 214
                pass_type = updater.startBatch(len(data_batch))
                gm.forwardBackward(converter(data_batch), out_args, pass_type)
                for each_param in gm.getParameters():
                    updater.update(each_param)
                # Get cost. We use numpy to calculate total cost for this batch.
                cost_vec = out_args.getSlotValue(0)
                cost_vec = cost_vec.copyToNumpyMat()
                cost = cost_vec.sum() / len(data_batch)
                updater.finishBatch(cost)
215
                pool = LazyParameterPool(gradient_machine=gm)
Y
Yu Yang 已提交
216 217
                event_handler(
                    CompleteTrainOneBatch(
218 219 220 221 222 223 224
                        pass_id=pass_id,
                        batch_id=batch_id,
                        cost=cost,
                        parameters=pool))

                if pool.need_copy:
                    __copy_parameter_from_lazy_pool__(gm, pool)
Y
Yu Yang 已提交
225 226 227 228 229

            updater.finishPass()
        gm.finish()


Y
Yu Yang 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
def __data_reader_to_batch__(reader, batch_size, topology):
    """
    This function is not important, and will be removed when data refactored.
    """

    def input_reorder(func):
        for item in func():
            retv = []
            for __layer_name__ in topology.input_layer_names:
                retv.append(item[__layer_name__])
            yield retv

    return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)


Y
Yu Yang 已提交
245
def __generator_to_batch__(generator, batch_size):
Y
Yu Yang 已提交
246 247 248
    """
    This function is not important, and will be removed when data refactored.
    """
Y
Yu Yang 已提交
249 250 251 252 253 254 255 256 257 258
    ret_val = list()
    for each_item in generator:
        ret_val.append(each_item)
        if len(ret_val) == batch_size:
            yield ret_val
            ret_val = list()
    if len(ret_val) != 0:
        yield ret_val


259 260 261 262 263 264 265 266 267 268 269
def __copy_parameter_from_lazy_pool__(gm, pool):
    assert isinstance(pool, LazyParameterPool)
    for each_param_name in pool.arrays.keys():
        param = filter(lambda x: x.getName() == each_param_name,
                       gm.getParameters())
        assert len(param) == 1
        param = param[0]
        param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(pool.arrays[
            each_param_name].flatten().astype('float32'))


Y
Yu Yang 已提交
270 271 272 273 274 275
def __copy_parameter_from_pool__(gm, pool):
    """

    :param gm:
    :type gm: api.GradientMachine
    :param pool:
276
    :type pool: v2_parameters.IParameterPool
Y
Yu Yang 已提交
277 278
    :return:
    """
279
    assert isinstance(pool, v2_parameters.IParameterPool)
Y
Yu Yang 已提交
280 281
    for each_param in gm.getParameters():
        name = each_param.getName()
282
        param = pool.get_parameter(name, v2_parameters.ParameterFlag.READ_ONLY)
Y
Yu Yang 已提交
283 284 285 286 287 288
        each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(param.flatten(
        ).astype('float32'))


def __check_train_args__(train_data_reader, topology, parameters,
                         test_data_reader, event_handler, **kwargs):
Y
Yu Yang 已提交
289 290 291
    """
    Check train function's argument types
    """
Y
Yu Yang 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305
    if not callable(train_data_reader) or not isinstance(train_data_reader(),
                                                         collections.Iterator):
        raise ValueError('train_data_reader should be a function, '
                         'which can return a iterator')

    if test_data_reader is not None:
        if not callable(test_data_reader) or not isinstance(
                test_data_reader(), collections.Iterator):
            raise ValueError('test_data_reader should be a function, which can '
                             'return a iterator')

    if not isinstance(topology, ModelConfig):
        raise ValueError('topology should be a model config')

306
    if not isinstance(parameters, v2_parameters.IParameterPool):
Y
Yu Yang 已提交
307 308 309 310
        raise ValueError('parameters should be a parameter pool')

    if not callable(event_handler):
        raise ValueError('event handler should be a function')