elemwise_ternary_trait_def.inl 1.1 KB
Newer Older
1 2 3 4 5 6
#ifndef DEF_TRAIT
#error "DEF_TRAIT must be defined"
#endif

/* ======================= ternary ======================= */
#define _CUR_ARITY 3
M
Megvii Engine Team 已提交
7
#define _EXPAND_PARAMS     \
8 9 10 11
    ctype x = inp[0][idx]; \
    ctype y = inp[1][idx]; \
    ctype z = inp[2][idx]

M
Megvii Engine Team 已提交
12
#define _ALLOW_BOOL  false
13
#define _ALLOW_FLOAT true
M
Megvii Engine Team 已提交
14
#define _ALLOW_INT   true
15
DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0)
16
DEF_TRAIT(COND_LT_MOV, x < y ? z : 0)
M
Megvii Engine Team 已提交
17
DEF_TRAIT(FUSE_MUL_ADD3, x* y + z)
18 19 20 21
DEF_TRAIT(CLIP, x < y ? y : (x < z ? x : z))
#undef _ALLOW_INT
#define _ALLOW_INT false
DEF_TRAIT(PRELU_GRAD, x > 0 ? y : (y * z))
22 23 24 25 26 27 28 29
#undef _ALLOW_INT
#undef _ALLOW_FLOAT

#undef _CUR_ARITY
#undef _EXPAND_PARAMS

/* ======================= quaternary ======================= */
#define _CUR_ARITY 4
M
Megvii Engine Team 已提交
30
#define _EXPAND_PARAMS      \
31 32 33 34 35 36
    ctype i0 = inp[0][idx]; \
    ctype i1 = inp[1][idx]; \
    ctype i2 = inp[2][idx]; \
    ctype i3 = inp[3][idx]

#define _ALLOW_FLOAT true
M
Megvii Engine Team 已提交
37 38
#define _ALLOW_INT   true
DEF_TRAIT(FUSE_MUL_ADD4, i0* i1 + i2 * i3)
39 40 41 42 43
#undef _ALLOW_INT
#undef _ALLOW_FLOAT

#undef _CUR_ARITY
#undef _EXPAND_PARAMS
M
Megvii Engine Team 已提交
44
#undef _ALLOW_BOOL
45 46

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}