parameters.py 9.7 KB
Newer Older
Y
Yu Yang 已提交
1
import numpy as np
Y
Yu Yang 已提交
2
import py_paddle.swig_paddle as api
Q
qiaolongfei 已提交
3
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
Y
Yu Yang 已提交
4 5 6
import struct
import tarfile
import cStringIO
Q
qiaolongfei 已提交
7
from topology import Topology
Q
qiaolongfei 已提交
8

Y
Yu Yang 已提交
9
__all__ = ['Parameters', 'create']
Y
Yu Yang 已提交
10 11


Q
qiaolongfei 已提交
12
def create(layers):
Y
Yu Yang 已提交
13
    """
Q
qiaolongfei 已提交
14
    Create parameter pool by topology.
Q
qiaolongfei 已提交
15
    :param layers:
Y
Yu Yang 已提交
16
    :return:
Y
Yu Yang 已提交
17
    """
Q
qiaolongfei 已提交
18
    topology = Topology(layers)
Q
qiaolongfei 已提交
19
    pool = Parameters()
Q
qiaolongfei 已提交
20
    for param in topology.proto().parameters:
Q
qiaolongfei 已提交
21
        pool.__append_config__(param)
Y
Yu Yang 已提交
22
    return pool
Y
Yu Yang 已提交
23 24


Y
Yu Yang 已提交
25
class Parameters(object):
Y
Yu Yang 已提交
26
    """
Y
Yu Yang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    Parameters is a dictionary contains Paddle's parameter. The key of
    Parameters is the name of parameter. The value of Parameters is a plain
    :code:`numpy.ndarry` .

    Basically usage is

    ..  code-block:: python

        data = paddle.layers.data(...)
        ...
        out = paddle.layers.fc(...)

        parameters = paddle.parameters.create(out)

        parameter_names = parameters.names()
        fc_mat = parameters.get('fc')
        print fc_mat
Y
Yu Yang 已提交
44 45
    """

Y
Yu Yang 已提交
46
    def __init__(self):
Y
Yu Yang 已提交
47 48 49 50
        self.__param_conf__ = dict()
        self.__gradient_machines__ = []
        self.__tmp_params__ = []

Y
Yu Yang 已提交
51 52 53 54 55 56 57 58 59 60
    def __append_config__(self, param_conf):
        """
        Append a parameter configuration. It used to initialize Parameters and
        should be invoked only in paddle.parameters.create

        :param param_conf: The parameter configuration in protobuf
        :type param_conf: ParameterConfig
        :return: Nothing
        """

Y
Yu Yang 已提交
61 62 63 64 65 66 67 68 69
        if not isinstance(param_conf, ParameterConfig):
            raise ValueError("param_conf must be paddle.proto.ParameterConfig")

        if param_conf.name in self.__param_conf__:
            raise ValueError("duplicated parameter %s" % param_conf.name)

        self.__param_conf__[param_conf.name] = param_conf

    def keys(self):
Y
Yu Yang 已提交
70 71 72 73 74
        """
        keys are the names of each parameter.
        :return: list of parameter name
        :rtype: list
        """
Y
Yu Yang 已提交
75 76 77
        return self.__param_conf__.keys()

    def names(self):
Y
Yu Yang 已提交
78 79 80 81 82
        """
        names of each parameter.
        :return: list of parameter name
        :rtype: list
        """
Y
Yu Yang 已提交
83 84 85
        return self.keys()

    def has_key(self, key):
Y
Yu Yang 已提交
86 87 88 89 90 91
        """
        has_key return true if there are such parameter name == key
        :param key: Parameter name
        :type key: basestring
        :return: True if contains such key
        """
Y
Yu Yang 已提交
92 93
        return key in self.__param_conf__.keys()

Y
Yu Yang 已提交
94
    def __iter__(self):
Y
Yu Yang 已提交
95 96 97 98 99 100 101 102 103 104 105 106
        """
        Return an iterator of parameter name. It is used by `for loop`
        or `in` operator.

        ..  code-block:: python

            parameters = paddle.parameters.create(...)
            if "fc_param" in parameters:
                print 'OK'
        :return: an iterator of parameter name
        :rtype: iterator
        """
