parameters.py 3.1 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
import numpy as np

from paddle.proto.ModelConfig_pb2 import ModelConfig
from paddle.proto.ParameterConfig_pb2 import ParameterConfig

__all__ = ['IParameterPool', 'create', 'ParameterFlag']


class ParameterFlag(object):
    """
    The flag for IParameterPool.get_parameter. If writeable, operation on return
    numpy array will also apply to Paddle parameter. But it will be slower in
    GPU mode.
    """
    READ_ONLY = 0x01
    WRITE_ONLY = 0x02
    READ_WRITE = READ_ONLY | WRITE_ONLY


class IParameterPool(object):
    """
    Interface of Parameter Pool. The parameter pool is a dictionary of
    parameters. User can modify parameter or customize parameter value
    by `get_parameter`.

    ..  code-block:: python

        pool = paddle.parameters.create(topo1, topo2)

        embedding = pool.get_parameter("embedding")
        assert isinstance(embedding, numpy.ndarray)
        print embedding[1:]
    """

    def get_parameter(self, name, flag=ParameterFlag.READ_WRITE):
        """
        Get a parameter by name.

        :param name: parameter name.
        :type name: basestring
        :param flag: the flag for return value. readable or writable.
        :type flag: int
        :return: The parameter value
        :rtype: np.ndarray
        """
        raise NotImplementedError()

    def get_names(self):
        """
        Get all parameter names
        :return: all parameter names
        :rtype: list
        """
        raise NotImplementedError()


class NumpyParameterPool(IParameterPool):
    def __init__(self):
        self.__param_configs__ = dict()
        self.__params__ = dict()

    def append(self, conf):
        if not isinstance(conf, ParameterConfig):
            raise ValueError("conf must be ParameterConfig")

        if not conf.IsInitialized():
            raise ValueError("conf is not initialized")

        self.__param_configs__[conf.name] = conf
        self.__params__[conf.name] = None

    def get_config(self, name):
        if name not in self.__param_configs__:
            raise ValueError("parameter %s is not appended" % name)

        return self.__param_configs__[name]

    def get_parameter(self, name, *args, **kwargs):
        if name not in self.__params__:
            raise ValueError("parameter %s is not appended" % name)

        param = self.__params__[name]
        if param is None:
            shape = self.__param_configs__[name].dims
            if len(shape) == 0:
                raise ValueError("parameter %s is no shape" % name)
            param = np.ndarray(
                shape=[int(item) for item in shape], dtype='float32')
            self.__params__[name] = param
        return param

    def get_names(self):
        return self.__param_configs__.keys()


def create(*topologies):
    """
    Create parameter pool by topologies.

    :param topologies:
    :return:
    """
    pool = NumpyParameterPool()
    for topo in topologies:
        if not isinstance(topo, ModelConfig):
            raise ValueError(
                'create must pass a topologies which type is ModelConfig')

        for param in topo.parameters:
            pool.append(param)
    return pool