diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 9f7abc0f4eb4166039e50b32c30e72204b2f28d0..58fdb3fddcb923fda5a32f086f688d9a7bf45a0e 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -72,6 +72,7 @@ from .nn import ( roi_align, roi_pooling, softmax, + softplus, warp_perspective, ) from .quantized import conv_bias_activation diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index adc616f10af0f4d7c22d004b455a469cee96516d..2e67647b3fd0b9134963bb42a4a63c36658bfbb2 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -18,7 +18,8 @@ from ..jit import barrier, mark_impure from ..random import uniform from ..utils.types import _pair, _pair_nonzero from .debug_param import get_conv_execution_strategy -from .tensor import concat +from .elemwise import exp, log +from .tensor import concat, where from .utils import _decide_comp_node_and_comp_graph @@ -267,6 +268,24 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: ) +@wrap_io_tensor +def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: + r""" + Performs the elementwise function: + + .. math:: + + \mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. + + For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. + + """ + mask = beta * inp <= threshold + out = log(1 + exp(beta * inp)) / beta + out = where(mask, out, inp) + return out + + @wrap_io_tensor def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: r""" diff --git a/python_module/test/unit/functional/test_functional.py b/python_module/test/unit/functional/test_functional.py index ac0728bafe28b24efab0ad8b59877249dde31626..19369aadbab49e2945cfd605bb233f4a2068b419 100644 --- a/python_module/test/unit/functional/test_functional.py +++ b/python_module/test/unit/functional/test_functional.py @@ -439,3 +439,19 @@ def test_conv_bias(): run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") + + +def test_softplus(): + x = np.arange(1000).astype(np.float32) + out = F.softplus(tensor(x)) + mask = x <= 20 + with np.errstate(over="ignore"): + expected = np.where(mask, np.log(1 + np.exp(x)), x) + assertTensorClose(out, expected) + beta = 2 + out = F.softplus(tensor(x), beta=beta, threshold=30) + mask = beta * x <= 30 + # ignore overflow + with np.errstate(over="ignore"): + expected = np.where(mask, np.log(1 + np.exp(x * beta)) / beta, x) + assertTensorClose(out, expected)