Y
Yu Yang 已提交
107 108
        return iter(self.__param_conf__)

Y
Yu Yang 已提交
109
    def __getitem__(self, key):
Y
Yu Yang 已提交
110 111 112 113 114 115 116 117 118
        """
        Get parameter by parameter name. It uses Python dict syntax.

        :note: It will always copy the parameter from C++ side.
        :param key: Parameter name
        :type key: basestring
        :return: parameter value
        :rtype: np.ndarray
        """
Y
Yu Yang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131
        shape = self.get_shape(key)

        if len(self.__gradient_machines__) == 0:
            # create new parameter in python numpy.
            return np.ndarray(shape=shape, dtype=np.float32)
        else:
            for each_gradient_machine in self.__gradient_machines__:
                param = __get_parameter_in_gradient_machine__(
                    each_gradient_machine, key)
                # for simplify implementation now, we always copy from C++
                assert isinstance(param, api.Parameter)
                val = param.getBuf(api.PARAMETER_VALUE)
                assert isinstance(val, api.Vector)
Y
Yu Yang 已提交
132 133
                val = val.copyToNumpyArray()
                return val
Y
Yu Yang 已提交
134 135 136 137 138
                # else continue

            raise RuntimeError("Unexpected branch")

    def get_shape(self, key):
Y
Yu Yang 已提交
139 140 141 142 143 144 145
        """
        get shape of the parameter.
        :param key: parameter name
        :type key: basestring
        :return: parameter's shape
        :rtype: tuple
        """
Y
Yu Yang 已提交
146 147 148 149 150
        if not isinstance(key, basestring):
            raise ValueError("parameter name should be string")
        if not self.has_key(key):
            raise ValueError("No such parameter %s" % key)
        conf = self.__param_conf__[key]
Y
Yu Yang 已提交
151
        return tuple(map(int, conf.dims))
Y
Yu Yang 已提交
152 153

    def __setitem__(self, key, value):
Y
Yu Yang 已提交
154 155 156 157 158 159 160 161 162 163 164
        """
        Set parameter by parameter name & value. It use Python dict syntax.

        :note: It will always copy the parameter to C++ side.
        :param key: Parameter name
        :type key: basestring
        :param value: Parameter matrix.
        :type value: np.ndarray
        :return: Nothing
        """

Y
Yu Yang 已提交
165 166 167 168
        if not isinstance(value, np.ndarray):
            raise ValueError("Must return ndarray")
        value = value.astype(dtype=np.float32)
        shape = self.get_shape(key)
Y
Yu Yang 已提交
169
        if value.shape != shape:
Y
Yu Yang 已提交
170 171 172 173 174 175 176 177 178 179
            raise ValueError("Value shape mismatch, expect %s, should %s" %
                             (shape, value.shape))

        if len(self.__gradient_machines__) == 0:
            self.__tmp_params__.append((key, value))
        else:
            for each_gradient_machine in self.__gradient_machines__:
                __copy_parameter_to_gradient_machine__(each_gradient_machine,
                                                       key, value)

Y
Yu Yang 已提交
180
    def get(self, parameter_name):
Y
Yu Yang 已提交
181 182 183 184 185 186 187 188 189
        """
        Get parameter by parameter name.

        :note: It will always copy the parameter from C++ side.
        :param parameter_name: parameter name
        :type parameter_name: basestring
        :return: The parameter matrix.
        :rtype: np.ndarray
        """
Y
Yu Yang 已提交
190 191 192
        return self.__getitem__(key=parameter_name)

    def set(self, parameter_name, value):
Y
Yu Yang 已提交
193 194 195 196 197 198 199 200
        """
        Set parameter by parameter name & matrix.
        :param parameter_name: parameter name
        :type parameter_name: basestring
        :param value: parameter matrix
        :type value: np.ndarray
        :return: Nothing.
        """
Y
Yu Yang 已提交
201 202
        self.__setitem__(key=parameter_name, value=value)

Y
Yu Yang 已提交
203
    def append_gradient_machine(self, gradient_machine):
Y
Yu Yang 已提交
204 205 206 207 208 209 210 211 212
        """
        append gradient machine to parameters. This method is used internally in
        Trainer.train.

        :param gradient_machine: Paddle C++ GradientMachine object.
        :type gradient_machine: api.GradientMachine
        :return:
        """

Y
Yu Yang 已提交
213 214 215 216 217 218 219 220 221 222 223
        if not isinstance(gradient_machine, api.GradientMachine):
            raise ValueError("gradient_machine should be api.GradientMachine")

        if len(self.__tmp_params__) != 0:
            for name, val in self.__tmp_params__:
                try:
                    __copy_parameter_to_gradient_machine__(gradient_machine,
                                                           name, val)
                except ValueError:
                    # If no such parameter in gradient machine, then don't copy
                    pass
224 225

        self.__gradient_machines__.append(gradient_machine)
Y
Yu Yang 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239

    def __getstate__(self):
        params = {}
        for name in self.names():
            params[name] = self.get(name)

        param_conf = {}
        for name in self.__param_conf__:
            conf = self.__param_conf__[name]
            assert isinstance(conf, ParameterConfig)
            param_conf[name] = conf.SerializeToString()

        return {'conf': param_conf, 'params': params}

Y
Yu Yang 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    def serialize(self, name, f):
        """

        :param name:
        :param f:
        :type f: file
        :return:
        """
        param = self.get(name)
        size = reduce(lambda a, b: a * b, param.shape)
        f.write(struct.pack("IIQ", 0, 4, size))
        param = param.astype(np.float32)
        f.write(param.tobytes())

    def deserialize(self, name, f):
        """

        :param name:
        :param f:
        :type f: file
        :return:
        """
        f.read(16)  # header
        arr = np.fromfile(f, dtype=np.float32)
        self.set(name, arr.reshape(self.get_shape(name)))

    def serialize_to_tar(self, f):
        tar = tarfile.TarFile(fileobj=f, mode='w')
        for nm in self.names():
            buf = cStringIO.StringIO()
            self.serialize(nm, buf)
            tarinfo = tarfile.TarInfo(name=nm)
            buf.seek(0)
            tarinfo.size = len(buf.getvalue())
            tar.addfile(tarinfo, buf)

Y
Yu Yang 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288
    def __setstate__(self, obj):
        Parameters.__init__(self)

        def __impl__(conf, params):
            for name in conf:
                p = ParameterConfig()
                p.ParseFromString(conf[name])
                self.__append_config__(p)
            for name in params:
                shape = self.get_shape(name)
                self.set(name, params[name].reshape(shape))

        __impl__(**obj)
Y
Yu Yang 已提交
289 290 291 292


def __get_parameter_in_gradient_machine__(gradient_machine, name):
    """
Y
Yu Yang 已提交
293

Y
Yu Yang 已提交
294 295 296 297 298 299 300 301
    :param gradient_machine:
    :type gradient_machine: api.GradientMachine
    :param name:
    :return:
    :rtype: api.Parameter
    """
    params = filter(lambda p: p.getName() == name,
                    gradient_machine.getParameters())
Y
Yu Yang 已提交
302

Y
Yu Yang 已提交
303 304 305 306 307 308
    if len(params) == 0:
        raise ValueError("No such parameter")
    elif len(params) > 1:
        raise ValueError("Unexpected branch")
    else:
        return params[0]
Y
Yu Yang 已提交
309 310


Y
Yu Yang 已提交
311
def __copy_parameter_to_gradient_machine__(gradient_machine, name, arr):
Y
Yu Yang 已提交
312
    """
Y
Yu Yang 已提交
313
    Copy a python ndarray into the gradient machine.
Y
Yu Yang 已提交
314

Y
Yu Yang 已提交
315 316 317 318 319
    :param gradient_machine:
    :type gradient_machine: api.GradientMachine
    :param name:
    :param arr:
    :type arr: np.ndarray
Y
Yu Yang 已提交
320
    :return:
Y
Yu Yang 已提交
321
    :rtype: api.Parameter
Y
Yu Yang 已提交
322
    """
Y
Yu Yang 已提交
323 324 325 326
    param = __get_parameter_in_gradient_machine__(gradient_machine, name)
    vec = param.getBuf(api.PARAMETER_VALUE)
    assert isinstance(vec, api.Vector)
    vec.copyFromNumpyArray(arr.flatten())