elemwise_unary_trait_def.inl 2.0 KB
Newer Older
1 2 3 4 5
#ifndef DEF_TRAIT
#error "DEF_TRAIT must be defined"
#endif

/* ======================= unary ======================= */
M
Megvii Engine Team 已提交
6 7
#define _CUR_ARITY     1
#define _EXPAND_PARAMS ctype x = inp[0][idx]
8

M
Megvii Engine Team 已提交
9
#define _ALLOW_BOOL  true
M
Megvii Engine Team 已提交
10
#define _ALLOW_FLOAT false
M
Megvii Engine Team 已提交
11
#define _ALLOW_INT   false
M
Megvii Engine Team 已提交
12 13 14 15 16 17
DEF_TRAIT(NOT, !x)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL

#define _ALLOW_BOOL false
18 19 20 21 22 23 24

#define _ALLOW_FLOAT true

#define _ALLOW_INT true
DEF_TRAIT(ABS, std::abs(x))
DEF_TRAIT(NEGATE, -x)
DEF_TRAIT(RELU, std::max<ctype>(x, 0))
25 26 27
DEF_TRAIT(RELU6, std::min<ctype>(std::max<ctype>(x, 0), 6))
DEF_TRAIT(SQUARE, x* x)
DEF_TRAIT(SIGN, x < 0 ? -1 : (x > 0 ? 1 : 0))
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
#undef _ALLOW_INT

#define _ALLOW_INT false
DEF_TRAIT(ACOS, std::acos(x))
DEF_TRAIT(ASIN, std::asin(x))
DEF_TRAIT(CEIL, std::ceil(x))
DEF_TRAIT(COS, std::cos(x))
DEF_TRAIT(EXP, std::exp(x))
DEF_TRAIT(EXPM1, std::expm1(x))
DEF_TRAIT(FLOOR, std::floor(x))
DEF_TRAIT(LOG, std::log(x))
DEF_TRAIT(LOG1P, std::log1p(x))
DEF_TRAIT(SIGMOID, 1 / (1 + std::exp(-x)))
DEF_TRAIT(SIN, std::sin(x))
DEF_TRAIT(TANH, std::tanh(x))
DEF_TRAIT(FAST_TANH, do_fast_tanh(x))
DEF_TRAIT(ROUND, std::round(x))
DEF_TRAIT(ERF, std::erf(x))
DEF_TRAIT(ERFINV, do_erfinv(x))
DEF_TRAIT(ERFC, std::erfc(x))
DEF_TRAIT(ERFCINV, do_erfcinv(x))
DEF_TRAIT(H_SWISH, do_h_swish(x))
50 51
DEF_TRAIT(SILU, x / (1 + std::exp(-x)))
DEF_TRAIT(GELU, x*(0.5f * (1.f + std::erf(x / std::sqrt(2.f)))))
52 53 54 55 56 57 58 59 60 61
DEF_TRAIT(SINH, std::sinh(x))
DEF_TRAIT(COSH, std::cosh(x))
DEF_TRAIT(ASINH, std::asinh(x))
DEF_TRAIT(ACOSH, std::acosh(x))
DEF_TRAIT(ATANH, std::atanh(x))
DEF_TRAIT(TAN, std::tan(x))
DEF_TRAIT(SOFTPLUS, std::log1p(std::exp(-std::abs(x))) + std::max<ctype>(x, 0))
DEF_TRAIT(HSIGMOID, x <= -3.f ? 0.f : (x >= 3.f ? 1.f : ((x + 3.f) / 6.f)))
DEF_TRAIT(SQRT, std::sqrt(x))
DEF_TRAIT(LOGSIGMOID, -std::log1p(std::exp(-std::abs(x))) - std::max<ctype>(-x, 0))
62 63 64 65
#undef _ALLOW_INT

#undef _ALLOW_FLOAT

M
Megvii Engine Team 已提交
66 67
#undef _ALLOW_BOOL

68 69 70 71
#undef _CUR_ARITY
#undef _EXPAND_PARAMS

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}