diff --git a/python/paddle/fluid/tests/unittests/test_adamax_api.py b/python/paddle/fluid/tests/unittests/test_adamax_api.py index f6946dc80b5e55b2e7149f357fe0600916a4fe9f..5a33e11d2862c037639b1643a2e44ff81a757053 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_api.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_api.py @@ -26,7 +26,7 @@ class TestAdamaxAPI(unittest.TestCase): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_variable(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.Adamax( learning_rate=0.01, parameters=linear.parameters(), diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index ddb70d6e6400c8e7ae71cabf92ce8060e220a7da..0a7cf54e2e0f15e51ba1b6f7526837f53c7cc2e0 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -23,7 +23,7 @@ class TestAdamWOp(unittest.TestCase): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_variable(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.AdamW( learning_rate=0.01, parameters=linear.parameters(), @@ -38,7 +38,7 @@ class TestAdamWOp(unittest.TestCase): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_variable(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.AdamW( learning_rate=0.0, parameters=linear.parameters(), diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py index b15ad911ee79d47011be6eaa4bde62ba71c55c0e..f61d1ab888a51b2ebe4d1205b30fb84dfa4e7aeb 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py @@ -40,9 +40,8 @@ class LeNetDygraph(fluid.dygraph.Layer): if num_classes > 0: self.fc = nn.Sequential( nn.Linear(400, 120), - nn.Linear(120, 84), - nn.Linear( - 84, 10, act=classifier_activation)) + nn.Linear(120, 84), nn.Linear(84, 10), + nn.Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) diff --git a/python/paddle/fluid/tests/unittests/test_linear.py b/python/paddle/fluid/tests/unittests/test_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..9d07a80da15dbfd35ffdedbcb09e82d59a84486e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linear.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +import paddle +from paddle import fluid, nn +import paddle.fluid.dygraph as dg +import paddle.nn.functional as F +import paddle.fluid.initializer as I + + +class LinearTestCase(unittest.TestCase): + def setUp(self): + self.dtype = 'float32' + self.input = np.ones((3, 1, 2)).astype(self.dtype) + self.weight = np.ones((2, 2)).astype(self.dtype) + self.bias = np.ones((2)).astype(self.dtype) + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + def functional(self, place): + paddle.disable_static(place) + input = paddle.to_tensor(self.input) + weight = paddle.to_tensor(self.weight) + bias = paddle.to_tensor(self.bias) + out = F.linear(input, weight, bias) + return out.numpy() + + def paddle_nn_layer(self, place): + paddle.disable_static(place) + input = paddle.to_tensor(self.input) + weight_attr = fluid.ParamAttr( + name="linear_weight", + learning_rate=1.0, + trainable=False, + regularizer=None, + initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + bias_attr = fluid.ParamAttr( + name="linear_bias", + learning_rate=1.0, + trainable=False, + regularizer=None, + initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + linear = paddle.nn.Linear( + 2, 2, weight_attr=weight_attr, bias_attr=bias_attr) + y = linear(input) + return y.numpy() + + def numpy_cal(self): + res = np.matmul(self.input, self.weight) + self.bias + return res + + def test_error(self, place=paddle.CPUPlace()): + res_f = self.functional(place) + res_nn = self.paddle_nn_layer(place) + res_np = self.numpy_cal() + np.testing.assert_array_almost_equal(res_f, res_nn) + np.testing.assert_array_almost_equal(res_nn, res_np) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 0f225758ced3bf7d6fd821be09f2dbf11ff1cc6d..f7b9d4214d36a422a3ec94dc410e58c6c827ef4c 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -228,7 +228,7 @@ class TestRMSPropV2(unittest.TestCase): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) # This can be any optimizer supported by dygraph. adam = paddle.optimizer.RMSProp( learning_rate=0.01, diff --git a/python/paddle/incubate/hapi/tests/test_model.py b/python/paddle/incubate/hapi/tests/test_model.py index 96c432e1bfd8f3620c705be62c4e9d90c2709fa5..8e0c051ee8c39c032dcc05afa466b493e1498a86 100644 --- a/python/paddle/incubate/hapi/tests/test_model.py +++ b/python/paddle/incubate/hapi/tests/test_model.py @@ -23,7 +23,7 @@ import shutil import tempfile from paddle import fluid -from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential, Softmax from paddle.fluid.dygraph.base import to_variable import paddle.incubate.hapi as hapi @@ -53,10 +53,8 @@ class LeNetDygraph(fluid.dygraph.Layer): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) @@ -83,10 +81,8 @@ class LeNetDeclarative(fluid.dygraph.Layer): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation @declarative def forward(self, inputs): @@ -320,10 +316,12 @@ class TestModel(unittest.TestCase): class MyModel(fluid.dygraph.Layer): def __init__(self, classifier_activation='softmax'): super(MyModel, self).__init__() - self._fc = Linear(20, 10, act=classifier_activation) + self._fc = Linear(20, 10) + self._act = Softmax() #Todo: accept any activation def forward(self, x): y = self._fc(x) + y = self._act(y) return y diff --git a/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py b/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py index 26ec53014b1c3b113a0e1ee82f3b9edfe9f48a3f..6df9b31217aae78c43de8d29956a8b2def99055b 100644 --- a/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py +++ b/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py @@ -22,7 +22,7 @@ import shutil import tempfile from paddle import fluid -from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential, Softmax from paddle.incubate.hapi.utils import uncombined_weight_to_state_dict @@ -43,10 +43,8 @@ class LeNetDygraph(fluid.dygraph.Layer): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) diff --git a/python/paddle/incubate/hapi/vision/models/lenet.py b/python/paddle/incubate/hapi/vision/models/lenet.py index dc7b094de0f26e04b9f07d011d3ce492950df269..169f70562f6edfe1773a1c8d75004c25831cedcb 100644 --- a/python/paddle/incubate/hapi/vision/models/lenet.py +++ b/python/paddle/incubate/hapi/vision/models/lenet.py @@ -13,7 +13,7 @@ #limitations under the License. import paddle.fluid as fluid -from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential, Softmax __all__ = ['LeNet'] @@ -50,10 +50,8 @@ class LeNet(fluid.dygraph.Layer): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) diff --git a/python/paddle/incubate/hapi/vision/models/vgg.py b/python/paddle/incubate/hapi/vision/models/vgg.py index 30f6e120b2502113045b3583686360f4ed2c32ac..4352a768eb7206ca30acead580a64a7d04b7701b 100644 --- a/python/paddle/incubate/hapi/vision/models/vgg.py +++ b/python/paddle/incubate/hapi/vision/models/vgg.py @@ -13,7 +13,7 @@ # limitations under the License. import paddle.fluid as fluid -from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, ReLU +from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, ReLU, Softmax from paddle.fluid.dygraph.container import Sequential from ...download import get_weights_path_from_url @@ -37,7 +37,8 @@ class Classifier(fluid.dygraph.Layer): super(Classifier, self).__init__() self.linear1 = Linear(512 * 7 * 7, 4096) self.linear2 = Linear(4096, 4096) - self.linear3 = Linear(4096, num_classes, act=classifier_activation) + self.linear3 = Linear(4096, num_classes) + self.act = Softmax() #Todo: accept any activation def forward(self, x): x = self.linear1(x) @@ -46,7 +47,8 @@ class Classifier(fluid.dygraph.Layer): x = self.linear2(x) x = fluid.layers.relu(x) x = fluid.layers.dropout(x, 0.5) - out = self.linear3(x) + x = self.linear3(x) + out = self.act(x) return out diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 645b2115650a109a9e84a64a531685b491e9b1b5..76063458d44de3000ad7c1af08376c07e0209c27 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -115,6 +115,7 @@ from .layer.extension import RowConv #DEFINE_ALIAS # from .layer.learning_rate import NoamDecay #DEFINE_ALIAS # from .layer.learning_rate import PiecewiseDecay #DEFINE_ALIAS # from .layer.learning_rate import PolynomialDecay #DEFINE_ALIAS +from .layer.common import Linear # from .layer.loss import NCELoss #DEFINE_ALIAS from .layer.loss import BCEWithLogitsLoss #DEFINE_ALIAS from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 75e2da4cf7e92f91d021f4e962c00925b426a89d..414e70853eb7163230ab2db987fc19c58e168f19 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -75,6 +75,7 @@ from .common import interpolate #DEFINE_ALIAS from .common import bilinear #DEFINE_ALIAS from .conv import conv1d #DEFINE_ALIAS from .conv import conv_transpose1d #DEFINE_ALIAS +from .common import linear #DEFINE_ALIAS from .conv import conv2d #DEFINE_ALIAS from .conv import conv_transpose2d #DEFINE_ALIAS from .conv import conv3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 6a462b53b753cf3040d474947c480e7fa2530138..8408e224d87370ca25d3738d8182c381d19b707b 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -17,7 +17,8 @@ import paddle from ...fluid.framework import in_dygraph_mode, default_main_program from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers.tensor import Variable, fill_constant, zeros, concat - +from ...fluid.layers import core +from ...fluid import dygraph_utils # TODO: define the common functions to build a neural network from ...fluid.layers import label_smooth #DEFINE_ALIAS from ...fluid import one_hot #DEFINE_ALIAS @@ -30,6 +31,10 @@ from ...fluid.layers import elementwise_mul #DEFINE_ALIAS from ...tensor import clip from ...tensor import sum from ...tensor import sqrt +from ...tensor import sum #DEFINE_ALIAS +from ...tensor import sqrt #DEFINE_ALIAS +from ...fluid.data_feeder import check_variable_and_dtype, check_dtype +from ...fluid.framework import Variable, in_dygraph_mode, _varbase_creator #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS @@ -46,6 +51,7 @@ __all__ = [ # 'embedding', # 'fc', 'label_smooth', + 'linear', 'one_hot', 'pad', 'pad_constant_like', @@ -1348,3 +1354,83 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8): n12 = sqrt(clip(w1 * w2, min=eps * eps)) cos_sim = w12 / n12 return cos_sim + + +def linear(x, weight, bias=None, name=None): + """ + + Fully-connected linear transformation op + + .. math:: + + Out = {XW + b} + + where :math:`X` is the input Tensor, :math:`W` and :math:`b` are weight and bias respectively. + + The linear op multiplies input tensor with weight matrix and + produces an output Tensor of shape [N, *, output_dim], + where N is batch size and `*` means any number of additional dimensions and output_dim is the last dim of ``weight``. + If ``bias`` is not None, a bias will be added to the output. + + Args: + x(Tensor): Input tensor, its data type is float16, float32 or float64 + weight(Tensor): Weight tensor, its data type is float16, float32 or float64 + bias(Tensor|None, optional): Bias tensor, its data type is float16, float32 or float64. If it is set to None, no bias will be added to the output units. + name(str|None, optional): For detailed information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Output tensor + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + + input = np.ones((3,1,2), dtype=np.float32) + weight = np.ones((2,2), dtype=np.float32) + bias = np.ones((2), dtype=np.float32) + place = paddle.CPUPlace() + paddle.disable_static(place) + input = paddle.to_tensor(input) + weight = paddle.to_tensor(weight) + bias = paddle.to_tensor(bias) + out = F.linear(input, weight, bias) + print(out) #[3 3 3 3 3 3] + + """ + if in_dygraph_mode(): + pre_bias = _varbase_creator(dtype=x.dtype) + core.ops.matmul(x, weight, pre_bias, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + return dygraph_utils._append_bias_in_dygraph( + pre_bias, bias, axis=len(x.shape) - 1) + else: + helper = LayerHelper('linear', **locals()) + dtype = x.dtype + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') + + inputs = {'X': [x], 'Y': [weight]} + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + tmp = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='matmul', inputs=inputs, outputs={'Out': tmp}, attrs=attrs) + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], + 'Y': [bias]}, + outputs={'Out': [res]}, + attrs={'axis': len(x.shape) - 1}) + else: + res = tmp + return res diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 9f32c1365a39d4e528acb88fa4e8b408feb3153a..a1e6508c67d96e9f6cc077efe6e61d708674b057 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -16,7 +16,6 @@ from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS -from ...fluid.dygraph import Linear #DEFINE_ALIAS from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers from .. import functional as F @@ -49,6 +48,89 @@ __all__ = [ ] +class Linear(layers.Layer): + """ + + Fully-connected linear transformation layer: + + .. math:: + + Out = {XW + b} + + where :math:`X` is the input Tensor, :math:`W` and :math:`b` are weight and bias respectively. + + Linear layer takes only one ``Tensor`` input. + The Linear layer multiplies input tensor with weight matrix and + produces an output Tensor of shape [N, *, `output_dim`], + where N is batch size and `*` means any number of additional dimensions. + If ``bias_attr`` is not None, a bias variable will be created and added to the output. + + Parameters: + in_features(int): The number of input units in this layer. + out_features(int): The number of output units in this layer. + weight_attr(ParamAttr or list of ParamAttr, optional): The parameter attribute for learnable + weights(Parameter) of this layer. Default: None. + bias_attr(ParamAttr or list of ParamAttr, optional): The attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + name(str|None): For detailed information, please refer to :ref:`api_guide_Name`. Default: None. + + Attributes: + **weight** (Parameter): the learnable weights of this layer. + + **bias** (Parameter or None): the learnable bias of this layer. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + from paddle import nn + import numpy as np + + data = np.ones((3,1,2), np.float32) + place = paddle.CPUPlace() + paddle.disable_static(place) + data = paddle.to_tensor(data) + weight_attr=paddle.framework.ParamAttr(name="linear_weight", learning_rate=1.0, + trainable=False, regularizer=None, initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + bias_attr=paddle.framework.ParamAttr(name="linear_bias", learning_rate=1.0, + trainable=False, regularizer=None, initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + linear = nn.Linear(2,2,weight_attr=weight_attr, bias_attr=bias_attr) + res = linear(data) # [3 3 3 3 3 3] + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(Linear, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.name = name + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.bias = self.create_parameter( + shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.name = name + + def forward(self, input): + out = F.linear( + x=input, weight=self.weight, bias=self.bias, name=self.name) + return out + + class UpSample(layers.Layer): """ This op resizes a batch of images.