未验证 提交 00e08ce0 编写于 作者: W wawltor 提交者: GitHub

add the sigmoid, Sigmoid for the api 2.0 (#26171)

Update the sigmoid, Sigmoid layer for the api2.0
上级 f8ca7201
...@@ -274,6 +274,7 @@ def generate_activation_fn(op_type): ...@@ -274,6 +274,7 @@ def generate_activation_fn(op_type):
return output return output
func.__name__ = op_type func.__name__ = op_type
func.__module__ = "paddle.fluid.layers"
func.__doc__ = _generate_doc_string_( func.__doc__ = _generate_doc_string_(
op_proto, op_proto,
additional_args_lines=[ additional_args_lines=[
......
...@@ -23,10 +23,15 @@ from paddle.utils import deprecated ...@@ -23,10 +23,15 @@ from paddle.utils import deprecated
__activations_noattr__ = [ __activations_noattr__ = [
'sigmoid', 'sigmoid',
'logsigmoid', 'logsigmoid',
'tanh_shrink',
'softplus',
'softsign',
]
__unary_func__ = [
'exp', 'exp',
'tanh', 'tanh',
'atan', 'atan',
'tanh_shrink',
'sqrt', 'sqrt',
'rsqrt', 'rsqrt',
'abs', 'abs',
...@@ -34,15 +39,13 @@ __activations_noattr__ = [ ...@@ -34,15 +39,13 @@ __activations_noattr__ = [
'floor', 'floor',
'cos', 'cos',
'acos', 'acos',
'asin',
'sin', 'sin',
'sinh', 'sinh',
'asin',
'cosh', 'cosh',
'round', 'round',
'reciprocal', 'reciprocal',
'square', 'square',
'softplus',
'softsign',
] ]
__all__ = [] __all__ = []
...@@ -58,9 +61,18 @@ globals()['_scale'] = generate_layer_fn('scale') ...@@ -58,9 +61,18 @@ globals()['_scale'] = generate_layer_fn('scale')
globals()['_elementwise_div'] = generate_layer_fn('elementwise_div') globals()['_elementwise_div'] = generate_layer_fn('elementwise_div')
__all__ += __activations_noattr__ __all__ += __activations_noattr__
__all__ += __unary_func__
for _OP in set(__activations_noattr__): for _OP in set(__activations_noattr__):
globals()[_OP] = generate_activation_fn(_OP) func = generate_activation_fn(_OP)
func = deprecated(
since="2.0.0", update_to="paddle.nn.functional.%s" % (_OP))(func)
globals()[_OP] = func
for _OP in set(__unary_func__):
func = generate_activation_fn(_OP)
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_OP))(func)
globals()[_OP] = func
add_sample_code(globals()["sigmoid"], r""" add_sample_code(globals()["sigmoid"], r"""
Examples: Examples:
......
...@@ -1440,9 +1440,9 @@ class TestNNReluAPI(unittest.TestCase): ...@@ -1440,9 +1440,9 @@ class TestNNReluAPI(unittest.TestCase):
y_t[y_t > 0] = 1 y_t[y_t > 0] = 1
return y_t * dy return y_t * dy
def check_api(self, place=fluid.CPUPlace(), inplace=False): def check_api(self, place=fluid.CPUPlace()):
main_program = Program() main_program = Program()
myrelu = nn.ReLU(inplace) myrelu = nn.ReLU()
with fluid.program_guard(main_program): with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape) x = fluid.data(name='x', shape=self.x_shape)
x.stop_gradient = False x.stop_gradient = False
...@@ -1465,8 +1465,7 @@ class TestNNReluAPI(unittest.TestCase): ...@@ -1465,8 +1465,7 @@ class TestNNReluAPI(unittest.TestCase):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
for place in places: for place in places:
for inplace in [True, False]: self.check_api(place)
self.check_api(place, inplace)
class TestNNFunctionalReluAPI(unittest.TestCase): class TestNNFunctionalReluAPI(unittest.TestCase):
...@@ -1491,71 +1490,5 @@ class TestNNFunctionalReluAPI(unittest.TestCase): ...@@ -1491,71 +1490,5 @@ class TestNNFunctionalReluAPI(unittest.TestCase):
self.assertTrue(np.allclose(out[0], self.y)) self.assertTrue(np.allclose(out[0], self.y))
class TestNNSigmoidAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [10, 15]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)
def ref_forward(self, x):
return 1 / (1 + np.exp(-x))
def ref_backward(self, y, dy):
return dy * y * (1 - y)
def check_api(self, place=fluid.CPUPlace(), inplace=False):
main_program = Program()
mysigmoid = nn.Sigmoid(inplace)
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
x.stop_gradient = False
y = mysigmoid(x)
fluid.backward.append_backward(fluid.layers.mean(y))
exe = fluid.Executor(place)
out = exe.run(main_program,
feed={'x': self.x},
fetch_list=[y, y.grad_name, x.grad_name])
self.assertTrue(np.allclose(out[0], self.y))
self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1])))
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = mysigmoid(x)
self.assertTrue(np.allclose(y.numpy(), self.y))
def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
for inplace in [True, False]:
self.check_api(place, inplace)
class TestNNFunctionalSigmoidAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [10, 15]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)
def ref_forward(self, x):
return 1 / (1 + np.exp(-x))
def test_check_api(self):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = F.sigmoid(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], self.y))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from scipy.special import expit, erf
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.nn.functional as functional
class TestNNSigmoidAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [10, 15]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)
def ref_forward(self, x):
return 1 / (1 + np.exp(-x))
def ref_backward(self, y, dy):
return dy * y * (1 - y)
def check_static_api(self, place):
paddle.enable_static()
main_program = paddle.static.Program()
mysigmoid = nn.Sigmoid(name="api_sigmoid")
with paddle.static.program_guard(main_program):
x = paddle.nn.data(name='x', shape=self.x_shape)
x.stop_gradient = False
y = mysigmoid(x)
fluid.backward.append_backward(paddle.mean(y))
exe = paddle.static.Executor(place)
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], self.y))
self.assertTrue(y.name.startswith("api_sigmoid"))
def check_dynamic_api(self, place):
paddle.disable_static(place)
x = paddle.to_variable(self.x)
mysigmoid = nn.Sigmoid()
y = mysigmoid(x)
self.assertTrue(np.allclose(y.numpy(), self.y))
def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_dynamic_api(place)
self.check_static_api(place)
class TestNNFunctionalSigmoidAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [10, 15]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)
def ref_forward(self, x):
return 1 / (1 + np.exp(-x))
def check_static_api(self, place):
paddle.enable_static()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.nn.data(name='x', shape=self.x_shape)
y = functional.sigmoid(x, name="api_sigmoid")
exe = paddle.static.Executor(fluid.CPUPlace())
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], self.y))
def check_dynamic_api(self):
paddle.disable_static()
x = paddle.to_variable(self.x)
y = functional.sigmoid(x)
self.assertTrue(np.allclose(y.numpy(), self.y))
def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_static_api(place)
self.check_dynamic_api()
...@@ -29,6 +29,7 @@ from ...fluid.layers import softplus #DEFINE_ALIAS ...@@ -29,6 +29,7 @@ from ...fluid.layers import softplus #DEFINE_ALIAS
from ...fluid.layers import softshrink #DEFINE_ALIAS from ...fluid.layers import softshrink #DEFINE_ALIAS
from ...fluid.layers import softsign #DEFINE_ALIAS from ...fluid.layers import softsign #DEFINE_ALIAS
from ...fluid.layers import swish #DEFINE_ALIAS from ...fluid.layers import swish #DEFINE_ALIAS
from ...fluid.layers import sigmoid #DEFINE_ALIAS
from ...fluid.layers import tanh_shrink #DEFINE_ALIAS from ...fluid.layers import tanh_shrink #DEFINE_ALIAS
from ...fluid.layers import thresholded_relu #DEFINE_ALIAS from ...fluid.layers import thresholded_relu #DEFINE_ALIAS
...@@ -48,12 +49,12 @@ __all__ = [ ...@@ -48,12 +49,12 @@ __all__ = [
'relu', 'relu',
'relu6', 'relu6',
'selu', 'selu',
'sigmoid',
'soft_relu', 'soft_relu',
'softmax', 'softmax',
'softplus', 'softplus',
'softshrink', 'softshrink',
'softsign', 'softsign',
'sigmoid',
'swish', 'swish',
'tanh_shrink', 'tanh_shrink',
'thresholded_relu', 'thresholded_relu',
...@@ -296,67 +297,6 @@ def relu(input, inplace=False, name=None): ...@@ -296,67 +297,6 @@ def relu(input, inplace=False, name=None):
return outs return outs
def sigmoid(input, inplace=False, name=None):
"""
:alias_main: paddle.nn.functional.sigmoid
:alias: paddle.nn.functional.sigmoid,paddle.nn.functional.activation.sigmoid
Sigmoid Activation.
.. math:
output = \frac{1}{1 + e^{-input}}
Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64.
inplace (bool, optional): If inplace is True, the input and output are the same variable.
Otherwise, the input and output of are different variables. Default: False. Note that if x is
more than one OPs' input, inplace must be False.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Output of sigmoid operator, a Tensor with shape same as input
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn.functional as functional
import numpy as np
# In the static graph mode
input = fluid.data(name="input", shape=[None, 4])
output = functional.sigmoid(input)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32')
output_data = exe.run(feed={"input": input_data},
fetch_list=[output])
print(output_data) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
# In the dynamic graph mode
with fluid.dygraph.guard():
input = fluid.dygraph.to_variable(input_data)
output = functional.sigmoid(input)
print(output) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
"""
if in_dygraph_mode():
if inplace:
warnings.warn(
"Inplace on sigmoid is not allowed and will be discarded in dygraph mode currently."
)
return core.ops.sigmoid(input)
check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
'sigmoid')
helper = LayerHelper("sigmoid", **locals())
outputs = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(
type='sigmoid', inputs={'X': [input]}, outputs={'Out': outputs})
return outputs
def softmax(x, axis=-1, name=None): def softmax(x, axis=-1, name=None):
""" """
This operator implements the softmax layer. The calculation process is as follows: This operator implements the softmax layer. The calculation process is as follows:
......
...@@ -28,7 +28,7 @@ __all__ = [ ...@@ -28,7 +28,7 @@ __all__ = [
from ...fluid.dygraph import layers from ...fluid.dygraph import layers
from ...fluid import core from ...fluid import core
from ...fluid.framework import in_dygraph_mode from ...fluid.framework import in_dygraph_mode
from .. import functional from .. import functional as F
class Hardshrink(layers.Layer): class Hardshrink(layers.Layer):
...@@ -75,7 +75,7 @@ class Hardshrink(layers.Layer): ...@@ -75,7 +75,7 @@ class Hardshrink(layers.Layer):
self._name = name self._name = name
def forward(self, x): def forward(self, x):
return functional.hardshrink(x, self._threshold, self._name) return F.hardshrink(x, self._threshold, self._name)
class HSigmoid(layers.Layer): class HSigmoid(layers.Layer):
...@@ -202,7 +202,7 @@ class HSigmoid(layers.Layer): ...@@ -202,7 +202,7 @@ class HSigmoid(layers.Layer):
[C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype) [C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype)
def forward(self, input, label, path_table=None, path_code=None): def forward(self, input, label, path_table=None, path_code=None):
out = functional.hsigmoid( out = F.hsigmoid(
input, input,
label, label,
self.weight, self.weight,
...@@ -253,7 +253,7 @@ class ReLU(layers.Layer): ...@@ -253,7 +253,7 @@ class ReLU(layers.Layer):
self._inplace = inplace self._inplace = inplace
def forward(self, input): def forward(self, input):
return functional.relu(input, self._inplace) return F.relu(input, self._inplace)
class LeakyReLU(layers.Layer): class LeakyReLU(layers.Layer):
...@@ -293,52 +293,47 @@ class LeakyReLU(layers.Layer): ...@@ -293,52 +293,47 @@ class LeakyReLU(layers.Layer):
self._name = name self._name = name
def forward(self, x): def forward(self, x):
return functional.leaky_relu(x, self._alpha, self._name) return F.leaky_relu(x, self._alpha, self._name)
class Sigmoid(layers.Layer): class Sigmoid(layers.Layer):
""" """
:alias_main: paddle.nn.Sigmoid this interface is used to construct a callable object of the ``Sigmoid`` class. This layer calcluate the `sigmoid` of input x.
:alias: paddle.nn.Sigmoid,paddle.nn.layer.Sigmoid,paddle.nn.layer.activation.Sigmoid
.. math::
Sigmoid Activation. output = \\frac{1}{1 + e^{-x}}
.. math: Parameters:
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
output = \frac{1}{1 + e^{-input}} Shape:
x: N-D tensor, available dtype is float16, float32, float64.
Parameters:
inplace (bool, optional): If inplace is True, the input and output
are the same variable. Otherwise, the input and output
are different variables. Default False. Note that if x is
more than one OPs' input, inplace must be False.
Returns: Returns:
None A callable object of Sigmoid.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np import numpy as np
input = fluid.data(name="input", shape=[None, 4]) import paddle
output = nn.Sigmoid()(input)
place = fluid.CPUPlace() paddle.disable_static()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32') input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32')
output_data = exe.run(feed={"input": input_data}, m = paddle.nn.Sigmoid()
fetch_list=[output]) x = paddle.to_variable(input_data)
print(output_data) # [0.7310586, 0.880797, 0.95257413, 0.98201376] output = m(x)
print(output.numpy()) # [0.7310586, 0.880797, 0.95257413, 0.98201376]
""" """
def __init__(self, inplace=False): def __init__(self, name=None):
super(Sigmoid, self).__init__() super(Sigmoid, self).__init__()
self._inplace = inplace self.name = name
def forward(self, input): def forward(self, x):
return functional.sigmoid(input, self._inplace) return F.sigmoid(x, self.name)
class LogSoftmax(layers.Layer): class LogSoftmax(layers.Layer):
...@@ -394,4 +389,4 @@ class LogSoftmax(layers.Layer): ...@@ -394,4 +389,4 @@ class LogSoftmax(layers.Layer):
self._axis = axis self._axis = axis
def forward(self, input): def forward(self, input):
return functional.log_softmax(input, self._axis) return F.log_softmax(input, self._axis)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册