未验证 提交 b8d07501 编写于 作者: H hong19860320 提交者: GitHub

Add Sigmoid and sigmoid op in paddle.nn and paddle.nn.functional (#23334)

上级 9b06dd86
......@@ -1071,5 +1071,71 @@ class TestNNFunctionalReluAPI(unittest.TestCase):
self.assertTrue(np.allclose(out[0], self.y))
class TestNNSigmoidAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [10, 15]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)
def ref_forward(self, x):
return 1 / (1 + np.exp(-x))
def ref_backward(self, y, dy):
return dy * y * (1 - y)
def check_api(self, place=fluid.CPUPlace(), inplace=False):
main_program = Program()
mysigmoid = nn.Sigmoid(inplace)
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
x.stop_gradient = False
y = mysigmoid(x)
fluid.backward.append_backward(fluid.layers.mean(y))
exe = fluid.Executor(place)
out = exe.run(main_program,
feed={'x': self.x},
fetch_list=[y, y.grad_name, x.grad_name])
self.assertTrue(np.allclose(out[0], self.y))
self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1])))
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = mysigmoid(x)
self.assertTrue(np.allclose(y.numpy(), self.y))
def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
for inplace in [True, False]:
self.check_api(place, inplace)
class TestNNFunctionalSigmoidAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [10, 15]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)
def ref_forward(self, x):
return 1 / (1 + np.exp(-x))
def test_check_api(self):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = functional.sigmoid(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], self.y))
if __name__ == "__main__":
unittest.main()
......@@ -82,7 +82,7 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS
# from .layer.norm import SpectralNorm #DEFINE_ALIAS
# from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS
# from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
# from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS
# from .layer.rnn import RNNCell #DEFINE_ALIAS
......@@ -192,7 +192,7 @@ from .functional.conv import conv3d_transpose #DEFINE_ALIAS
from .functional.activation import relu #DEFINE_ALIAS
# from .functional.activation import relu6 #DEFINE_ALIAS
# from .functional.activation import selu #DEFINE_ALIAS
# from .functional.activation import sigmoid #DEFINE_ALIAS
from .functional.activation import sigmoid #DEFINE_ALIAS
# from .functional.activation import soft_relu #DEFINE_ALIAS
# from .functional.activation import softmax #DEFINE_ALIAS
# from .functional.activation import softplus #DEFINE_ALIAS
......
......@@ -118,7 +118,7 @@ from . import activation
from .activation import relu #DEFINE_ALIAS
# from .activation import relu6 #DEFINE_ALIAS
# from .activation import selu #DEFINE_ALIAS
# from .activation import sigmoid #DEFINE_ALIAS
from .activation import sigmoid #DEFINE_ALIAS
# from .activation import soft_relu #DEFINE_ALIAS
# from .activation import softmax #DEFINE_ALIAS
# from .activation import softplus #DEFINE_ALIAS
......
......@@ -16,6 +16,7 @@ import warnings
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from ...fluid import core
from ...fluid.data_feeder import check_variable_and_dtype
# TODO: define activation functions of neural network
__all__ = [
......@@ -34,7 +35,7 @@ __all__ = [
'relu',
# 'relu6',
# 'selu',
# 'sigmoid',
'sigmoid',
# 'soft_relu',
# 'softmax',
# 'softplus',
......@@ -94,6 +95,64 @@ def relu(input, inplace=False, name=None):
return outs
def sigmoid(input, inplace=False, name=None):
"""
Sigmoid Activation.
.. math:
output = \frac{1}{1 + e^{-input}}
Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64.
inplace (bool, optional): If inplace is True, the input and output are the same variable.
Otherwise, the input and output of are different variables. Default: False. Note that if x is
more than one OPs' input, inplace must be False.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Output of sigmoid operator, a Tensor with shape same as input
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn.functional as functional
import numpy as np
# In the static graph mode
input = fluid.data(name="input", shape=[None, 4])
output = functional.sigmoid(input)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32')
output_data = exe.run(feed={"input": input_data},
fetch_list=[output])
print(output_data) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
# In the dynamic graph mode
with fluid.dygraph.guard():
input = fluid.dygraph.to_variable(input_data)
output = functional.sigmoid(input)
print(output) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
"""
if in_dygraph_mode():
if inplace:
warnings.warn(
"Inplace on sigmoid is not allowed and will be discarded in dygraph mode currently."
)
return core.ops.sigmoid(input)
check_variable_and_dtype(input, 'X', ['float16', 'float32', 'float64'],
'sigmoid')
helper = LayerHelper("sigmoid", **locals())
outputs = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(
type='sigmoid', inputs={'X': [input]}, outputs={'Out': outputs})
return outputs
def log_softmax(input, axis=None, dtype=None, name=None):
"""
This operator implements the log_softmax layer. The calculation process is as follows:
......
......@@ -21,7 +21,7 @@ from .. import functional
__all__ = [
# 'PReLU',
'ReLU',
# 'Sigmoid',
'Sigmoid',
# 'Softmax',
'LogSoftmax',
]
......@@ -66,6 +66,48 @@ class ReLU(layers.Layer):
return functional.relu(input, self._inplace)
class Sigmoid(layers.Layer):
"""
Sigmoid Activation.
.. math:
output = \frac{1}{1 + e^{-input}}
Parameters:
inplace (bool, optional): If inplace is True, the input and output
are the same variable. Otherwise, the input and output
are different variables. Default False. Note that if x is
more than one OPs' input, inplace must be False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
input = fluid.data(name="input", shape=[None, 4])
output = nn.Sigmoid()(input)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32')
output_data = exe.run(feed={"input": input_data},
fetch_list=[output])
print(output_data) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
"""
def __init__(self, inplace=False):
super(Sigmoid, self).__init__()
self._inplace = inplace
def forward(self, input):
return functional.sigmoid(input, self._inplace)
class LogSoftmax(layers.Layer):
"""
This operator implements the log_softmax layer. The calculation process is as follows:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册