fused_linear.py 3.8 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.nn import Layer
from paddle.incubate.nn import functional as F


class FusedLinear(Layer):
20
    r"""
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    Linear layer takes only one multi-dimensional tensor as input with the
    shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
    number of additional dimensions. It multiplies input tensor with the weight
    (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
    an output tensor of shape :math:`[batch\_size, *, out\_features]` .
    If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
    shape :math:`[out\_features]` ) will be created and added to the output.

    Parameters:
        in_features (int): The number of input units.
        out_features (int): The number of output units.
        weight_attr (ParamAttr, optional): The attribute for the learnable
            weight of this layer. The default value is None and the weight will be
            initialized to zero. For detailed information, please refer to
            paddle.ParamAttr.
        transpose_weight (bool): Whether to transpose the `weight` Tensor before
37
            multiplication.
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
            of this layer. If it is set to False, no bias will be added to the output.
            If it is set to None or one kind of ParamAttr, a bias parameter will
            be created according to ParamAttr. For detailed information, please refer
            to paddle.ParamAttr. The default value is None and the bias will be
            initialized to zero.
        name (str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .

    Attribute:
        **weight** (Parameter): the learnable weight of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` .

    Examples:
        .. code-block:: python
58

59 60 61 62
            # required: gpu
            import paddle
            from paddle.incubate.nn import FusedLinear

63
            x = paddle.randn([3, 4])
64
            linear = FusedLinear(4, 5)
65
            y = linear(x)
66 67 68
            print(y.shape) # [3, 5]
    """

69 70 71 72 73 74 75 76 77
    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        transpose_weight=False,
        name=None,
    ):
78
        super().__init__()
79 80 81 82 83
        if transpose_weight:
            weight_shape = [out_features, in_features]
        else:
            weight_shape = [in_features, out_features]
        dtype = self._helper.get_default_dtype()
84 85 86 87 88 89
        self.weight = self.create_parameter(
            shape=weight_shape, attr=weight_attr, dtype=dtype, is_bias=False
        )
        self.bias = self.create_parameter(
            shape=[out_features], attr=bias_attr, dtype=dtype, is_bias=True
        )
90 91 92 93
        self.transpose_weight = transpose_weight
        self.name = name

    def forward(self, input):
94 95 96
        return F.fused_linear(
            input, self.weight, self.bias, self.transpose_weight, self.name
        )