parameters.py 11.2 KB
Newer Older
Y
Yu Yang 已提交
1
import numpy as np
Q
qiaolongfei 已提交
2
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
X
xuwei06 已提交
3
import paddle.trainer.config_parser as cp
Y
Yu Yang 已提交
4 5 6
import struct
import tarfile
import cStringIO
Q
qiaolongfei 已提交
7
from topology import Topology
Q
qiaolongfei 已提交
8

Y
Yu Yang 已提交
9
__all__ = ['Parameters', 'create']
Y
Yu Yang 已提交
10 11


Q
qiaolongfei 已提交
12
def create(layers):
Y
Yu Yang 已提交
13
    """
Q
qiaolongfei 已提交
14
    Create parameter pool by topology.
Y
Yu Yang 已提交
15

Q
qiaolongfei 已提交
16
    :param layers:
Y
Yu Yang 已提交
17
    :return:
Y
Yu Yang 已提交
18
    """
Q
qiaolongfei 已提交
19
    topology = Topology(layers)
Q
qiaolongfei 已提交
20
    pool = Parameters()
X
xuwei06 已提交
21
    initializers = cp.g_parameter_initializer_map
Q
qiaolongfei 已提交
22
    for param in topology.proto().parameters:
Q
qiaolongfei 已提交
23
        pool.__append_config__(param)
X
xuwei06 已提交
24 25
        if param.name in initializers:
            pool[param.name] = initializers[param.name](param.name)
Y
Yu Yang 已提交
26
    return pool
Y
Yu Yang 已提交
27 28


Y
Yu Yang 已提交
29
class Parameters(object):
Y
Yu Yang 已提交
30
    """
Y
Yu Yang 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
    Parameters is a dictionary contains Paddle's parameter. The key of
    Parameters is the name of parameter. The value of Parameters is a plain
    :code:`numpy.ndarry` .

    Basically usage is

    ..  code-block:: python

        data = paddle.layers.data(...)
        ...
        out = paddle.layers.fc(...)

        parameters = paddle.parameters.create(out)

        parameter_names = parameters.names()
        fc_mat = parameters.get('fc')
        print fc_mat
Y
Yu Yang 已提交
48 49
    """

Y
Yu Yang 已提交
50
    def __init__(self):
Y
Yu Yang 已提交
51 52
        self.__param_conf__ = dict()
        self.__gradient_machines__ = []
53
        self.__tmp_params__ = dict()
Y
Yu Yang 已提交
54

Y
Yu Yang 已提交
55 56 57 58 59 60 61 62 63 64
    def __append_config__(self, param_conf):
        """
        Append a parameter configuration. It used to initialize Parameters and
        should be invoked only in paddle.parameters.create

        :param param_conf: The parameter configuration in protobuf
        :type param_conf: ParameterConfig
        :return: Nothing
        """

Y
Yu Yang 已提交
65 66 67 68 69 70 71 72 73
        if not isinstance(param_conf, ParameterConfig):
            raise ValueError("param_conf must be paddle.proto.ParameterConfig")

        if param_conf.name in self.__param_conf__:
            raise ValueError("duplicated parameter %s" % param_conf.name)

        self.__param_conf__[param_conf.name] = param_conf

    def keys(self):
Y
Yu Yang 已提交
74 75
        """
        keys are the names of each parameter.
Y
Yu Yang 已提交
76

Y
Yu Yang 已提交
77 78 79
        :return: list of parameter name
        :rtype: list
        """
Y
Yu Yang 已提交
80 81 82
        return self.__param_conf__.keys()

    def names(self):
Y
Yu Yang 已提交
83 84
        """
        names of each parameter.
Y
Yu Yang 已提交
85

Y
Yu Yang 已提交
86 87 88
        :return: list of parameter name
        :rtype: list
        """
Y
Yu Yang 已提交
89 90 91
        return self.keys()

    def has_key(self, key):
Y
Yu Yang 已提交
92 93
        """
        has_key return true if there are such parameter name == key
Y
Yu Yang 已提交
94

Y
Yu Yang 已提交
95 96 97 98
        :param key: Parameter name
        :type key: basestring
        :return: True if contains such key
        """
Y
Yu Yang 已提交
99 100
        return key in self.__param_conf__.keys()

Y
Yu Yang 已提交
101
    def __iter__(self):
Y
Yu Yang 已提交
102 103 104 105 106 107 108 109 110 111 112 113
        """
        Return an iterator of parameter name. It is used by `for loop`
        or `in` operator.

        ..  code-block:: python

            parameters = paddle.parameters.create(...)
            if "fc_param" in parameters:
                print 'OK'
        :return: an iterator of parameter name
        :rtype: iterator
        """
Y
Yu Yang 已提交
114 115
        return iter(self.__param_conf__)

Y
Yu Yang 已提交
116
    def __getitem__(self, key):
Y
Yu Yang 已提交
117 118 119 120 121 122 123 124 125
        """
        Get parameter by parameter name. It uses Python dict syntax.

        :note: It will always copy the parameter from C++ side.
        :param key: Parameter name
        :type key: basestring
        :return: parameter value
        :rtype: np.ndarray
        """
