diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index fe68bcf322d05823907d896fa88b9322145c19f9..3a27186bdea0385ed1a8bd95502972176beadc61 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -27,9 +27,10 @@ import numbers import logging __all__ = [ - 'Conv2D', 'Conv3D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit', - 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose', - 'Conv3DTranspose', 'GroupNorm', 'SpectralNorm', 'TreeConv' + 'Conv2D', 'Conv3D', 'Pool2D', 'FC', 'Linear', 'BatchNorm', 'Embedding', + 'GRUUnit', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', + 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', 'SpectralNorm', + 'TreeConv' ] @@ -873,6 +874,101 @@ class Pool2D(layers.Layer): return pool_out +class Linear(layers.Layer): + """ + Fully-connected linear transformation layer: + + .. math:: + + Out = Act({XW + b}) + + where :math:`X` is the input Tensor, :math:`W` and :math:`b` are weight and bias respectively. + + Different from FC layer, Linear layer takes only one ``Tensor`` input. + The Linear layer multiplies input tensor with weight matrix and + produces an output Tensor of shape [N, *, `output_dim`], + where N is batch size and `*` means any number of additional dimensions. + If ``bias_attr`` is not None, a bias variable will be created and added to the output. + Finally, if ``act`` is not None, it will be applied to the output as well. + + Parameters: + input_dim(int): The number of input units in this layer. + output_dim(int): The number of output units in this layer. + param_attr(ParamAttr or list of ParamAttr, optional): The parameter attribute for learnable + weights(Parameter) of this layer. Default: None. + bias_attr(ParamAttr or list of ParamAttr, optional): The attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + act(str, optional): Activation to be applied to the output of this layer. Default: None. + dtype(str, optional): Dtype used for weight, it can be "float32" or "float64". Default: "float32". + + Attributes: + **weight** (Parameter): the learnable weights of this layer. + + **bias** (Parameter or None): the learnable bias of this layer. + + Returns: + None + + Examples: + .. code-block:: python + + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import Linear + import numpy as np + + data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') + with fluid.dygraph.guard(): + linear = Linear(32, 64) + data = to_variable(data) + res = linear(data) # [30, 10, 64] + """ + + def __init__(self, + input_dim, + output_dim, + param_attr=None, + bias_attr=None, + act=None, + dtype="float32"): + super(Linear, self).__init__() + self._act = act + self._dtype = dtype + self.weight = self.create_parameter( + shape=[input_dim, output_dim], + attr=param_attr, + dtype=dtype, + is_bias=False) + self.bias = self.create_parameter( + shape=[output_dim], attr=bias_attr, dtype=dtype, is_bias=True) + + def forward(self, input): + tmp = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="matmul", + inputs={"X": input, + "Y": self.weight}, + outputs={"Out": tmp}, + attrs={ + "transpose_X": False, + "transpose_Y": False, + "alpha": 1, + }) + if self.bias: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], + 'Y': [self.bias]}, + outputs={'Out': [pre_activation]}, + attrs={'axis': len(input.shape) - 1}) + else: + pre_activation = tmp + return self._helper.append_activation(pre_activation, act=self._act) + + class FC(layers.Layer): """ This interface is used to construct a callable object of the ``FC`` class. diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 1a196a9804e41df2f09ff7de3ad0b4e148d15641..641cf8b75fa963fd62ceff03e7d0765ffdbc6700 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -110,6 +110,41 @@ class TestLayer(LayerTest): ret = custom(x, do_fc2=True) self.assertTrue(np.array_equal(ret.numpy().shape, [3, 1])) + def test_linear(self): + inp = np.ones([3, 32, 32], dtype='float32') + with self.static_graph(): + t = layers.data( + name='data', + shape=[3, 32, 32], + dtype='float32', + append_batch_size=False) + linear = nn.Linear( + 32, 4, bias_attr=fluid.initializer.ConstantInitializer(value=1)) + ret = linear(t) + static_ret = self.get_static_graph_result( + feed={'data': inp}, fetch_list=[ret])[0] + with self.dynamic_graph(): + t = base.to_variable(inp) + linear = nn.Linear( + 32, 4, bias_attr=fluid.initializer.ConstantInitializer(value=1)) + dy_ret = linear(t) + dy_ret_value = dy_ret.numpy() + + self.assertTrue(np.array_equal(static_ret, dy_ret_value)) + + inp = np.ones([3, 32], dtype='float32') + with self.dynamic_graph(): + t = base.to_variable(inp) + linear = nn.Linear(32, 4, bias_attr=False) + dy_ret = linear(t) + dy_ret_value = dy_ret.numpy() + with self.dynamic_graph(): + t = base.to_variable(inp) + fc = nn.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1) + dy_ret2 = fc(t) + dy_ret_value2 = dy_ret2.numpy() + self.assertTrue(np.array_equal(dy_ret_value, dy_ret_value2)) + def test_fc(self): inp = np.ones([3, 32, 32], dtype='float32') with self.static_graph():