From 89bb65fd148322b5fc819418894139afe728ddc0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Feb 2023 16:18:30 +0800 Subject: [PATCH] fix(dnn): fix softplus bwd kernel GitOrigin-RevId: 1f01ab5592f29ead271d02f7de15cc1c8a65df44 --- dnn/src/common/elemwise/kern_defs.cuh | 11 ++++++++++- .../test/unit/functional/test_elemwise.py | 2 +- src/opr/impl/basic_arith.cpp | 17 ++++------------- src/opr/test/basic_arith/elemwise.cpp | 8 ++++++++ .../basic_arith/elemwise_binary_trait_def.inl | 2 +- 5 files changed, 24 insertions(+), 16 deletions(-) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index a7057e7f2..e61d9a5e8 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -85,6 +85,15 @@ __device__ __host__ inline float gelu_grad(float x, float dy) { return dy * (normcdf_v + x * phi); } +//! grad of softplus +__device__ __host__ inline float softplus_grad(float x, float dy) { + float logg = -dy * expf(-fabs(x)) / (1.f + expf(-fabs(x))); + float grad0 = x > 0.f ? logg : -logg; + float relux = x < 0.f ? 0.f : x; + float grad1 = relux > 0.f ? dy : 0.f; + return grad0 + grad1; +} + __device__ __host__ inline bool feq(float a, float b) { return fabsf(a - b) < 1e-6; } @@ -287,7 +296,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f)); DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f)); DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x)); -DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x))); +DEF_KERN_FLOAT(SOFTPLUS_GRAD, softplus_grad(x, y)); DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y)); DEF_KERN_FLOAT( HSIGMOID_GRAD, diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index bedf4db18..ffafdefe9 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -397,7 +397,7 @@ def origin_softplus(inp: mge.tensor) -> mge.tensor: def test_subgraph_elemwise_mode(): def _test_allclose(func, ori_func): targets = np.array(2) - inp = np.random.randn(2, 256, 10, 16).astype("float32") + inp = np.random.uniform(size=(2, 16, 10, 16)).astype(np.float32) ori_inp = mge.tensor(inp) mge_inp = mge.tensor(inp) diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 1061dd1f1..536347ebc 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -559,21 +559,12 @@ MGB_IMPL_OPR_GRAD(Elemwise) { } case Mode::RELU6: RET(EL2(RELU6_GRAD, i0, og)); - case Mode::SOFTPLUS: { - auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0))); - auto logg = og * abse / (1 + abse); - auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg)); - RET(EL2(ADD, absg, EL2(SWITCH_GT0, EL1(RELU, i0), og))); - } + case Mode::SOFTPLUS: + RET(EL2(SOFTPLUS_GRAD, i0, og)); case Mode::HSIGMOID: RET(EL2(HSIGMOID_GRAD, i0, og)); - case Mode::LOGSIGMOID: { - og = EL1(NEGATE, og); - auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0))); - auto logg = og * abse / (1 + abse); - auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg)); - RET(EL2(SUB, absg, EL2(SWITCH_GT0, EL1(RELU, EL1(NEGATE, i0)), og))); - } + case Mode::LOGSIGMOID: + RET(EL2(SOFTPLUS_GRAD, -i0, og)); case Mode::SQRT: RET(og / EL1(SQRT, i0) / 2); case Mode::SQUARE: diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 7b66b0b5b..85b7734a1 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -77,6 +77,14 @@ float do_fuse_add_h_swish(float x, float y) { return z * fmaxf(fminf(z + 3.f, 6.f), 0.f) / 6.f; } +float do_softplus_grad(float x, float y) { + float logg = -y * expf(-fabs(x)) / (1.f + expf(-fabs(x))); + float grad0 = x > 0.f ? logg : -logg; + float relux = x < 0.f ? 0.f : x; + float grad1 = relux > 0.f ? y : 0.f; + return grad0 + grad1; +} + template T do_shl(T, T); // undefined template diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 1ed742db5..d99755506 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -61,7 +61,7 @@ DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y)) DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1)) DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1)) DEF_TRAIT(ATANH_GRAD, y / (1 - x * x)) -DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x))) +DEF_TRAIT(SOFTPLUS_GRAD, do_softplus_grad(x, y)) DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y)) DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f))) -- GitLab