Y
Yu Yang 已提交
126
        import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
127 128 129 130
        shape = self.get_shape(key)

        if len(self.__gradient_machines__) == 0:
            # create new parameter in python numpy.
131 132 133 134
            if key in self.__tmp_params__:
                return self.__tmp_params__[key]
            else:
                return np.ndarray(shape=shape, dtype=np.float32)
Y
Yu Yang 已提交
135 136 137 138 139 140 141 142
        else:
            for each_gradient_machine in self.__gradient_machines__:
                param = __get_parameter_in_gradient_machine__(
                    each_gradient_machine, key)
                # for simplify implementation now, we always copy from C++
                assert isinstance(param, api.Parameter)
                val = param.getBuf(api.PARAMETER_VALUE)
                assert isinstance(val, api.Vector)
Y
Yu Yang 已提交
143 144
                val = val.copyToNumpyArray()
                return val
Y
Yu Yang 已提交
145 146 147 148 149
                # else continue

            raise RuntimeError("Unexpected branch")

    def get_shape(self, key):
Y
Yu Yang 已提交
150 151
        """
        get shape of the parameter.
Y
Yu Yang 已提交
152

Y
Yu Yang 已提交
153 154 155 156 157
        :param key: parameter name
        :type key: basestring
        :return: parameter's shape
        :rtype: tuple
        """
Y
Yu Yang 已提交
158 159 160 161 162
        if not isinstance(key, basestring):
            raise ValueError("parameter name should be string")
        if not self.has_key(key):
            raise ValueError("No such parameter %s" % key)
        conf = self.__param_conf__[key]
D
dangqingqing 已提交
163 164
        dims = conf.dims if conf.dims else (1, conf.size)
        return tuple(map(int, dims))
Y
Yu Yang 已提交
165 166

    def __setitem__(self, key, value):
Y
Yu Yang 已提交
167 168 169 170 171 172 173 174 175 176 177
        """
        Set parameter by parameter name & value. It use Python dict syntax.

        :note: It will always copy the parameter to C++ side.
        :param key: Parameter name
        :type key: basestring
        :param value: Parameter matrix.
        :type value: np.ndarray
        :return: Nothing
        """

Y
Yu Yang 已提交
178 179 180 181
        if not isinstance(value, np.ndarray):
            raise ValueError("Must return ndarray")
        value = value.astype(dtype=np.float32)
        shape = self.get_shape(key)
Y
Yu Yang 已提交
182
        if value.shape != shape:
Y
Yu Yang 已提交
183 184 185 186
            raise ValueError("Value shape mismatch, expect %s, should %s" %
                             (shape, value.shape))

        if len(self.__gradient_machines__) == 0:
187
            self.__tmp_params__[key] = value
Y
Yu Yang 已提交
188 189 190 191 192
        else:
            for each_gradient_machine in self.__gradient_machines__:
                __copy_parameter_to_gradient_machine__(each_gradient_machine,
                                                       key, value)

Y
Yu Yang 已提交
193
    def get(self, parameter_name):
Y
Yu Yang 已提交
194 195 196 197 198 199 200 201 202
        """
        Get parameter by parameter name.

        :note: It will always copy the parameter from C++ side.
        :param parameter_name: parameter name
        :type parameter_name: basestring
        :return: The parameter matrix.
        :rtype: np.ndarray
        """
Y
Yu Yang 已提交
203 204 205
        return self.__getitem__(key=parameter_name)

    def set(self, parameter_name, value):
Y
Yu Yang 已提交
206 207
        """
        Set parameter by parameter name & matrix.
Y
Yu Yang 已提交
208

Y
Yu Yang 已提交
209 210 211 212 213 214
        :param parameter_name: parameter name
        :type parameter_name: basestring
        :param value: parameter matrix
        :type value: np.ndarray
        :return: Nothing.
        """
Y
Yu Yang 已提交
215 216
        self.__setitem__(key=parameter_name, value=value)

Y
Yu Yang 已提交
217
    def append_gradient_machine(self, gradient_machine):
Y
Yu Yang 已提交
218 219 220 221 222 223 224 225
        """
        append gradient machine to parameters. This method is used internally in
        Trainer.train.

        :param gradient_machine: Paddle C++ GradientMachine object.
        :type gradient_machine: api.GradientMachine
        :return:
        """
Y
Yu Yang 已提交
226
        import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
227 228 229 230
        if not isinstance(gradient_machine, api.GradientMachine):
            raise ValueError("gradient_machine should be api.GradientMachine")

        if len(self.__tmp_params__) != 0:
231
            for name, val in self.__tmp_params__.iteritems():
Y
Yu Yang 已提交
232 233 234 235 236 237
                try:
                    __copy_parameter_to_gradient_machine__(gradient_machine,
                                                           name, val)
                except ValueError:
                    # If no such parameter in gradient machine, then don't copy
                    pass
238 239

        self.__gradient_machines__.append(gradient_machine)
