未验证 提交 fc0da42b 编写于 作者: C ceci3 提交者: GitHub

add Pad2D and LeakyReLU (#25177)

* add Pad2D and Leaky_ReLU, test=develop

* update,test=develop

* change name,test=develop

* add unittest and redine docs,test=develop
上级 d0a921ba
......@@ -253,6 +253,38 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, dy_ret_value))
def test_leakyrelu(self):
inputs = np.random.uniform(-1, 1, (10, 10)).astype('float32')
with self.static_graph():
t = layers.data(name='t', shape=[10, 10], dtype='float32')
ret = layers.leaky_relu(t, alpha=0.01)
static_ret = self.get_static_graph_result(
feed={'t': inputs}, fetch_list=[ret])[0]
with self.dynamic_graph():
lrelu = paddle.nn.LeakyReLU(alpha=0.01)
dy_ret = lrelu(base.to_variable(inputs))
dy_ret_value = dy_ret.numpy()
self.assertTrue(np.allclose(static_ret, dy_ret_value))
def test_pad2d(self):
with self.static_graph():
t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32')
ret = layers.pad2d(t, paddings=[1, 1, 1, 1])
static_ret = self.get_static_graph_result(
feed={'t': np.ones(
[3, 3, 5, 5], dtype='float32')},
fetch_list=[ret])[0]
with self.dynamic_graph():
t = np.ones([3, 3, 5, 5], dtype='float32')
my_pad2d = paddle.nn.Pad2D(paddings=1)
dy_ret = my_pad2d(base.to_variable(t))
dy_ret_value = dy_ret.numpy()
self.assertTrue(np.allclose(static_ret, dy_ret_value))
def test_matmul(self):
with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32')
......
......@@ -53,12 +53,14 @@ from .input import data #DEFINE_ALIAS
# from .input import Input #DEFINE_ALIAS
# from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
# from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS
from .layer.activation import HSigmoid #DEFINE_ALIAS
from .layer.common import BilinearTensorProduct #DEFINE_ALIAS
from .layer.common import Pool2D #DEFINE_ALIAS
from .layer.common import Pad2D #DEFINE_ALIAS
from .layer.common import Embedding #DEFINE_ALIAS
from .layer.common import Linear #DEFINE_ALIAS
from .layer.common import UpSample #DEFINE_ALIAS
......
......@@ -29,12 +29,14 @@ from .activation import *
from .norm import *
# from .activation import PReLU #DEFINE_ALIAS
from .activation import ReLU #DEFINE_ALIAS
from .activation import LeakyReLU #DEFINE_ALIAS
from .activation import Sigmoid #DEFINE_ALIAS
# from .activation import Softmax #DEFINE_ALIAS
from .activation import LogSoftmax #DEFINE_ALIAS
from .activation import HSigmoid #DEFINE_ALIAS
from .common import BilinearTensorProduct #DEFINE_ALIAS
from .common import Pool2D #DEFINE_ALIAS
from .common import Pad2D #DEFINE_ALIAS
from .common import Embedding #DEFINE_ALIAS
from .common import Linear #DEFINE_ALIAS
from .common import UpSample #DEFINE_ALIAS
......
......@@ -17,6 +17,7 @@
__all__ = [
# 'PReLU',
'ReLU',
'LeakyReLU',
'Sigmoid',
# 'Softmax',
'LogSoftmax',
......@@ -207,6 +208,50 @@ class ReLU(layers.Layer):
return functional.relu(input, self._inplace)
class LeakyReLU(layers.Layer):
"""
:alias_main: paddle.nn.LeakyReLU
:alias: paddle.nn.LeakyReLU,paddle.nn.layer.LeakyReLU,paddle.nn.layer.activation.LeakyReLU
Leaky ReLU Activation.
.. math:
out = max(x, alpha * x)
Parameters:
alpha (float, optional): Slope of the activation function at x < 0. Default: 0.01.
inplace (bool, optional): If inplace is True, the input and output of
``LeakyReLU`` are the same variable. Otherwise, the input and output of
``LeakyReLU`` are different variables. Default False. Note that if x is
more than one OPs' input, inplace must be False. Default: False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
data = np.array([-2, 0, 1]).astype('float32')
lrelu = nn.LeakyReLU()
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = lrelu(data) # [-0.02, 0, 1]
"""
def __init__(self, alpha=1e-2, inplace=False):
super(LeakyReLU, self).__init__()
self._alpha = alpha
self._inplace = inplace
def forward(self, input):
return functional.leaky_relu(input, self._alpha, self._inplace)
class Sigmoid(layers.Layer):
"""
:alias_main: paddle.nn.Sigmoid
......
......@@ -20,7 +20,10 @@ from ...fluid.dygraph import Linear #DEFINE_ALIAS
from ...fluid.dygraph import layers
from .. import functional as F
__all__ = ['BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample']
__all__ = [
'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample',
'Pad2D'
]
class UpSample(layers.Layer):
......@@ -248,3 +251,93 @@ class UpSample(layers.Layer):
data_format=self.data_format)
return out
class Pad2D(layers.Layer):
"""
:alias_main: paddle.nn.Pad2D
:alias: paddle.nn.Pad2D,paddle.nn.layer.Pad2D,paddle.nn.layer.common.Pad2D
This interface is used to construct a callable object of the ``Pad2D`` class.
The Pad2D layer pads the input tensor boundaries according to 'paddings' and 'mode'.
If mode is 'reflect', paddings[0] and paddings[1] must be no greater
than height-1. And the width dimension has the same condition.
Parameters:
paddings (int | List[int32]): The padding size. If padding is a int, uses the same
padding in all boundaries, if padding is a List, it must contain four integers,
(padding_top, padding_bottom, padding_left, padding_right).
Default is [0, 0, 0, 0].
mode (str): Three modes: 'constant' (default), 'reflect', 'edge' .
When in 'constant' mode, this op uses a constant value to pad the input tensor.
When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
When in 'edge' mode, uses input boundaries to pad the input tensor.
Default is 'constant'
pad_value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0
data_format (str): An string from: "NHWC", "NCHW". Specify the data format of
the input data.
Default is "NCHW"
Returns:
None
Examples:
.. code-block:: text
Input = [[[[1., 2., 3.],
[4., 5., 6.]]]]
Case 0:
paddings = [0, 1, 2, 3],
mode = 'constant'
pad_value = 0
Out = [[[[0., 0., 1., 2., 3., 0., 0., 0.],
[0., 0., 4., 5., 6., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]]]]
Case 1:
paddings = [0, 1, 2, 1],
mode = 'reflect'
Out = [[[[3., 2., 1., 2., 3., 2.],
[6., 5., 4., 5., 6., 5.],
[3., 2., 1., 2., 3., 2.]]]]
Case 2:
paddings = [0, 1, 2, 1],
mode = 'edge'
Out = [[[[1., 1., 1., 2., 3., 3.],
[4., 4., 4., 5., 6., 6.],
[4., 4., 4., 5., 6., 6.]]]]
Code Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
data = np.ones((2, 2, 2, 2)).astype('float32')
my_pad = nn.Pad2D(paddings=[1, 1, 1, 1])
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
result = my_pad(data)
"""
def __init__(self,
paddings=0,
mode='constant',
pad_value=0.0,
data_format="NCHW"):
super(Pad2D, self).__init__()
self._mode = mode
self._pad_value = pad_value
self._data_format = data_format
self._paddings = [paddings] * 4 if isinstance(paddings,
int) else paddings
def forward(self, input):
return F.pad2d(
input,
paddings=self._paddings,
mode=self._mode,
pad_value=self._pad_value,
data_format=self._data_format)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册