未验证 提交 4dfc375a 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

add elu module (#4924)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 8125b43f
......@@ -9,6 +9,7 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.Tanh
.. autofunction:: oneflow.experimental.tanh
.. autofunction:: oneflow.experimental.Tensor.tanh
.. autofunction:: oneflow.experimental.nn.ELU
.. autofunction:: oneflow.experimental.nn.GELU
.. autofunction:: oneflow.experimental.gelu
.. autofunction:: oneflow.experimental.Tensor.gelu
......
......@@ -208,6 +208,55 @@ def tanh_op(x):
return Tanh()(x)
@oneflow_export("nn.ELU")
@experimental_api
class ELU(Module):
r"""Applies the element-wise function:
.. math::
\text{ELU}(x) = \begin{cases}
x & \text{ if } x \gt 0 \\
\alpha*(exp(x)-1) & \text{ if } x \le 0 \\
\end{cases}
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
For example:
.. code-block:: python
import oneflow.experimental as flow
m = flow.nn.ELU()
input = flow.randn(2)
output = m(input)
"""
def __init__(self, alpha: float = 1.0, inplace: bool = False):
super().__init__()
assert inplace == False, f"ELU not support inplace equal true now!"
self._op = (
flow.builtin_op("elu")
.Input("in")
.Attr("alpha", alpha)
.Output("out")
.Build()
)
def forward(self, x):
res = self._op(x)[0]
return res
@oneflow_export("nn.GELU")
@experimental_api
class GELU(Module):
......
......@@ -99,6 +99,28 @@ class TestTanhModule(flow.unittest.TestCase):
test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestELUModule(flow.unittest.TestCase):
def test_elu(test_case):
m = flow.nn.ELU()
arr = np.random.randn(2, 3, 4, 5)
np_out = np.where(arr > 0, arr, 1.0 * (np.exp(arr) - 1))
x = flow.Tensor(arr)
of_out = m(x)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=1e-4, atol=1e-4))
def test_elu_alpha(test_case):
m = flow.nn.ELU(alpha=1.2)
arr = np.random.randn(2, 3, 4, 5)
np_out = np.where(arr > 0, arr, 1.2 * (np.exp(arr) - 1))
x = flow.Tensor(arr)
of_out = m(x)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=1e-4, atol=1e-4))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册