提交 5eca4da3 编写于 作者: M Megvii Engine Team

feat(mge/functional): add softplus function

GitOrigin-RevId: 26c1ab74161e07089f846ddb4a610d6279d9571f
上级 855c49ca
......@@ -72,6 +72,7 @@ from .nn import (
roi_align,
roi_pooling,
softmax,
softplus,
warp_perspective,
)
from .quantized import conv_bias_activation
......
......@@ -18,7 +18,8 @@ from ..jit import barrier, mark_impure
from ..random import uniform
from ..utils.types import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
from .tensor import concat
from .elemwise import exp, log
from .tensor import concat, where
from .utils import _decide_comp_node_and_comp_graph
......@@ -267,6 +268,24 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
)
@wrap_io_tensor
def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor:
r"""
Performs the elementwise function:
.. math::
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta.
For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`.
"""
mask = beta * inp <= threshold
out = log(1 + exp(beta * inp)) / beta
out = where(mask, out, inp)
return out
@wrap_io_tensor
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
r"""
......
......@@ -439,3 +439,19 @@ def test_conv_bias():
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
def test_softplus():
x = np.arange(1000).astype(np.float32)
out = F.softplus(tensor(x))
mask = x <= 20
with np.errstate(over="ignore"):
expected = np.where(mask, np.log(1 + np.exp(x)), x)
assertTensorClose(out, expected)
beta = 2
out = F.softplus(tensor(x), beta=beta, threshold=30)
mask = beta * x <= 30
# ignore overflow
with np.errstate(over="ignore"):
expected = np.where(mask, np.log(1 + np.exp(x * beta)) / beta, x)
assertTensorClose(out, expected)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册