elemwise_binary_trait_def.inl 2.4 KB
Newer Older
1 2 3 4 5 6
#ifndef DEF_TRAIT
#error "DEF_TRAIT must be defined"
#endif

/* ======================= binary ======================= */
#define _CUR_ARITY 2
M
Megvii Engine Team 已提交
7
#define _EXPAND_PARAMS     \
8 9 10
    ctype x = inp[0][idx]; \
    ctype y = inp[1][idx]

M
Megvii Engine Team 已提交
11
#define _ALLOW_BOOL  true
M
Megvii Engine Team 已提交
12
#define _ALLOW_FLOAT false
M
Megvii Engine Team 已提交
13 14
#define _ALLOW_INT   false
DEF_TRAIT(AND, x&& y)
M
Megvii Engine Team 已提交
15 16 17 18
DEF_TRAIT(OR, x || y)
DEF_TRAIT(XOR, x ^ y)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
19

M
Megvii Engine Team 已提交
20
#define _ALLOW_INT   true
21 22 23 24 25
#define _ALLOW_FLOAT true
DEF_TRAIT(EQ, x == y)
DEF_TRAIT(LEQ, x <= y)
DEF_TRAIT(LT, x < y)

M
Megvii Engine Team 已提交
26 27
#undef _ALLOW_BOOL

M
Megvii Engine Team 已提交
28
#define _ALLOW_BOOL  false
29
#define _ALLOW_FLOAT true
M
Megvii Engine Team 已提交
30
#define _ALLOW_INT   true
31 32
DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y)
DEF_TRAIT(ADD, x + y)
33
DEF_TRAIT(FLOOR_DIV, do_floor_div(x, y))
34 35 36
DEF_TRAIT(MAX, std::max(x, y))
DEF_TRAIT(MIN, std::min(x, y))
DEF_TRAIT(MOD, do_mod(x, y))
M
Megvii Engine Team 已提交
37 38
DEF_TRAIT(MUL, x* y)
DEF_TRAIT(SIGMOID_GRAD, x*(1 - x) * y)
39 40 41 42 43
DEF_TRAIT(SUB, x - y)
DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0)
DEF_TRAIT(TANH_GRAD, (1 - x * x) * y)

DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0))
44
DEF_TRAIT(PRELU, (x > 0) ? x : (x* y))
45 46 47 48 49
#undef _ALLOW_INT

#define _ALLOW_INT false
DEF_TRAIT(POW, std::pow(x, y))
DEF_TRAIT(TRUE_DIV, x / y)
M
Megvii Engine Team 已提交
50
DEF_TRAIT(SAFE_DIV, y != 0 ? x / y : 0)
51 52 53 54 55 56 57
DEF_TRAIT(LOG_SUM_EXP, do_log_sum_exp(x, y))
DEF_TRAIT(FUSE_ADD_SIGMOID, 1 / (1 + std::exp(-(x + y))))
DEF_TRAIT(FUSE_ADD_TANH, std::tanh(x + y))
DEF_TRAIT(FUSE_ADD_H_SWISH, do_fuse_add_h_swish(x, y))
DEF_TRAIT(FAST_TANH_GRAD, do_fast_tanh_grad(x, y))
DEF_TRAIT(ATAN2, std::atan2(x, y))
DEF_TRAIT(H_SWISH_GRAD, do_h_swish_grad(x, y))
M
Megvii Engine Team 已提交
58 59 60
DEF_TRAIT(
        SILU_GRAD, y*(1 + std::exp(-x) + x * std::exp(-x)) / (1 + std::exp(-x)) /
                           (1 + std::exp(-x)))
61
DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))
62 63 64 65 66 67
DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1))
DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1))
DEF_TRAIT(ATANH_GRAD, y / (1 - x * x))
DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x)))
DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y))
DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f)))
68

69 70 71 72
#undef _ALLOW_INT
#undef _ALLOW_FLOAT

#define _ALLOW_FLOAT false
M
Megvii Engine Team 已提交
73
#define _ALLOW_INT   true
74 75 76 77 78
DEF_TRAIT(SHL, do_shl(x, y))
DEF_TRAIT(SHR, do_shr(x, y))
DEF_TRAIT(RMULH, do_round_mulh_saturate(x, y))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
M
Megvii Engine Team 已提交
79
#undef _ALLOW_BOOL
80 81 82 83 84

#undef _CUR_ARITY
#undef _EXPAND_PARAMS

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}