提交 87aedc29 编写于 作者: M Megvii Engine Team

feat(dnn): add elemwise modes

GitOrigin-RevId: cb713ddb24e09b9426eb12c5a9767c308d65fd60
上级 25e89d68
......@@ -14,14 +14,18 @@ MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
'ERFCINV', 'H_SWISH', 'SILU', 'GELU', 'SINH', 'COSH',
'ASINH', 'ACOSH', 'ATANH', 'TAN', 'SOFTPLUS', 'RELU6',
'HSIGMOID', 'LOGSIGMOID', 'SQRT', 'SQUARE', 'SIGN'],
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU',
'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD',
'RELU6_GRAD', 'HSIGMOID_GRAD'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'],
}
QINT4_MODES = {
......@@ -29,8 +33,8 @@ QINT4_MODES = {
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'],
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0',
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH',
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH', 'PRELU'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP'],
}
QINT32_MODES = {
......
......@@ -12,23 +12,27 @@ DTYPES = {'dt_int32': ('Int32', 'INT'),
}
MODES = {
(1, 'INT'): ['RELU', 'ABS', 'NEGATE'],
(1, 'INT'): ['RELU', 'ABS', 'NEGATE', 'RELU6', 'SQUARE', 'SIGN'],
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ',
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'],
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'],
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH', 'PRELU'],
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'CLIP'],
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
'ERFCINV', 'H_SWISH', 'SILU', 'GELU', 'SINH', 'COSH',
'ASINH', 'ACOSH', 'ATANH', 'TAN', 'SOFTPLUS', 'RELU6',
'HSIGMOID', 'LOGSIGMOID', 'SQRT', 'SQUARE', 'SIGN'],
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU',
'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD',
'RELU6_GRAD', 'HSIGMOID_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
(3, 'BOOL'): []
......
......@@ -421,9 +421,31 @@ pdef('Elemwise').add_enum(
Doc('GELU = 58', 'unary: x Phi(x)'),
Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'),
Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'),
Doc('NEQ = 61', 'binary: x != y'),
Doc('ISNAN = 62', 'unary: isnan(x)'),
Doc('ISINF = 63', 'unary: isinf(x)'),
Doc('SINH = 61', 'unary: sinh(x)'),
Doc('COSH = 62', 'unary: cosh(x)'),
Doc('ASINH = 63', 'unary: asinh(x)'),
Doc('ACOSH = 64', 'unary: acosh(x)'),
Doc('ATANH = 65', 'unary: atanh(x)'),
Doc('TAN = 66', 'unary: tan(x)'),
Doc('ASINH_GRAD = 67', 'binary: y / sqrt(x^2 + 1)'),
Doc('ACOSH_GRAD = 68', 'binary: y / sqrt(x^2 - 1) (x > 1)'),
Doc('ATANH_GRAD = 69', 'binary: y / (1 - x^2) (|x| < 1)'),
Doc('PRELU = 70', 'binary: x > 0 ? x : x * y'),
Doc('CLIP = 71', 'ternary: x <= y ? y : (x <= z ? x : z)'),
Doc('PRELU_GRAD = 72', 'ternary: x > 0 ? y : y * z'),
Doc('SOFTPLUS = 73', 'unary: log(1 + e^x)'),
Doc('SOFTPLUS_GRAD = 74', 'binary: y * e^x / (1 + e^x)'),
Doc('RELU6 = 75', 'unary: min(max(0, x), 6)'),
Doc('RELU6_GRAD = 76', 'binary: x < 0 ? 0 : (x > 6 ? 0 : y)'),
Doc('HSIGMOID = 77', 'unary: relu6(x + 3) / 6'),
Doc('HSIGMOID_GRAD = 78', 'binary: x < -3 ? 0 : (x > 3 ? 0 : y / 6)'),
Doc('LOGSIGMOID = 79', 'unary: -log(1 + e^(-x))'),
Doc('SQRT = 80', 'unary: x^(1/2)'),
Doc('SQUARE = 81', 'unary: x^2'),
Doc('SIGN = 82', 'unary: sgn(x)'),
Doc('NEQ = 83', 'binary: x != y'),
Doc('ISNAN = 84', 'unary: isnan(x)'),
Doc('ISINF = 85', 'unary: isinf(x)'),
)
pdef('ElemwiseMultiType').add_enum(
......
......@@ -25,12 +25,28 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \
......@@ -66,7 +82,14 @@
MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
......@@ -86,15 +109,19 @@
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
......@@ -154,11 +154,18 @@ struct ElemwiseKern;
// int and float
DEF_KERN_ALL(NEGATE, -x);
DEF_KERN_ALL(SQUARE, x* x);
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6)));
DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f));
#else
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
#endif
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
......@@ -186,6 +193,18 @@ DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));
DEF_KERN_FLOAT(SINH, sinhf(x));
DEF_KERN_FLOAT(COSH, coshf(x));
DEF_KERN_FLOAT(ASINH, asinhf(x));
DEF_KERN_FLOAT(ACOSH, acoshf(x));
DEF_KERN_FLOAT(ATANH, atanhf(x));
DEF_KERN_FLOAT(TAN, tanf(x));
DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x));
DEF_KERN_FLOAT(
HSIGMOID,
x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f)));
DEF_KERN_FLOAT(SQRT, sqrtf(x));
DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x));
// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
......@@ -240,6 +259,12 @@ DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
#else
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
#endif
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y));
DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y));
#else
DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y));
#endif
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
......@@ -259,6 +284,14 @@ DEF_KERN_FLOAT(
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f));
DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f));
DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x));
DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x)));
DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y));
DEF_KERN_FLOAT(
HSIGMOID_GRAD,
x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f)));
#undef KERN_SIG
/* ================== ternary kernels ================== */
......@@ -268,6 +301,8 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0));
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z));
DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z));
#undef KERN_SIG
......
......@@ -220,6 +220,28 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
CB_MODE(Mode::GELU);
CB_MODE(Mode::GELU_GRAD);
CB_MODE(Mode::COND_LT_MOV);
CB_MODE(Mode::SINH);
CB_MODE(Mode::COSH);
CB_MODE(Mode::ASINH);
CB_MODE(Mode::ACOSH);
CB_MODE(Mode::ATANH);
CB_MODE(Mode::TAN);
CB_MODE(Mode::ASINH_GRAD);
CB_MODE(Mode::ACOSH_GRAD);
CB_MODE(Mode::ATANH_GRAD);
CB_MODE(Mode::PRELU);
CB_MODE(Mode::PRELU_GRAD);
CB_MODE(Mode::CLIP);
CB_MODE(Mode::SOFTPLUS);
CB_MODE(Mode::SOFTPLUS_GRAD);
CB_MODE(Mode::RELU6);
CB_MODE(Mode::RELU6_GRAD);
CB_MODE(Mode::HSIGMOID);
CB_MODE(Mode::HSIGMOID_GRAD);
CB_MODE(Mode::LOGSIGMOID);
CB_MODE(Mode::SQRT);
CB_MODE(Mode::SQUARE);
CB_MODE(Mode::SIGN);
default:
megdnn_assert(
0,
......
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
......@@ -267,7 +267,10 @@ IMPL_MODE_DISPATCHER(2, dt_qint4, dt_qint4);
IMPL_MODE_DISPATCHER(2, dt_quint4, dt_quint4);
#undef FOREACH
#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT
#define FOREACH(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
IMPL_MODE_DISPATCHER(3, dt_qint4, dt_qint4);
IMPL_MODE_DISPATCHER(3, dt_quint4, dt_quint4);
#undef FOREACH
......
......@@ -228,6 +228,7 @@ INST(Mode::SHL);
INST(Mode::SHR);
INST(Mode::FUSE_ADD_RELU);
INST(Mode::RMULH);
INST(Mode::PRELU);
#undef INST
#define INST(mode) \
......@@ -258,6 +259,13 @@ INST(Mode::H_SWISH_GRAD);
INST(Mode::FUSE_ADD_H_SWISH);
INST(Mode::SILU_GRAD);
INST(Mode::GELU_GRAD);
INST(Mode::PRELU);
INST(Mode::ASINH_GRAD);
INST(Mode::ACOSH_GRAD);
INST(Mode::ATANH_GRAD);
INST(Mode::SOFTPLUS_GRAD);
INST(Mode::RELU6_GRAD);
INST(Mode::HSIGMOID_GRAD);
#undef INST
} // namespace fallback
} // namespace megdnn
......
......@@ -77,6 +77,9 @@ using Mode = param_enumv::Elemwise::Mode;
INST(Mode::RELU);
INST(Mode::ABS);
INST(Mode::NEGATE);
INST(Mode::RELU6);
INST(Mode::SQUARE);
INST(Mode::SIGN);
#undef INST
#define INST(mode) \
......@@ -105,6 +108,19 @@ INST(Mode::ERFCINV);
INST(Mode::H_SWISH);
INST(Mode::SILU);
INST(Mode::GELU);
INST(Mode::SINH);
INST(Mode::COSH);
INST(Mode::ASINH);
INST(Mode::ACOSH);
INST(Mode::ATANH);
INST(Mode::TAN);
INST(Mode::SOFTPLUS);
INST(Mode::RELU6);
INST(Mode::HSIGMOID);
INST(Mode::LOGSIGMOID);
INST(Mode::SQRT);
INST(Mode::SQUARE);
INST(Mode::SIGN);
#undef INST
} // namespace fallback
} // namespace megdnn
......
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
......@@ -744,8 +744,8 @@ DEF_TEST(all_modes) {
TensorShapeArray shapes;
UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
tanh_rng_f32{-5.f, 5.f};
abslt1_rng_f32{-0.95f, 0.95f}, uniform_0_2_rng{0.f, 2.f},
tanh_rng_f32{-5.f, 5.f}, lt1_rng_f32{1.f, 10.f};
UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
big_nonzero_rng_f32{100.f, 1000.f};
UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
......@@ -786,12 +786,14 @@ DEF_TEST(all_modes) {
shapes[shapes.size() - 1] = {};
auto do_run = [&](DType dtype, float eps = 1e-3) {
// limit value ranges for some modes
if (mode == Mode::LOG || mode == Mode::LOG1P) {
if (mode == Mode::LOG || mode == Mode::LOG1P || mode == Mode::SQRT) {
checker.set_rng(0, &pos_rng_f32);
} else if (mode == Mode::POW) {
} else if (mode == Mode::POW || mode == Mode::SOFTPLUS_GRAD) {
checker.set_rng(0, &small_pos_rng_f32);
checker.set_rng(1, &small_rng_f32);
} else if (mode == Mode::EXP || mode == Mode::EXPM1) {
} else if (
mode == Mode::EXP || mode == Mode::EXPM1 || mode == Mode::SINH ||
mode == Mode::COSH) {
checker.set_rng(0, &small_rng_f32);
} else if (mode == Mode::FAST_TANH) {
checker.set_rng(0, &tanh_rng_f32);
......@@ -807,6 +809,10 @@ DEF_TEST(all_modes) {
checker.set_rng(1, &default_rng_f32);
} else if (mode == Mode::ERFCINV) {
checker.set_rng(0, &uniform_0_2_rng);
} else if (mode == Mode::ACOSH_GRAD || mode == Mode::ACOSH) {
checker.set_rng(0, &lt1_rng_f32);
} else if (mode == Mode::ATANH_GRAD || mode == Mode::ATANH) {
checker.set_rng(0, &abslt1_rng_f32);
} else if (
mode == Mode::MOD || mode == Mode::TRUE_DIV ||
mode == Mode::FLOOR_DIV) {
......
......@@ -467,12 +467,12 @@ def log1p(x):
def sqrt(x: Tensor) -> Tensor:
r"""Element-wise `sqrt`."""
return x ** 0.5
return _elwise(x, mode=Elemwise.Mode.SQRT)
def square(x: Tensor) -> Tensor:
r"""Element-wise `square`."""
return x ** 2
return _elwise(x, mode=Elemwise.Mode.SQUARE)
def round(x):
......@@ -515,7 +515,7 @@ def sin(x):
def tan(x):
r"""Element-wise `tangent`."""
return sin(x) / cos(x)
return _elwise(x, mode=Elemwise.Mode.TAN)
def acos(x):
......@@ -544,13 +544,12 @@ def atan2(y, x):
def cosh(x):
r"""Element-wise `hyperbolic cosine`."""
return 0.5 * (exp(x) + exp(-x))
return _elwise(x, mode=Elemwise.Mode.COSH)
def sinh(x):
r"""Element-wise `hyperbolic sine`."""
u = expm1(x)
return 0.5 * u / (u + 1) * (u + 2)
return _elwise(x, mode=Elemwise.Mode.SINH)
def tanh(x):
......@@ -560,17 +559,17 @@ def tanh(x):
def asinh(x):
r"""Element-wise `inverse hyperbolic sine`."""
return log(x + (x ** 2 + 1) ** 0.5)
return _elwise(x, mode=Elemwise.Mode.ASINH)
def acosh(x):
r"""Element-wise `inverse hyperbolic cosine`."""
return log(x + (x ** 2 - 1) ** 0.5)
return _elwise(x, mode=Elemwise.Mode.ACOSH)
def atanh(x):
r"""Element-wise `inverse hyperbolic tangent`."""
return log1p(2 * x / (1 - x)) / 2
return _elwise(x, mode=Elemwise.Mode.ATANH)
# bit-twiddling functions
......@@ -680,7 +679,7 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
return minimum(maximum(x, lower), upper)
return _elwise(x, lower, upper, mode=Elemwise.Mode.CLIP)
else:
return maximum(x, lower)
else:
......
......@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import Const, apply
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.tensor.array_method import _matmul
from ..core.tensor.array_method import _elwise, _matmul
from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
......@@ -86,7 +86,7 @@ def sign(inp: Tensor):
>>> F.sign(x)
Tensor([ 1 -1 0], dtype=int32, device=xpux:0)
"""
return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
return _elwise(inp, mode=builtin.Elemwise.Mode.SIGN)
def sum(
......
......@@ -753,37 +753,9 @@ def sigmoid(x):
return _elwise(x, mode=Elemwise.Mode.SIGMOID)
@lru_cache(maxsize=None)
def _get_hsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"Hsigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def hsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
inp = f("+", inp, c(3))
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = f("/", min_6, c(6))
(oup_grad,) = yield (oup,)
inp_grad = f("/", oup_grad, c(6))
inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)
return hsigmoid
def hsigmoid(x):
r"""Element-wise `relu6(x + 3) / 6`."""
hsigmoid = _get_hsigmoid_op(x.dtype, x.device)
(x,) = hsigmoid(x)
return x
# return relu6(x + 3) / 6
return _elwise(x, mode=Elemwise.Mode.HSIGMOID)
def relu(x):
......@@ -791,95 +763,14 @@ def relu(x):
return _elwise(x, mode=Elemwise.Mode.RELU)
@lru_cache(maxsize=None)
def _get_relu6_op(dtype=None, device=None):
@subgraph_fn(
"ReLU6",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def relu6(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = min_6
(oup_grad,) = yield (oup,)
inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)
return relu6
def relu6(x):
r"""Element-wise `min(max(x, 0), 6)`."""
relu6 = _get_relu6_op(x.dtype, x.device)
(x,) = relu6(x)
return x
@lru_cache(maxsize=None)
def _get_prelu_op(dtype=None, device=None):
@subgraph_fn(
"PReLU",
dtype=dtype,
device=device,
nr_inputs=2,
jit_fusion=True,
custom_grad=True,
)
def prelu(inputs, f, c):
(inp, weight) = inputs[0:2]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("fma3", min_0, weight, max_0)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, weight)
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
weight_grad = f("*", oup_grad, min_0)
yield (inp_grad, weight_grad)
return prelu
def prelu(inp: Tensor, weight: Tensor) -> Tensor:
r"""Element-wise PReLU function.
Refer to :class:`~.PReLU` for more information.
"""
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
(oup,) = prelu(inp, broadcast_to(weight, inp.shape))
return oup
return _elwise(x, mode=Elemwise.Mode.RELU6)
@lru_cache(maxsize=None)
def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None):
@subgraph_fn(
"LeakyReLU",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def leakyReLU(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("+", max_0, f("*", min_0, c(negative_slope)))
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, c(negative_slope))
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return leakyReLU
def prelu(x, y):
r"""Element-wise `max(x, 0) + y * min(x, 0)`."""
return _elwise(x, y, mode=Elemwise.Mode.PRELU)
def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
......@@ -887,9 +778,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
Refer to :class:`~.LeakyReLU` for more information.
"""
leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
(oup,) = leakyReLU(inp)
return oup
return _elwise(inp, negative_slope, mode=Elemwise.Mode.PRELU)
def silu(x):
......@@ -908,36 +797,6 @@ def gelu(x):
return _elwise(x, mode=Elemwise.Mode.GELU)
@lru_cache(maxsize=None)
def _get_softplus_op(dtype=None, device=None):
@subgraph_fn(
"Softplus",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def softplus(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup0 = f("log1p", exp)
oup1 = f("relu", inp)
oup = f("+", oup0, oup1)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return softplus
def softplus(inp: Tensor) -> Tensor:
r"""Applies the element-wise function:
......@@ -960,9 +819,7 @@ def softplus(inp: Tensor) -> Tensor:
>>> y.numpy().round(decimals=4)
array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32)
"""
softplus = _get_softplus_op(inp.dtype, inp.device)
(oup,) = softplus(inp)
return oup
return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS)
def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
......@@ -991,39 +848,6 @@ def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return inp - logsumexp(inp, axis, keepdims=True)
@lru_cache(maxsize=None)
def _get_logsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"LogSigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def logsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup0 = f("log1p", exp)
oup1 = f("relu", f("-", inp))
oup = f("+", oup0, oup1)
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return logsigmoid
def logsigmoid(inp: Tensor) -> Tensor:
r"""Applies the element-wise function:
......@@ -1041,9 +865,7 @@ def logsigmoid(inp: Tensor) -> Tensor:
array([-5.0067, -4.0182, -3.0486, -2.1269, -1.3133, -0.6931, -0.3133,
-0.1269, -0.0486, -0.0181], dtype=float32)
"""
logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device)
(oup,) = logsigmoid(inp)
return oup
return _elwise(inp, mode=Elemwise.Mode.LOGSIGMOID)
def logsumexp(
......
......@@ -122,6 +122,11 @@ ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) {
Elemwise::Mode::ACOS, Elemwise::Mode::ASIN,
Elemwise::Mode::ATAN2, Elemwise::Mode::COS,
Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP,
Elemwise::Mode::TAN, Elemwise::Mode::ASINH,
Elemwise::Mode::ACOSH, Elemwise::Mode::ATANH,
Elemwise::Mode::SINH, Elemwise::Mode::COSH,
Elemwise::Mode::SOFTPLUS, Elemwise::Mode::HSIGMOID,
Elemwise::Mode::LOGSIGMOID, Elemwise::Mode::SQRT,
};
static std::unordered_set<Elemwise::Mode> cast_case2 = {
......
......@@ -133,7 +133,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) /
6.f),
};
mgb_assert(map.size() + 19 == opr::Elemwise::Param::MODE_NR_MEMBER);
mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
return map;
......
......@@ -543,6 +543,34 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
RET(EL2(SILU_GRAD, i0, og));
case Mode::GELU:
RET(EL2(GELU_GRAD, i0, og));
case Mode::SINH:
RET(EL1(COSH, i0) * og);
case Mode::COSH:
RET(EL1(SINH, i0) * og);
case Mode::ASINH:
RET(EL2(ASINH_GRAD, i0, og));
case Mode::ACOSH:
RET(EL2(ACOSH_GRAD, i0, og));
case Mode::ATANH:
RET(EL2(ATANH_GRAD, i0, og));
case Mode::TAN: {
auto two = i0.make_scalar_dt(2);
RET(og / (EL2(POW, EL1(COS, i0), two)));
}
case Mode::RELU6:
RET(EL2(RELU6_GRAD, i0, og));
case Mode::SOFTPLUS:
RET(EL2(SOFTPLUS_GRAD, i0, og));
case Mode::HSIGMOID:
RET(EL2(HSIGMOID_GRAD, i0, og));
case Mode::LOGSIGMOID:
RET(EL2(SOFTPLUS_GRAD, EL1(NEGATE, i0), og));
case Mode::SQRT:
RET(og / EL1(SQRT, i0) / 2);
case Mode::SQUARE:
RET(og * 2 * i0);
case Mode::SIGN:
RET(i0.make_scalar_dt(0).broadcast(i0.symshape()));
// binary
case Mode::ABS_GRAD:
......@@ -617,6 +645,11 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
case Mode::XOR:
case Mode::AND:
return nullptr;
case Mode::PRELU:
if (wrt_idx == 0) {
RET(EL3(PRELU_GRAD, i0, og, i1));
}
RET(EL2(SWITCH_GT0, -i0, og * i0));
// ternary
case Mode::COND_LEQ_MOV:
......@@ -627,6 +660,15 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
if (wrt_idx <= 1)
return nullptr;
RET(EL3(COND_LT_MOV, i0, i1, og));
case Mode::CLIP:
if (wrt_idx == 0) {
RET(EL3(COND_LEQ_MOV, i1, i0, EL3(COND_LEQ_MOV, i0, i2, og)));
}
if (wrt_idx == 1) {
RET(EL3(COND_LEQ_MOV, i0, i1, og));
}
RET(EL3(COND_LEQ_MOV, i2, i0, og));
// fuse oprs
case Mode::FUSE_MUL_ADD3:
if (wrt_idx < 2) {
......
......@@ -349,6 +349,99 @@ struct CheckerConfig<H_SWISH> : public CheckerConfig<void> {};
template <>
struct CheckerConfig<H_SWISH_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<TAN> : public NoGradCheckerConfig {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(-1.2, 1.2);
}
};
template <>
struct CheckerConfig<SINH> : public CheckerConfig<void> {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(-5, 5);
}
template <class Opt>
static void update_opt(Opt& opt) {
opt.numdiff_eps = 1e-2;
opt.numdiff_max_err = 0.1;
}
};
template <>
struct CheckerConfig<COSH> : public CheckerConfig<SINH> {};
template <>
struct CheckerConfig<ASINH> : public CheckerConfig<void> {
template <class Opt>
static void update_opt(Opt& opt) {
opt.numdiff_eps = 1e-2;
opt.numdiff_max_err = 0.1;
}
};
template <>
struct CheckerConfig<ACOSH> : public CheckerConfig<ASINH> {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(1.05, 5);
}
};
template <>
struct CheckerConfig<ATANH> : public CheckerConfig<ASINH> {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(-0.95, 0.95);
}
};
template <>
struct CheckerConfig<SOFTPLUS> : public CheckerConfig<void> {};
template <>
struct CheckerConfig<LOGSIGMOID> : public CheckerConfig<void> {};
template <>
struct CheckerConfig<SQUARE> : public CheckerConfig<void> {};
template <>
struct CheckerConfig<SQRT> : public CheckerConfig<void> {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(0.05, 5);
}
template <class Opt>
static void update_opt(Opt& opt) {
opt.numdiff_eps = 1e-2;
opt.numdiff_max_err = 0.1;
}
};
template <>
struct CheckerConfig<RELU6> : public CheckerConfig<void> {
template <typename ctype, class Checker>
static void do_update_checker(Checker& checker) {
auto icoord = [](const typename Checker::NumInpArray& inp) {
auto p0 = inp[0]->template ptr<ctype>();
for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
if (std::abs(p0[i]) < 1) {
p0[i] += 2;
} else if (std::abs(p0[i] - 6) < 1) {
p0[i] += 2;
}
}
};
checker.set_input_coordinator(icoord);
}
template <class Checker>
static void update_checker(Checker& checker) {
using ctype = typename Checker::ctype;
return do_update_checker<ctype>(checker);
}
};
template <>
struct CheckerConfig<HSIGMOID> : public CheckerConfig<void> {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(-2.95, 2.95);
}
};
template <>
struct CheckerConfig<SIGN> : public NoZeroCheckerConfig<0> {};
/* ======================= binary config ======================= */
template <bool for_mod>
struct BinaryInputMinGap : public CheckerConfig<void> {
......@@ -567,13 +660,85 @@ template <>
struct CheckerConfig<SILU_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<PRELU> : public NoZeroCheckerConfig<0> {};
template <>
struct CheckerConfig<ASINH_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<ACOSH_GRAD> : public NoGradCheckerConfig {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(1.05, 5);
}
};
template <>
struct CheckerConfig<ATANH_GRAD> : public NoGradCheckerConfig {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(-0.95, 0.95);
}
};
template <>
struct CheckerConfig<RELU6_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<SOFTPLUS_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<HSIGMOID_GRAD> : public NoGradCheckerConfig {
template <typename ctype>
static InputGenerator get_inp_gen(size_t) {
return get_inp_gen_f32_range<ctype>(-2.95, 2.95);
}
};
/* ======================= ternary config ======================= */
template <>
struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {};
template <>
struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {};
struct CheckerConfig<PRELU_GRAD> : public NoGradCheckerConfig {};
template <>
struct CheckerConfig<CLIP> : public CheckerConfig<void> {
template <typename ctype, class Checker>
static void do_update_checker(Checker& checker) {
auto icoord = [](const typename Checker::NumInpArray& inp) {
auto p0 = inp[0]->template ptr<ctype>(), p1 = inp[1]->template ptr<ctype>(),
p2 = inp[2]->template ptr<ctype>();
for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
if (p1[i] > p2[i]) {
std::swap(p1[i], p2[i]);
}
if (p1[i] + 1 > p2[i]) {
p2[i] = p1[i] + 1;
}
if (std::abs(p1[i] - p0[i]) < 1) {
if (p1[i] < p0[i])
p0[i] += 1;
else
p0[i] -= 1;
}
if (std::abs(p2[i] - p0[i]) < 1) {
if (p2[i] < p0[i])
p0[i] += 1;
else
p0[i] -= 1;
}
}
};
checker.set_input_coordinator(icoord);
}
template <class Checker>
static void update_checker(Checker& checker) {
using ctype = typename Checker::ctype;
return do_update_checker<ctype>(checker);
}
template <class Opt>
static void update_opt(Opt& opt) {
opt.numdiff_eps = 1e-3;
opt.numdiff_max_err = 0.1;
}
};
/* ======================= test runner ======================= */
namespace detail {
template <typename dtype, class Trait>
......
......@@ -41,6 +41,7 @@ DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0)
DEF_TRAIT(TANH_GRAD, (1 - x * x) * y)
DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0))
DEF_TRAIT(PRELU, (x > 0) ? x : (x* y))
#undef _ALLOW_INT
#define _ALLOW_INT false
......@@ -57,6 +58,12 @@ DEF_TRAIT(
SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) / (1 + std::exp(-x)) /
(1 + std::exp(-x)))
DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))
DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1))
DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1))
DEF_TRAIT(ATANH_GRAD, y / (1 - x * x))
DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x)))
DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y))
DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f)))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
......
......@@ -15,6 +15,10 @@
DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0)
DEF_TRAIT(COND_LT_MOV, x < y ? z : 0)
DEF_TRAIT(FUSE_MUL_ADD3, x* y + z)
DEF_TRAIT(CLIP, x < y ? y : (x < z ? x : z))
#undef _ALLOW_INT
#define _ALLOW_INT false
DEF_TRAIT(PRELU_GRAD, x > 0 ? y : (y * z))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
......
......@@ -22,6 +22,9 @@ DEF_TRAIT(NOT, !x)
DEF_TRAIT(ABS, std::abs(x))
DEF_TRAIT(NEGATE, -x)
DEF_TRAIT(RELU, std::max<ctype>(x, 0))
DEF_TRAIT(RELU6, std::min<ctype>(std::max<ctype>(x, 0), 6))
DEF_TRAIT(SQUARE, x* x)
DEF_TRAIT(SIGN, x < 0 ? -1 : (x > 0 ? 1 : 0))
#undef _ALLOW_INT
#define _ALLOW_INT false
......@@ -46,6 +49,16 @@ DEF_TRAIT(ERFCINV, do_erfcinv(x))
DEF_TRAIT(H_SWISH, do_h_swish(x))
DEF_TRAIT(SILU, x / (1 + std::exp(-x)))
DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f)))))
DEF_TRAIT(SINH, std::sinh(x))
DEF_TRAIT(COSH, std::cosh(x))
DEF_TRAIT(ASINH, std::asinh(x))
DEF_TRAIT(ACOSH, std::acosh(x))
DEF_TRAIT(ATANH, std::atanh(x))
DEF_TRAIT(TAN, std::tan(x))
DEF_TRAIT(SOFTPLUS, std::log1p(std::exp(-std::abs(x))) + std::max<ctype>(x, 0))
DEF_TRAIT(HSIGMOID, x <= -3.f ? 0.f : (x >= 3.f ? 1.f : ((x + 3.f) / 6.f)))
DEF_TRAIT(SQRT, std::sqrt(x))
DEF_TRAIT(LOGSIGMOID, -std::log1p(std::exp(-std::abs(x))) - std::max<ctype>(-x, 0))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册