提交 f76a2cc2 编写于 作者: M Megvii Engine Team

feat(mge/opr): add silu and gelu

GitOrigin-RevId: 75aa42947e43fc86920cba11bd768dd6ff226249
上级 f2ac4c34
...@@ -10,13 +10,13 @@ MODES = { ...@@ -10,13 +10,13 @@ MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH'], 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'], 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
} }
......
...@@ -21,13 +21,13 @@ MODES = { ...@@ -21,13 +21,13 @@ MODES = {
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH'], 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH'], 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'], (1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
......
...@@ -410,7 +410,11 @@ pdef('Elemwise').add_enum( ...@@ -410,7 +410,11 @@ pdef('Elemwise').add_enum(
Doc('NOT', 'unary: !x'), Doc('NOT', 'unary: !x'),
Doc('AND', 'binary: x && y'), Doc('AND', 'binary: x && y'),
Doc('OR', 'binary: x || y'), Doc('OR', 'binary: x || y'),
Doc('XOR', 'binary: x ^ y') Doc('XOR', 'binary: x ^ y'),
Doc('SILU', 'unary: x / (1 + exp(-x))'),
Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'),
Doc('GELU', 'unary: x Phi(x)'),
Doc('GELU_GRAD', 'binary: grad(x Phi(x))'),
) )
pdef('ElemwiseMultiType').add_enum( pdef('ElemwiseMultiType').add_enum(
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
...@@ -64,6 +66,8 @@ ...@@ -64,6 +66,8 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
......
...@@ -69,6 +69,31 @@ namespace megdnn { ...@@ -69,6 +69,31 @@ namespace megdnn {
return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
} }
//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
const float one = 1.0;
float sigmoid = one / (one + expf(-x));
return dy * sigmoid * (one + x * (one - sigmoid));
}
__device__ __host__ inline float normcdf(float x) {
#if MEGDNN_CC_HOST
return 0.5f * (1.f + erff(x / sqrtf(2.f)));
#else
//! use cuda build-in math
return ::normcdff(x);
#endif
}
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
//! 1/ sqrt(2 * pi)
const float coeff = 0.3989422804014327f;
float phi = coeff * expf(-0.5f * x * x);
float normcdf_v = normcdf(x);
return dy * (normcdf_v + x * phi);
}
#include "src/common/elemwise/each_mode.inl" #include "src/common/elemwise/each_mode.inl"
template<megcorePlatform_t plat, uint32_t mode, typename dtype> template<megcorePlatform_t plat, uint32_t mode, typename dtype>
...@@ -137,6 +162,8 @@ namespace megdnn { ...@@ -137,6 +162,8 @@ namespace megdnn {
DEF_KERN_FLOAT(ERFC, erfcf(x)); DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x * normcdf(x));
// int only // int only
DEF_KERN(dt_bool, NOT, x ^ 1); DEF_KERN(dt_bool, NOT, x ^ 1);
...@@ -207,6 +234,8 @@ namespace megdnn { ...@@ -207,6 +234,8 @@ namespace megdnn {
x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
#undef KERN_SIG #undef KERN_SIG
/* ================== ternary kernels ================== */ /* ================== ternary kernels ================== */
......
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
...@@ -48,6 +48,7 @@ __all__ = [ ...@@ -48,6 +48,7 @@ __all__ = [
"deformable_psroi_pooling", "deformable_psroi_pooling",
"dropout", "dropout",
"embedding", "embedding",
"gelu",
"hsigmoid", "hsigmoid",
"hswish", "hswish",
"indexing_one_hot", "indexing_one_hot",
...@@ -67,6 +68,7 @@ __all__ = [ ...@@ -67,6 +68,7 @@ __all__ = [
"sigmoid", "sigmoid",
"sliding_window", "sliding_window",
"sliding_window_transpose", "sliding_window_transpose",
"silu",
"softmax", "softmax",
"softplus", "softplus",
"sync_batch_norm", "sync_batch_norm",
...@@ -766,6 +768,25 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: ...@@ -766,6 +768,25 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
return maximum(inp, 0) + negative_slope * minimum(inp, 0) return maximum(inp, 0) + negative_slope * minimum(inp, 0)
def silu(x):
r"""
Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`.
"""
return _elwise(x, mode=Elemwise.Mode.SILU)
def gelu(x):
r"""
Applies the element-wise function:
.. math::
\text{gelu}(x) = x\Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
"""
return _elwise(x, mode=Elemwise.Mode.GELU)
def softplus(inp: Tensor) -> Tensor: def softplus(inp: Tensor) -> Tensor:
r""" r"""
Applies the element-wise function: Applies the element-wise function:
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .activation import GELU, LeakyReLU, PReLU, ReLU, Sigmoid, SiLU, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batch_matmul_activation import BatchMatMulActivation from .batch_matmul_activation import BatchMatMulActivation
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
from ..functional import leaky_relu, prelu, relu, sigmoid, softmax from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax
from ..tensor import Parameter from ..tensor import Parameter
from .module import Module from .module import Module
...@@ -92,6 +92,74 @@ class Sigmoid(Module): ...@@ -92,6 +92,74 @@ class Sigmoid(Module):
return sigmoid(inputs) return sigmoid(inputs)
class SiLU(Module):
r"""
Applies the element-wise function:
.. math::
\text{SiLU}(x) = \frac{x}{1 + \exp(-x)}
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32))
silu = M.SiLU()
output = silu(data)
with np.printoptions(precision=6):
print(output.numpy())
Outputs:
.. testoutput::
[-0.238406 -0.268941 0. 0.731059 1.761594]
"""
def forward(self, inputs):
return silu(inputs)
class GELU(Module):
r"""
Applies the element-wise function:
.. math::
\text{GELU}(x) = x\Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32))
gelu = M.GELU()
output = gelu(data)
with np.printoptions(precision=4):
print(output.numpy())
Outputs:
.. testoutput::
[-0.0455 -0.1587 0. 0.8413 1.9545]
"""
def forward(self, inputs):
return gelu(inputs)
class ReLU(Module): class ReLU(Module):
r""" r"""
Applies the element-wise function: Applies the element-wise function:
......
...@@ -28,6 +28,8 @@ class Elemwise(Module): ...@@ -28,6 +28,8 @@ class Elemwise(Module):
* "fuse_add_sigmoid": sigmoid(x + y) * "fuse_add_sigmoid": sigmoid(x + y)
* "fuse_add_tanh": tanh(x + y) * "fuse_add_tanh": tanh(x + y)
* "relu": x > 0 ? x : 0 * "relu": x > 0 ? x : 0
* "silu": silu(x)
* "gelu": gelu(x)
* "abs": x > 0 ? x : -x * "abs": x > 0 ? x : -x
* "sigmoid": sigmoid(x) * "sigmoid": sigmoid(x)
* "exp": exp(x) * "exp": exp(x)
......
...@@ -144,6 +144,13 @@ def test_hswish(): ...@@ -144,6 +144,13 @@ def test_hswish():
np.testing.assert_almost_equal(y_np, y_mge, decimal=6) np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
def test_silu():
x = np.array([-1.5, 0.0, 1.0, 1.5]).astype("float32")
y_np = x / (1 + np.exp(-x))
y_mge = F.silu(tensor(x)).numpy()
np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
def test_hsigmoid(): def test_hsigmoid():
np.random.seed(42) np.random.seed(42)
x = np.random.randn(100).astype("float32") x = np.random.randn(100).astype("float32")
......
...@@ -145,7 +145,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { ...@@ -145,7 +145,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) / 0.f}) /
6.f), 6.f),
}; };
mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER); mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR // ERFINV, ERFCINV, NOT, AND, OR, XOR
return map; return map;
......
...@@ -613,6 +613,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { ...@@ -613,6 +613,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); RET(EL2(H_SWISH_GRAD, (i0 + i1), og));
case Mode::NOT: case Mode::NOT:
return nullptr; return nullptr;
case Mode::SILU:
RET(EL2(SILU_GRAD, i0, og));
case Mode::GELU:
RET(EL2(GELU_GRAD, i0, og));
// binary // binary
case Mode::ABS_GRAD: case Mode::ABS_GRAD:
......
...@@ -131,6 +131,12 @@ namespace { ...@@ -131,6 +131,12 @@ namespace {
std::numeric_limits<T>::digits)); std::numeric_limits<T>::digits));
} }
float do_gelu_grad(float x, float y) {
float phi = 1.f / sqrtf(2.0 * M_PI) * expf(-0.5f * x * x);
float normcdf_v = 0.5f * (1.f + erff(x / sqrtf(2.f)));
return y * (normcdf_v + x * phi);
}
/* ======================= basic framework ======================= */ /* ======================= basic framework ======================= */
template<typename ctype, bool stable_sign = false> template<typename ctype, bool stable_sign = false>
...@@ -563,6 +569,9 @@ namespace { ...@@ -563,6 +569,9 @@ namespace {
} }
}; };
template<> struct CheckerConfig<SILU_GRAD>: public NoGradCheckerConfig {};
template<> struct CheckerConfig<GELU_GRAD>: public NoGradCheckerConfig {};
/* ======================= ternary config ======================= */ /* ======================= ternary config ======================= */
template<> struct CheckerConfig<COND_LEQ_MOV>: template<> struct CheckerConfig<COND_LEQ_MOV>:
public BinaryInputMinGap<false> {}; public BinaryInputMinGap<false> {};
......
...@@ -64,6 +64,10 @@ DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y)) ...@@ -64,6 +64,10 @@ DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y))
DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y)) DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y))
DEF_TRAIT(ATAN2, std::atan2(x, y)) DEF_TRAIT(ATAN2, std::atan2(x, y))
DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y)) DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y))
DEF_TRAIT(SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) /
(1 + std::exp(-x)) / (1 + std::exp(-x)))
DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))
#undef _ALLOW_INT #undef _ALLOW_INT
#undef _ALLOW_FLOAT #undef _ALLOW_FLOAT
......
...@@ -56,6 +56,8 @@ DEF_TRAIT(ERFINV, do_erfinv(x)) ...@@ -56,6 +56,8 @@ DEF_TRAIT(ERFINV, do_erfinv(x))
DEF_TRAIT(ERFC, std::erfc(x)) DEF_TRAIT(ERFC, std::erfc(x))
DEF_TRAIT(ERFCINV, do_erfcinv(x)) DEF_TRAIT(ERFCINV, do_erfcinv(x))
DEF_TRAIT(H_SWISH, do_h_swish(x)) DEF_TRAIT(H_SWISH, do_h_swish(x))
DEF_TRAIT(SILU, x / (1 + std::exp(-x)))
DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f)))))
#undef _ALLOW_INT #undef _ALLOW_INT
#undef _ALLOW_FLOAT #undef _ALLOW_FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册