Y
Yu Yang 已提交
240

Y
Yu Yang 已提交
241 242 243 244 245 246 247 248 249 250 251 252
    def serialize(self, name, f):
        """

        :param name:
        :param f:
        :type f: file
        :return:
        """
        param = self.get(name)
        size = reduce(lambda a, b: a * b, param.shape)
        f.write(struct.pack("IIQ", 0, 4, size))
        param = param.astype(np.float32)
253
        f.write(param.tostring())
Y
Yu Yang 已提交
254 255 256 257 258 259 260 261 262 263

    def deserialize(self, name, f):
        """

        :param name:
        :param f:
        :type f: file
        :return:
        """
        f.read(16)  # header
Y
Yu Yang 已提交
264
        arr = np.frombuffer(f.read(), dtype=np.float32)
Y
Yu Yang 已提交
265 266
        self.set(name, arr.reshape(self.get_shape(name)))

Y
Yu Yang 已提交
267
    def to_tar(self, f):
Y
Yu Yang 已提交
268 269 270 271 272 273 274 275 276
        tar = tarfile.TarFile(fileobj=f, mode='w')
        for nm in self.names():
            buf = cStringIO.StringIO()
            self.serialize(nm, buf)
            tarinfo = tarfile.TarInfo(name=nm)
            buf.seek(0)
            tarinfo.size = len(buf.getvalue())
            tar.addfile(tarinfo, buf)

Y
Yu Yang 已提交
277 278 279 280 281 282 283 284 285 286
            conf = self.__param_conf__[nm]
            confStr = conf.SerializeToString()
            tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm)
            tarinfo.size = len(confStr)
            buf = cStringIO.StringIO(confStr)
            buf.seek(0)
            tar.addfile(tarinfo, fileobj=buf)

    @staticmethod
    def from_tar(f):
D
dangqingqing 已提交
287 288 289 290 291 292 293 294 295 296 297 298
        """
        Create a `Parameters` object from the given file. And
        the `Parameters` only contains the parameters in this
        file. It is adapted the parameters are same in the
        defined network and the given file. For example, it
        can be used in the inference.

        :param f: the initialized model file.
        :type f: tar file
        :return: A Parameters object.
        :rtype: Parameters.
        """
Y
Yu Yang 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312
        params = Parameters()
        tar = tarfile.TarFile(fileobj=f, mode='r')
        for finfo in tar:
            assert isinstance(finfo, tarfile.TarInfo)
            if finfo.name.endswith('.protobuf'):
                f = tar.extractfile(finfo)
                conf = ParameterConfig()
                conf.ParseFromString(f.read())
                params.__append_config__(conf)

        for param_name in params.names():
            f = tar.extractfile(param_name)
            params.deserialize(param_name, f)
        return params
Y
Yu Yang 已提交
313

314
    def init_from_tar(self, f):
D
dangqingqing 已提交
315 316 317 318 319 320 321 322 323
        """
        Different from `from_tar`, this interface can be used to
        init partial network parameters from another saved model.

        :param f: the initialized model file.
        :type f: tar file
        :return: Nothing.
        """

324
        tar_param = Parameters.from_tar(f)
325 326 327 328
        for pname in tar_param.names():
            if pname in self.names():
                self.set(pname, tar_param.get(pname))

Y
Yu Yang 已提交
329 330 331

def __get_parameter_in_gradient_machine__(gradient_machine, name):
    """
Y
Yu Yang 已提交
332

Y
Yu Yang 已提交
333 334 335 336 337 338 339 340
    :param gradient_machine:
    :type gradient_machine: api.GradientMachine
    :param name:
    :return:
    :rtype: api.Parameter
    """
    params = filter(lambda p: p.getName() == name,
                    gradient_machine.getParameters())
Y
Yu Yang 已提交
341

Y
Yu Yang 已提交
342 343 344 345 346 347
    if len(params) == 0:
        raise ValueError("No such parameter")
    elif len(params) > 1:
        raise ValueError("Unexpected branch")
    else:
        return params[0]
Y
Yu Yang 已提交
348 349


Y
Yu Yang 已提交
350
def __copy_parameter_to_gradient_machine__(gradient_machine, name, arr):
Y
Yu Yang 已提交
351
    """
Y
Yu Yang 已提交
352
    Copy a python ndarray into the gradient machine.
Y
Yu Yang 已提交
353

Y
Yu Yang 已提交
354 355 356 357 358
    :param gradient_machine:
    :type gradient_machine: api.GradientMachine
    :param name:
    :param arr:
    :type arr: np.ndarray
Y
Yu Yang 已提交
359
    :return:
Y
Yu Yang 已提交
360
    :rtype: api.Parameter
Y
Yu Yang 已提交
361
    """
Y
Yu Yang 已提交
362
    import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
363 364 365 366
    param = __get_parameter_in_gradient_machine__(gradient_machine, name)
    vec = param.getBuf(api.PARAMETER_VALUE)
    assert isinstance(vec, api.Vector)
    vec.copyFromNumpyArray(arr.flatten())