assign.py 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Z
zhiboniu 已提交
14
import paddle
15 16 17
from ...fluid.data_feeder import check_type
from ...fluid.initializer import NumpyArrayInitializer

18 19
__all__ = []

20 21 22 23 24

class Assign(NumpyArrayInitializer):
    """Init an parameter with a numpy array, list, or tensor.

    Args:
25
        value (Tensor|numpy.ndarray|list|tuple): numpy array, list, tuple, or tensor to initialize the parameter.
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        name(str, optional): The default value is None. Normally there is no need for user to set this
            property. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A parameter initialized by the input numpy array, list, or tensor.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # numpy array
            data_1 = paddle.ones(shape=[1, 2], dtype='float32')
            weight_attr_1 = paddle.framework.ParamAttr(
                name="linear_weight_1", 
                initializer=paddle.nn.initializer.Assign(np.array([2, 2])))
            bias_attr_1 = paddle.framework.ParamAttr(
                name="linear_bias_1",
                initializer=paddle.nn.initializer.Assign(np.array([2])))
            linear_1 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_1, bias_attr=bias_attr_1)
            # linear_1.weight:  [2. 2.]
            # linear_1.bias:  [2.]

50
            res_1 = linear_1(data_1)
51 52 53 54 55 56 57 58 59 60 61 62 63 64
            # res_1:  [6.]

            # python list
            data_2 = paddle.ones(shape=[1, 2], dtype='float32')
            weight_attr_2 = paddle.framework.ParamAttr(
                name="linear_weight_2",
                initializer=paddle.nn.initializer.Assign([2, 2]))
            bias_attr_2 = paddle.framework.ParamAttr(
                name="linear_bias_2",
                initializer=paddle.nn.initializer.Assign([2]))
            linear_2 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_2, bias_attr=bias_attr_2)
            # linear_2.weight:  [2. 2.]
            # linear_2.bias:  [2.]

65
            res_2 = linear_2(data_2)
66 67 68 69 70 71 72 73 74 75 76 77 78 79
            # res_2:  [6.]

            # tensor
            data_3 = paddle.ones(shape=[1, 2], dtype='float32')
            weight_attr_3 = paddle.framework.ParamAttr(
                name="linear_weight_3",
                initializer=paddle.nn.initializer.Assign(paddle.full([2], 2)))
            bias_attr_3 = paddle.framework.ParamAttr(
                name="linear_bias_3",
                initializer=paddle.nn.initializer.Assign(paddle.full([1], 2)))
            linear_3 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_3, bias_attr=bias_attr_3)
            # linear_3.weight:  [2. 2.]
            # linear_3.bias:  [2.]

80
            res_3 = linear_3(data_3)
81 82 83 84 85
            # res_3:  [6.]
    """

    def __init__(self, value, name=None):
        import numpy
86
        check_type(value, 'value',
Z
zhiboniu 已提交
87 88
                   (numpy.ndarray, list, tuple, paddle.static.Variable),
                   'Assign')
89

90
        if (isinstance(value, (list, tuple))):
91 92 93
            value = numpy.array(value)

        # TODO: value is already is a tensor, accounting efficiency maybe it does not need to convert tensor to numpy data and then initialized.
Z
zhiboniu 已提交
94
        if (isinstance(value, paddle.static.Variable)):
95 96 97
            value = value.numpy()

        super(Assign, self).__init__(value)