diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index a7057e7f2dd2400a0a84dcca59b3c465d42a9964..e61d9a5e8057f0d7765a52e62af863714768d83b 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -85,6 +85,15 @@ __device__ __host__ inline float gelu_grad(float x, float dy) { return dy * (normcdf_v + x * phi); } +//! grad of softplus +__device__ __host__ inline float softplus_grad(float x, float dy) { + float logg = -dy * expf(-fabs(x)) / (1.f + expf(-fabs(x))); + float grad0 = x > 0.f ? logg : -logg; + float relux = x < 0.f ? 0.f : x; + float grad1 = relux > 0.f ? dy : 0.f; + return grad0 + grad1; +} + __device__ __host__ inline bool feq(float a, float b) { return fabsf(a - b) < 1e-6; } @@ -287,7 +296,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f)); DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f)); DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x)); -DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x))); +DEF_KERN_FLOAT(SOFTPLUS_GRAD, softplus_grad(x, y)); DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y)); DEF_KERN_FLOAT( HSIGMOID_GRAD, diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index bedf4db186ae7df4f8e60e9871d8e5499424f083..ffafdefe9ef415386847ab02d465a33102e6b6a2 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -397,7 +397,7 @@ def origin_softplus(inp: mge.tensor) -> mge.tensor: def test_subgraph_elemwise_mode(): def _test_allclose(func, ori_func): targets = np.array(2) - inp = np.random.randn(2, 256, 10, 16).astype("float32") + inp = np.random.uniform(size=(2, 16, 10, 16)).astype(np.float32) ori_inp = mge.tensor(inp) mge_inp = mge.tensor(inp) diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 1061dd1f1029b53d64a7a44066fb27a4cc4ad361..536347ebc46c7925e41ea4e4a9c11c7fc3bd9626 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -559,21 +559,12 @@ MGB_IMPL_OPR_GRAD(Elemwise) { } case Mode::RELU6: RET(EL2(RELU6_GRAD, i0, og)); - case Mode::SOFTPLUS: { - auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0))); - auto logg = og * abse / (1 + abse); - auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg)); - RET(EL2(ADD, absg, EL2(SWITCH_GT0, EL1(RELU, i0), og))); - } + case Mode::SOFTPLUS: + RET(EL2(SOFTPLUS_GRAD, i0, og)); case Mode::HSIGMOID: RET(EL2(HSIGMOID_GRAD, i0, og)); - case Mode::LOGSIGMOID: { - og = EL1(NEGATE, og); - auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0))); - auto logg = og * abse / (1 + abse); - auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg)); - RET(EL2(SUB, absg, EL2(SWITCH_GT0, EL1(RELU, EL1(NEGATE, i0)), og))); - } + case Mode::LOGSIGMOID: + RET(EL2(SOFTPLUS_GRAD, -i0, og)); case Mode::SQRT: RET(og / EL1(SQRT, i0) / 2); case Mode::SQUARE: diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 7b66b0b5b5f48dafc50a5e60db604b53e0716535..85b7734a1654073114e786f842308fc5ef08db0d 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -77,6 +77,14 @@ float do_fuse_add_h_swish(float x, float y) { return z * fmaxf(fminf(z + 3.f, 6.f), 0.f) / 6.f; } +float do_softplus_grad(float x, float y) { + float logg = -y * expf(-fabs(x)) / (1.f + expf(-fabs(x))); + float grad0 = x > 0.f ? logg : -logg; + float relux = x < 0.f ? 0.f : x; + float grad1 = relux > 0.f ? y : 0.f; + return grad0 + grad1; +} + template T do_shl(T, T); // undefined template diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 1ed742db53dfdb7c5afa28a70c91de33f07a8927..d99755506925fcc33246c0a19ac52f659689a70c 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -61,7 +61,7 @@ DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y)) DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1)) DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1)) DEF_TRAIT(ATANH_GRAD, y / (1 - x * x)) -DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x))) +DEF_TRAIT(SOFTPLUS_GRAD, do_softplus_grad(x, y)) DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y)) DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f)))