diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index d700397cfaf2acb797a7235730dabec79ebe6562..a1ead2aef63f7b186ed2d5e8a6598349ae50509d 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -253,6 +253,38 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(static_ret, dy_ret_value)) + def test_leakyrelu(self): + inputs = np.random.uniform(-1, 1, (10, 10)).astype('float32') + with self.static_graph(): + t = layers.data(name='t', shape=[10, 10], dtype='float32') + ret = layers.leaky_relu(t, alpha=0.01) + static_ret = self.get_static_graph_result( + feed={'t': inputs}, fetch_list=[ret])[0] + + with self.dynamic_graph(): + lrelu = paddle.nn.LeakyReLU(alpha=0.01) + dy_ret = lrelu(base.to_variable(inputs)) + dy_ret_value = dy_ret.numpy() + + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + + def test_pad2d(self): + with self.static_graph(): + t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32') + ret = layers.pad2d(t, paddings=[1, 1, 1, 1]) + static_ret = self.get_static_graph_result( + feed={'t': np.ones( + [3, 3, 5, 5], dtype='float32')}, + fetch_list=[ret])[0] + + with self.dynamic_graph(): + t = np.ones([3, 3, 5, 5], dtype='float32') + my_pad2d = paddle.nn.Pad2D(paddings=1) + dy_ret = my_pad2d(base.to_variable(t)) + dy_ret_value = dy_ret.numpy() + + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + def test_matmul(self): with self.static_graph(): t = layers.data(name='t', shape=[3, 3], dtype='float32') diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 607c47c9a8c53a392a69a00329fe2359324620de..e074ca66bb1d3700cc2e50db2b1439e991113f39 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -53,12 +53,14 @@ from .input import data #DEFINE_ALIAS # from .input import Input #DEFINE_ALIAS # from .layer.activation import PReLU #DEFINE_ALIAS from .layer.activation import ReLU #DEFINE_ALIAS +from .layer.activation import LeakyReLU #DEFINE_ALIAS from .layer.activation import Sigmoid #DEFINE_ALIAS # from .layer.activation import Softmax #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS +from .layer.common import Pad2D #DEFINE_ALIAS from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index cac6afd615465eb0e9c6452032af20bfbeaeb612..4963ac360804f88dad9677e1dd9c05a5231c89b9 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -29,12 +29,14 @@ from .activation import * from .norm import * # from .activation import PReLU #DEFINE_ALIAS from .activation import ReLU #DEFINE_ALIAS +from .activation import LeakyReLU #DEFINE_ALIAS from .activation import Sigmoid #DEFINE_ALIAS # from .activation import Softmax #DEFINE_ALIAS from .activation import LogSoftmax #DEFINE_ALIAS from .activation import HSigmoid #DEFINE_ALIAS from .common import BilinearTensorProduct #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS +from .common import Pad2D #DEFINE_ALIAS from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS from .common import UpSample #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index b30b651b79a501c36f6cd58234a96f62acdd1b1c..02a1d297e83ea4f21b3f1a9cb85b950e5959dc08 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -17,6 +17,7 @@ __all__ = [ # 'PReLU', 'ReLU', + 'LeakyReLU', 'Sigmoid', # 'Softmax', 'LogSoftmax', @@ -207,6 +208,50 @@ class ReLU(layers.Layer): return functional.relu(input, self._inplace) +class LeakyReLU(layers.Layer): + """ + :alias_main: paddle.nn.LeakyReLU + :alias: paddle.nn.LeakyReLU,paddle.nn.layer.LeakyReLU,paddle.nn.layer.activation.LeakyReLU + + Leaky ReLU Activation. + + .. math: + + out = max(x, alpha * x) + + Parameters: + alpha (float, optional): Slope of the activation function at x < 0. Default: 0.01. + inplace (bool, optional): If inplace is True, the input and output of + ``LeakyReLU`` are the same variable. Otherwise, the input and output of + ``LeakyReLU`` are different variables. Default False. Note that if x is + more than one OPs' input, inplace must be False. Default: False. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + + data = np.array([-2, 0, 1]).astype('float32') + lrelu = nn.LeakyReLU() + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + res = lrelu(data) # [-0.02, 0, 1] + """ + + def __init__(self, alpha=1e-2, inplace=False): + super(LeakyReLU, self).__init__() + self._alpha = alpha + self._inplace = inplace + + def forward(self, input): + return functional.leaky_relu(input, self._alpha, self._inplace) + + class Sigmoid(layers.Layer): """ :alias_main: paddle.nn.Sigmoid diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 94841bbe2e700c986e8cc8eca3b68e96dcb7add9..8125e528b195b28024915ed9c20b922bd6224a5e 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -20,7 +20,10 @@ from ...fluid.dygraph import Linear #DEFINE_ALIAS from ...fluid.dygraph import layers from .. import functional as F -__all__ = ['BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample'] +__all__ = [ + 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', + 'Pad2D' +] class UpSample(layers.Layer): @@ -248,3 +251,93 @@ class UpSample(layers.Layer): data_format=self.data_format) return out + + +class Pad2D(layers.Layer): + """ + :alias_main: paddle.nn.Pad2D + :alias: paddle.nn.Pad2D,paddle.nn.layer.Pad2D,paddle.nn.layer.common.Pad2D + + This interface is used to construct a callable object of the ``Pad2D`` class. + The Pad2D layer pads the input tensor boundaries according to 'paddings' and 'mode'. + If mode is 'reflect', paddings[0] and paddings[1] must be no greater + than height-1. And the width dimension has the same condition. + + Parameters: + paddings (int | List[int32]): The padding size. If padding is a int, uses the same + padding in all boundaries, if padding is a List, it must contain four integers, + (padding_top, padding_bottom, padding_left, padding_right). + Default is [0, 0, 0, 0]. + mode (str): Three modes: 'constant' (default), 'reflect', 'edge' . + When in 'constant' mode, this op uses a constant value to pad the input tensor. + When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. + When in 'edge' mode, uses input boundaries to pad the input tensor. + Default is 'constant' + pad_value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 + data_format (str): An string from: "NHWC", "NCHW". Specify the data format of + the input data. + Default is "NCHW" + + Returns: + None + + Examples: + .. code-block:: text + + Input = [[[[1., 2., 3.], + [4., 5., 6.]]]] + + Case 0: + paddings = [0, 1, 2, 3], + mode = 'constant' + pad_value = 0 + Out = [[[[0., 0., 1., 2., 3., 0., 0., 0.], + [0., 0., 4., 5., 6., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.]]]] + + Case 1: + paddings = [0, 1, 2, 1], + mode = 'reflect' + Out = [[[[3., 2., 1., 2., 3., 2.], + [6., 5., 4., 5., 6., 5.], + [3., 2., 1., 2., 3., 2.]]]] + + Case 2: + paddings = [0, 1, 2, 1], + mode = 'edge' + Out = [[[[1., 1., 1., 2., 3., 3.], + [4., 4., 4., 5., 6., 6.], + [4., 4., 4., 5., 6., 6.]]]] + + Code Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + data = np.ones((2, 2, 2, 2)).astype('float32') + my_pad = nn.Pad2D(paddings=[1, 1, 1, 1]) + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + result = my_pad(data) + """ + + def __init__(self, + paddings=0, + mode='constant', + pad_value=0.0, + data_format="NCHW"): + super(Pad2D, self).__init__() + self._mode = mode + self._pad_value = pad_value + self._data_format = data_format + self._paddings = [paddings] * 4 if isinstance(paddings, + int) else paddings + + def forward(self, input): + return F.pad2d( + input, + paddings=self._paddings, + mode=self._mode, + pad_value=self._pad_value, + data_format=self._data_format)