diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index cb498ce94d438fe59e2643ac2afcdc053ae907b9..976b0cbefff1f076c4d766c4df7860fe06cd1b18 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1071,5 +1071,71 @@ class TestNNFunctionalReluAPI(unittest.TestCase): self.assertTrue(np.allclose(out[0], self.y)) +class TestNNSigmoidAPI(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.x_shape = [10, 15] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.y = self.ref_forward(self.x) + + def ref_forward(self, x): + return 1 / (1 + np.exp(-x)) + + def ref_backward(self, y, dy): + return dy * y * (1 - y) + + def check_api(self, place=fluid.CPUPlace(), inplace=False): + main_program = Program() + mysigmoid = nn.Sigmoid(inplace) + with fluid.program_guard(main_program): + x = fluid.data(name='x', shape=self.x_shape) + x.stop_gradient = False + y = mysigmoid(x) + fluid.backward.append_backward(fluid.layers.mean(y)) + exe = fluid.Executor(place) + out = exe.run(main_program, + feed={'x': self.x}, + fetch_list=[y, y.grad_name, x.grad_name]) + self.assertTrue(np.allclose(out[0], self.y)) + self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1]))) + + with fluid.dygraph.guard(place): + x = fluid.dygraph.to_variable(self.x) + y = mysigmoid(x) + self.assertTrue(np.allclose(y.numpy(), self.y)) + + def test_check_api(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + for inplace in [True, False]: + self.check_api(place, inplace) + + +class TestNNFunctionalSigmoidAPI(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.x_shape = [10, 15] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.y = self.ref_forward(self.x) + + def ref_forward(self, x): + return 1 / (1 + np.exp(-x)) + + def test_check_api(self): + main_program = Program() + with fluid.program_guard(main_program): + x = fluid.data(name='x', shape=self.x_shape) + y = functional.sigmoid(x) + exe = fluid.Executor(fluid.CPUPlace()) + out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], self.y)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 5caa20104ddcdbcd1f58542718f182a84d13c175..e8e0aa3fd1f559fa6f8d2b2e4e20c7b77dbd2826 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -82,7 +82,7 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS # from .layer.norm import SpectralNorm #DEFINE_ALIAS # from .layer.activation import PReLU #DEFINE_ALIAS from .layer.activation import ReLU #DEFINE_ALIAS -# from .layer.activation import Sigmoid #DEFINE_ALIAS +from .layer.activation import Sigmoid #DEFINE_ALIAS # from .layer.activation import Softmax #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS # from .layer.rnn import RNNCell #DEFINE_ALIAS @@ -192,7 +192,7 @@ from .functional.conv import conv3d_transpose #DEFINE_ALIAS from .functional.activation import relu #DEFINE_ALIAS # from .functional.activation import relu6 #DEFINE_ALIAS # from .functional.activation import selu #DEFINE_ALIAS -# from .functional.activation import sigmoid #DEFINE_ALIAS +from .functional.activation import sigmoid #DEFINE_ALIAS # from .functional.activation import soft_relu #DEFINE_ALIAS # from .functional.activation import softmax #DEFINE_ALIAS # from .functional.activation import softplus #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 784a2ec9d235be651f62ebe972a8e679bb08b95f..1a66af1882db3a6f4d1900f0b708e062dfcb9228 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -118,7 +118,7 @@ from . import activation from .activation import relu #DEFINE_ALIAS # from .activation import relu6 #DEFINE_ALIAS # from .activation import selu #DEFINE_ALIAS -# from .activation import sigmoid #DEFINE_ALIAS +from .activation import sigmoid #DEFINE_ALIAS # from .activation import soft_relu #DEFINE_ALIAS # from .activation import softmax #DEFINE_ALIAS # from .activation import softplus #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 900f1aa33c12a442f5d45a2bb6caa133a3bbe643..8a236867b601462ac69d966d9147f9fabffa5b9a 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -16,6 +16,7 @@ import warnings from ...fluid.layer_helper import LayerHelper from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ from ...fluid import core +from ...fluid.data_feeder import check_variable_and_dtype # TODO: define activation functions of neural network __all__ = [ @@ -34,7 +35,7 @@ __all__ = [ 'relu', # 'relu6', # 'selu', - # 'sigmoid', + 'sigmoid', # 'soft_relu', # 'softmax', # 'softplus', @@ -94,6 +95,64 @@ def relu(input, inplace=False, name=None): return outs +def sigmoid(input, inplace=False, name=None): + """ + Sigmoid Activation. + + .. math: + + output = \frac{1}{1 + e^{-input}} + + Parameters: + input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64. + inplace (bool, optional): If inplace is True, the input and output are the same variable. + Otherwise, the input and output of are different variables. Default: False. Note that if x is + more than one OPs' input, inplace must be False. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Output of sigmoid operator, a Tensor with shape same as input + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn.functional as functional + import numpy as np + # In the static graph mode + input = fluid.data(name="input", shape=[None, 4]) + output = functional.sigmoid(input) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32') + output_data = exe.run(feed={"input": input_data}, + fetch_list=[output]) + print(output_data) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + # In the dynamic graph mode + with fluid.dygraph.guard(): + input = fluid.dygraph.to_variable(input_data) + output = functional.sigmoid(input) + print(output) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + """ + + if in_dygraph_mode(): + if inplace: + warnings.warn( + "Inplace on sigmoid is not allowed and will be discarded in dygraph mode currently." + ) + return core.ops.sigmoid(input) + + check_variable_and_dtype(input, 'X', ['float16', 'float32', 'float64'], + 'sigmoid') + helper = LayerHelper("sigmoid", **locals()) + outputs = helper.create_variable_for_type_inference(input.dtype) + helper.append_op( + type='sigmoid', inputs={'X': [input]}, outputs={'Out': outputs}) + return outputs + + def log_softmax(input, axis=None, dtype=None, name=None): """ This operator implements the log_softmax layer. The calculation process is as follows: diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index f94c8c7f9809d8740cc35a7764356132c04237b7..ad465b9f0b86e54ba1f41f500233bb3c637ee0cd 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -21,7 +21,7 @@ from .. import functional __all__ = [ # 'PReLU', 'ReLU', - # 'Sigmoid', + 'Sigmoid', # 'Softmax', 'LogSoftmax', ] @@ -66,6 +66,48 @@ class ReLU(layers.Layer): return functional.relu(input, self._inplace) +class Sigmoid(layers.Layer): + """ + Sigmoid Activation. + + .. math: + + output = \frac{1}{1 + e^{-input}} + + Parameters: + inplace (bool, optional): If inplace is True, the input and output + are the same variable. Otherwise, the input and output + are different variables. Default False. Note that if x is + more than one OPs' input, inplace must be False. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + input = fluid.data(name="input", shape=[None, 4]) + output = nn.Sigmoid()(input) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32') + output_data = exe.run(feed={"input": input_data}, + fetch_list=[output]) + print(output_data) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + """ + + def __init__(self, inplace=False): + super(Sigmoid, self).__init__() + self._inplace = inplace + + def forward(self, input): + return functional.sigmoid(input, self._inplace) + + class LogSoftmax(layers.Layer): """ This operator implements the log_softmax layer. The calculation process is as follows: