From 67d33303adeb4075aab8ab5496c6251810a6e433 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 9 Oct 2020 20:57:12 +0800 Subject: [PATCH] feat(mge): remove fast_tanh GitOrigin-RevId: 7ad5f7ecce158c471920ffc62e700ee726029077 --- imperative/python/megengine/functional/elemwise.py | 10 ---------- imperative/python/megengine/module/elemwise.py | 2 +- .../python/test/unit/functional/test_elemwise.py | 8 -------- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 8757b7764..3b71291c8 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -33,7 +33,6 @@ __all__ = [ "equal", "exp", "expm1", - "fast_tanh", "floor", "floor_div", "greater", @@ -369,15 +368,6 @@ def atanh(x): return log1p(2 * x / (1 - x)) / 2 -def fast_tanh(x): - r"""Element-wise `fast tanh`; this is an approximation: - - .. math:: - \text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x) - """ - return _elwise(x, mode="fast_tanh") - - # bit-twiddling functions diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index 3542359d6..9bc05fbfc 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -34,7 +34,7 @@ class Elemwise(Module): * "EXP": exp(x) * "TANH": tanh(x) * "FUSE_MUL_ADD3": x * y + z - * "FAST_TANH": fast_tanh(x) + * "FAST_TANH": x * (27. + x * x) / (27. + 9. * x * x) * "NEGATE": -x * "ACOS": acos(x) * "ASIN": asin(x) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index f0b51a56f..30421dd8c 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -113,14 +113,6 @@ def test_atanh(): np.testing.assert_almost_equal(y_np, y_mge, decimal=5) -def test_fast_tanh(): - np.random.seed(42) - x = np.random.randn(100).astype("float32") - y_np = x * (27.0 + x * x) / (27.0 + 9.0 * x * x) - y_mge = F.fast_tanh(tensor(x)).numpy() - np.testing.assert_almost_equal(y_np, y_mge, decimal=6) - - def test_hswish(): np.random.seed(42) x = np.random.randn(100).astype("float32") -- GitLab