diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 70f6ed7fa2d5d43004177bf5270f93bca7526526..ec458ee7957c6a496ad4bd5579fe0b3c8a72069d 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -41,6 +41,7 @@ std::map> op_ins_map = { {"fake_quantize_dequantize_moving_average_abs_max", {"X", "InScale", "InAccum", "InState"}}, {"nll_loss", {"X", "Label", "Weight"}}, + {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, {"gather", {"X", "Index", "Axis"}}, }; diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_api.py b/python/paddle/fluid/tests/unittests/test_bilinear_api.py new file mode 100644 index 0000000000000000000000000000000000000000..24eae4797de85f371ed62e78c85b160f698ee9eb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bilinear_api.py @@ -0,0 +1,65 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from op_test import OpTest + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np + + +class TestBilinearAPI(unittest.TestCase): + def test_api(self): + with fluid.program_guard(fluid.default_startup_program(), + fluid.default_main_program()): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + + data1 = fluid.data(name='X1', shape=[5, 5], dtype='float32') + data2 = fluid.data(name='X2', shape=[5, 4], dtype='float32') + + layer1 = np.random.random((5, 5)).astype('float32') + layer2 = np.random.random((5, 4)).astype('float32') + + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + ret = bilinear(data1, data2) + + exe.run(fluid.default_startup_program()) + ret_fetch = exe.run(feed={'X1': layer1, + 'X2': layer2}, + fetch_list=[ret.name]) + self.assertEqual(ret_fetch[0].shape, (5, 1000)) + + +class TestBilinearAPIDygraph(unittest.TestCase): + def test_api(self): + paddle.disable_static() + layer1 = np.random.random((5, 5)).astype('float32') + layer2 = np.random.random((5, 4)).astype('float32') + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + ret = bilinear(paddle.to_tensor(layer1), paddle.to_tensor(layer2)) + self.assertEqual(ret.shape, [5, 1000]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 290622450a958d81d5c2bb2ed683ee2c271a5ad2..b262b945267c72136141d98b2b0d5b911bd08ca9 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -88,6 +88,7 @@ from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Flatten #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS +from .layer.common import Bilinear #DEFINE_ALIAS from .layer.common import Dropout #DEFINE_ALIAS from .layer.common import Dropout2D #DEFINE_ALIAS from .layer.common import Dropout3D #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 53f5954279311b2765d1acb14ed66193c2c2f52d..905849360e116341bdf366c6b600875af2357b11 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -72,6 +72,7 @@ from .common import unfold #DEFINE_ALIAS # from .common import bilinear_tensor_product #DEFINE_ALIAS from .common import assign #DEFINE_ALIAS from .common import interpolate #DEFINE_ALIAS +from .common import bilinear #DEFINE_ALIAS from .conv import conv1d #DEFINE_ALIAS from .conv import conv_transpose1d #DEFINE_ALIAS from .conv import conv2d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index bf404d54b7d1518d935c0716c71e7084a8276f19..cff108ec6a9a8666e0aa51ba0414fd885777f1a7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -33,6 +33,8 @@ from ...tensor import sqrt #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS +from ...fluid.framework import in_dygraph_mode +from ...fluid import core, dygraph_utils from ...fluid import core, layers from ...fluid.data_feeder import check_variable_and_dtype @@ -52,6 +54,7 @@ __all__ = [ # 'bilinear_tensor_product', 'assign', 'interpolate', + 'bilinear', 'cosine_similarity', ] @@ -460,6 +463,70 @@ def interpolate(input, return out +def bilinear(x1, x2, weight, bias=None, name=None): + """ + + This layer performs bilinear on two inputs. + + .. math:: + out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,size-1 + out = out + b + + In this formula: + - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features]. + - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features]. + - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features]. + - :math:`out_{i}`: the i-th element of out, shape is [batch_size, out_features]. + - :math:`b`: the learned bias, shape is [1, out_features]. + - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`. + + Parameters: + x1 (Tensor): the first input tensor, it's data type should be float32, float64. + x2 (Tensor): the second input tensor, it's data type should be float32, float64. + weight (Parameter): The learnable weights of this layer, shape is [out_features, in1_features, in2_features]. + bias (Parameter, optional): The learnable bias(Bias) of this layer, shape is [1, out_features]. If it is set to None, no bias will be added to the output units. The default value is None. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Variable: A 2-D Tensor of shape [batch_size, out_features]. + + Examples: + .. code-block:: python + + import paddle + import numpy + import paddle.nn.functional as F + + paddle.disable_static() + x1 = numpy.random.random((5, 5)).astype('float32') + x2 = numpy.random.random((5, 4)).astype('float32') + w = numpy.random.random((1000, 5, 4)).astype('float32') + b = numpy.random.random((1, 1000)).astype('float32') + + result = F.bilinear(paddle.to_tensor(x1), paddle.to_tensor(x2), paddle.to_tensor(w), paddle.to_tensor(b)) # result shape [5, 1000] + + """ + + if in_dygraph_mode(): + return core.ops.bilinear_tensor_product(x1, x2, weight, bias) + + check_variable_and_dtype(x1, 'x1', ['float32', 'float64'], 'bilinear') + check_variable_and_dtype(x2, 'x2', ['float32', 'float64'], 'bilinear') + + inputs = {"X": x1, "Y": x2, "Weight": weight} + if bias is not None: + inputs["Bias"] = bias + + helper = LayerHelper("bilinear", **locals()) + out = helper.create_variable_for_type_inference(dtype=x1.dtype) + + helper.append_op( + type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out}) + + return out + + def dropout(x, p=0.5, axis=None, diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 1a96a3738af96b5abe8f44d1bb5178855251f7b9..8a73cfb8ccda15505d8668eff6776aac387c134f 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -23,11 +23,27 @@ from .. import functional as F from ...fluid.framework import _dygraph_tracer __all__ = [ - 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', - 'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d', - 'ReflectionPad2d', 'ReplicationPad2d', 'ConstantPad2d', 'ZeroPad2d', - 'ConstantPad3d', 'ReplicationPad3d', 'CosineSimilarity', 'Dropout', - 'Dropout2D', 'Dropout3D', 'AlphaDropout' + 'BilinearTensorProduct', + 'Pool2D', + 'Embedding', + 'Linear', + 'UpSample', + 'Pad2D', + 'ReflectionPad1d', + 'ReplicationPad1d', + 'ConstantPad1d', + 'ReflectionPad2d', + 'ReplicationPad2d', + 'ConstantPad2d', + 'ZeroPad2d', + 'ConstantPad3d', + 'ReplicationPad3d', + 'CosineSimilarity', + 'Dropout', + 'Dropout2D', + 'Dropout3D', + 'Bilinear', + 'AlphaDropout', ] @@ -338,6 +354,94 @@ class Pad2D(layers.Layer): data_format=self._data_format) +class Bilinear(layers.Layer): + """ + + This layer performs bilinear on two inputs. + + .. math:: + out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,size-1 + out = out + b + + In this formula: + - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features]. + - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features]. + - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features]. + - :math:`out_{i}`: the i-th element of out, shape is [batch_size, out_features]. + - :math:`b`: the learned bias, shape is [1, out_features]. + - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`. + + Parameters: + in1_features (int): The dimension of each first input(`x1`). + in2_features (int): The dimension of each second input(`x2`). + out_features (int): The dimension of output of this layer. + weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of + this layer. The default value is None. + bias_attr (ParamAttr, optional): The parameter attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. The default value is None. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Attribute: + **weight** (Parameter): the learnable weights of this layer. + + **bias** (Parameter): the learnable bias of this layer. + + Returns: + Variable: A 2-D Tensor of shape [batch_size, out_features]. + + Examples: + .. code-block:: python + + import paddle + import numpy + + paddle.disable_static() + layer1 = numpy.random.random((5, 5)).astype('float32') + layer2 = numpy.random.random((5, 4)).astype('float32') + bilinear = paddle.nn.Bilinear( + in1_features=5, in2_features=4, out_features=1000) + result = bilinear(paddle.to_tensor(layer1), + paddle.to_tensor(layer2)) # result shape [5, 1000] + + """ + + def __init__(self, + in1_features, + in2_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(Bilinear, self).__init__() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self._name = name + self._in1_features = in1_features + self._in2_features = in2_features + self._out_features = out_features + self._dtype = self._helper.get_default_dtype() + + weight_shape = [ + self._out_features, self._in1_features, self._in2_features + ] + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=weight_shape, + dtype=self._dtype, + is_bias=False) + bias_shape = [1, self._out_features] + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=bias_shape, + dtype=self._dtype, + is_bias=True) + + def forward(self, x1, x2): + return F.bilinear(x1, x2, self.weight, self.bias, self._name) + + class Dropout(layers.Layer): """ Dropout is a regularization technique for reducing overfitting by preventing