提交 f76a2cc2 编写于 作者: M Megvii Engine Team

feat(mge/opr): add silu and gelu

GitOrigin-RevId: 75aa42947e43fc86920cba11bd768dd6ff226249
上级 f2ac4c34
......@@ -10,13 +10,13 @@ MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH'],
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'],
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
}
......
......@@ -21,13 +21,13 @@ MODES = {
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH'],
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'],
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
......
......@@ -410,7 +410,11 @@ pdef('Elemwise').add_enum(
Doc('NOT', 'unary: !x'),
Doc('AND', 'binary: x && y'),
Doc('OR', 'binary: x || y'),
Doc('XOR', 'binary: x ^ y')
Doc('XOR', 'binary: x ^ y'),
Doc('SILU', 'unary: x / (1 + exp(-x))'),
Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'),
Doc('GELU', 'unary: x Phi(x)'),
Doc('GELU_GRAD', 'binary: grad(x Phi(x))'),
)
pdef('ElemwiseMultiType').add_enum(
......
......@@ -25,6 +25,8 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
......@@ -64,6 +66,8 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
......
......@@ -69,6 +69,31 @@ namespace megdnn {
return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}
//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
const float one = 1.0;
float sigmoid = one / (one + expf(-x));
return dy * sigmoid * (one + x * (one - sigmoid));
}
__device__ __host__ inline float normcdf(float x) {
#if MEGDNN_CC_HOST
return 0.5f * (1.f + erff(x / sqrtf(2.f)));
#else
//! use cuda build-in math
return ::normcdff(x);
#endif
}
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
//! 1/ sqrt(2 * pi)
const float coeff = 0.3989422804014327f;
float phi = coeff * expf(-0.5f * x * x);
float normcdf_v = normcdf(x);
return dy * (normcdf_v + x * phi);
}
#include "src/common/elemwise/each_mode.inl"
template<megcorePlatform_t plat, uint32_t mode, typename dtype>
......@@ -137,6 +162,8 @@ namespace megdnn {
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x * normcdf(x));
// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
......@@ -207,6 +234,8 @@ namespace megdnn {
x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
#undef KERN_SIG
/* ================== ternary kernels ================== */
......
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
......@@ -48,6 +48,7 @@ __all__ = [
"deformable_psroi_pooling",
"dropout",
"embedding",
"gelu",
"hsigmoid",
"hswish",
"indexing_one_hot",
......@@ -67,6 +68,7 @@ __all__ = [
"sigmoid",
"sliding_window",
"sliding_window_transpose",
"silu",
"softmax",
"softplus",
"sync_batch_norm",
......@@ -766,6 +768,25 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
return maximum(inp, 0) + negative_slope * minimum(inp, 0)
def silu(x):
r"""
Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`.
"""
return _elwise(x, mode=Elemwise.Mode.SILU)
def gelu(x):
r"""
Applies the element-wise function:
.. math::
\text{gelu}(x) = x\Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
"""
return _elwise(x, mode=Elemwise.Mode.GELU)
def softplus(inp: Tensor) -> Tensor:
r"""
Applies the element-wise function:
......
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .activation import GELU, LeakyReLU, PReLU, ReLU, Sigmoid, SiLU, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batch_matmul_activation import BatchMatMulActivation
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
......
......@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from ..functional import leaky_relu, prelu, relu, sigmoid, softmax
from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax
from ..tensor import Parameter
from .module import Module
......@@ -92,6 +92,74 @@ class Sigmoid(Module):
return sigmoid(inputs)
class SiLU(Module):
r"""
Applies the element-wise function:
.. math::
\text{SiLU}(x) = \frac{x}{1 + \exp(-x)}
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32))
silu = M.SiLU()
output = silu(data)
with np.printoptions(precision=6):
print(output.numpy())
Outputs:
.. testoutput::
[-0.238406 -0.268941 0. 0.731059 1.761594]
"""
def forward(self, inputs):
return silu(inputs)
class GELU(Module):
r"""
Applies the element-wise function:
.. math::
\text{GELU}(x) = x\Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32))
gelu = M.GELU()
output = gelu(data)
with np.printoptions(precision=4):
print(output.numpy())
Outputs:
.. testoutput::
[-0.0455 -0.1587 0. 0.8413 1.9545]
"""
def forward(self, inputs):
return gelu(inputs)
class ReLU(Module):
r"""
Applies the element-wise function:
......
......@@ -28,6 +28,8 @@ class Elemwise(Module):
* "fuse_add_sigmoid": sigmoid(x + y)
* "fuse_add_tanh": tanh(x + y)
* "relu": x > 0 ? x : 0
* "silu": silu(x)
* "gelu": gelu(x)
* "abs": x > 0 ? x : -x
* "sigmoid": sigmoid(x)
* "exp": exp(x)
......
......@@ -144,6 +144,13 @@ def test_hswish():
np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
def test_silu():
x = np.array([-1.5, 0.0, 1.0, 1.5]).astype("float32")
y_np = x / (1 + np.exp(-x))
y_mge = F.silu(tensor(x)).numpy()
np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
def test_hsigmoid():
np.random.seed(42)
x = np.random.randn(100).astype("float32")
......
......@@ -145,7 +145,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) /
6.f),
};
mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER);
mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR
return map;
......
......@@ -613,6 +613,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
RET(EL2(H_SWISH_GRAD, (i0 + i1), og));
case Mode::NOT:
return nullptr;
case Mode::SILU:
RET(EL2(SILU_GRAD, i0, og));
case Mode::GELU:
RET(EL2(GELU_GRAD, i0, og));
// binary
case Mode::ABS_GRAD:
......
......@@ -131,6 +131,12 @@ namespace {
std::numeric_limits<T>::digits));
}
float do_gelu_grad(float x, float y) {
float phi = 1.f / sqrtf(2.0 * M_PI) * expf(-0.5f * x * x);
float normcdf_v = 0.5f * (1.f + erff(x / sqrtf(2.f)));
return y * (normcdf_v + x * phi);
}
/* ======================= basic framework ======================= */
template<typename ctype, bool stable_sign = false>
......@@ -563,6 +569,9 @@ namespace {
}
};
template<> struct CheckerConfig<SILU_GRAD>: public NoGradCheckerConfig {};
template<> struct CheckerConfig<GELU_GRAD>: public NoGradCheckerConfig {};
/* ======================= ternary config ======================= */
template<> struct CheckerConfig<COND_LEQ_MOV>:
public BinaryInputMinGap<false> {};
......
......@@ -64,6 +64,10 @@ DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y))
DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y))
DEF_TRAIT(ATAN2, std::atan2(x, y))
DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y))
DEF_TRAIT(SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) /
(1 + std::exp(-x)) / (1 + std::exp(-x)))
DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
......
......@@ -56,6 +56,8 @@ DEF_TRAIT(ERFINV, do_erfinv(x))
DEF_TRAIT(ERFC, std::erfc(x))
DEF_TRAIT(ERFCINV, do_erfcinv(x))
DEF_TRAIT(H_SWISH, do_h_swish(x))
DEF_TRAIT(SILU, x / (1 + std::exp(-x)))
DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f)))))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册