kern_defs.cuh 8.1 KB
Newer Older
1 2
#pragma once

M
Megvii Engine Team 已提交
3
#include "src/common/elemwise/erfinv.h"
4
#include "src/common/elemwise_helper.cuh"
M
Megvii Engine Team 已提交
5
#include "src/common/opr_param_defs_enumv.cuh"
6 7 8 9 10 11 12 13 14 15 16 17
#include "src/common/utils.cuh"

#include "megcore_cdefs.h"
#include "megdnn/dtype.h"

#include <cmath>
#include <cstdlib>

#if MEGDNN_CC_HOST
#include <algorithm>
using std::max;
using std::min;
18 19

#define rsqrtf(x) (1.f / sqrt(x))
20 21 22 23
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
M
Megvii Engine Team 已提交
24
#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL         1
25 26 27 28 29 30 31 32 33 34
#endif

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#endif

namespace megdnn {

M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
template <typename T>
__device__ __host__ inline T log_sum_exp(T x, T y) {
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(exp(a - b)));
}

__device__ __host__ inline float fast_tanh(float x) {
    return x * (27.f + x * x) / (27.f + 9.f * x * x);
}

//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
//! pass
//! --use_fast_math to nvcc to enable --prec_div optimization, which will
//! cause performance drop on Turing architecture
__device__ __host__ inline float fuse_add_hswish(float x, float y) {
    float z = x + y;
    return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
}

__device__ __host__ inline float fast_tanh_grad(float x, float dx) {
    float x_pow2 = x * x;
    float deno = 3.f + x_pow2;
    return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}

//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
    const float one = 1.0;
    float sigmoid = one / (one + expf(-x));
    return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
70
#if MEGDNN_CC_HOST
M
Megvii Engine Team 已提交
71
    return 0.5f * (1.f + erff(x / sqrtf(2.f)));
72
#else
M
Megvii Engine Team 已提交
73 74
    //! use cuda build-in math
    return ::normcdff(x);
75
#endif
M
Megvii Engine Team 已提交
76
}
77

M
Megvii Engine Team 已提交
78 79 80 81 82 83 84 85
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
    //! 1/ sqrt(2 * pi)
    const float coeff = 0.3989422804014327f;
    float phi = coeff * expf(-0.5f * x * x);
    float normcdf_v = normcdf(x);
    return dy * (normcdf_v + x * phi);
}
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
__device__ __host__ inline bool feq(float a, float b) {
    return fabsf(a - b) < 1e-6;
}

__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
    do {                   \
        if (feq(y, _v)) {  \
            return _stmt;  \
        }                  \
    } while (0)

    CALL_IF(2.f, x * x);
    CALL_IF(0.5f, sqrtf(x));
    CALL_IF(-0.5f, rsqrtf(x));
    CALL_IF(0.f, 1.f);
    CALL_IF(1.f, x);
    CALL_IF(3.f, x * x * x);
    CALL_IF(-1.f, 1.f / x);
    CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
    return powf(x, y);
}

111 112 113 114 115 116 117 118 119
__device__ __host__ inline int dispatch_floordiv_int(int x, int y) {
    if ((x ^ y) < 0) {
        const auto quot = x / y;
        const auto rem = x % y;
        return rem ? quot - 1 : quot;
    }
    return x / y;
}

120 121
#include "src/common/elemwise/each_mode.inl"

M
Megvii Engine Team 已提交
122 123
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
struct ElemwiseKern;
124 125

//! define kernel for a single ctype
M
Megvii Engine Team 已提交
126 127 128 129 130
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
131 132 133
    }

//! define kernel for all float types
M
Megvii Engine Team 已提交
134 135
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
M
Megvii Engine Team 已提交
136 137
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
138 139

//! define kernel for all int types
M
Megvii Engine Team 已提交
140
#define DEF_KERN_INT(_mode, _imp)    \
141 142
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
M
Megvii Engine Team 已提交
143 144
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);
145 146 147

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
M
Megvii Engine Team 已提交
148 149
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);
150

M
Megvii Engine Team 已提交
151
/* ================== unary kernels ================== */
152 153
#define KERN_SIG ctype x

M
Megvii Engine Team 已提交
154 155
// int and float
DEF_KERN_ALL(NEGATE, -x);
156
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
157 158
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
159
#else
M
Megvii Engine Team 已提交
160
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
161
#endif
M
Megvii Engine Team 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
DEF_KERN_FLOAT(ABS, fabsf(x));

// float only
DEF_KERN_FLOAT(ACOS, acosf(x));
DEF_KERN_FLOAT(ASIN, asinf(x));
DEF_KERN_FLOAT(CEIL, ceilf(x));
DEF_KERN_FLOAT(COS, cosf(x));
DEF_KERN_FLOAT(EXP, expf(x));
DEF_KERN_FLOAT(EXPM1, expm1f(x));
DEF_KERN_FLOAT(FLOOR, floorf(x));
DEF_KERN_FLOAT(LOG, logf(x));
DEF_KERN_FLOAT(LOG1P, log1pf(x));
DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
DEF_KERN_FLOAT(SIN, sinf(x));
DEF_KERN_FLOAT(TANH, tanhf(x));
DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
DEF_KERN_FLOAT(ROUND, roundf(x));
DEF_KERN_FLOAT(ERF, erff(x));
DEF_KERN_FLOAT(ERFINV, erfinvf(x));
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
191 192 193

#undef KERN_SIG

M
Megvii Engine Team 已提交
194
/* ================== binary kernels ================== */
195 196
#define KERN_SIG ctype x, ctype y

M
Megvii Engine Team 已提交
197
// int and float
198
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
199 200
DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
201
#else
M
Megvii Engine Team 已提交
202
DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
203
#endif
M
Megvii Engine Team 已提交
204 205 206 207 208 209 210 211 212 213
DEF_KERN_ALL(ADD, x + y);
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x&& y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);
214
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
215 216
DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
217
#else
M
Megvii Engine Team 已提交
218
DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
219
#endif
M
Megvii Engine Team 已提交
220 221 222 223 224 225 226 227
DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

228
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y));
M
Megvii Engine Team 已提交
229 230 231 232 233 234 235
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

DEF_KERN_INT(MOD, x % y);
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);
DEF_KERN_INT(SHR, x >> y);
236
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
237 238
DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
239
#else
M
Megvii Engine Team 已提交
240
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
241 242
#endif

M
Megvii Engine Team 已提交
243 244 245 246 247
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
DEF_KERN_FLOAT(POW, powf(x, y));
DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
248

M
Megvii Engine Team 已提交
249 250
DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y));
DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f));
251

M
Megvii Engine Team 已提交
252 253 254 255 256
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(
        H_SWISH_GRAD,
        x < -3.f ? (ctype)0.f
                 : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
257

M
Megvii Engine Team 已提交
258 259 260
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
261 262
#undef KERN_SIG

M
Megvii Engine Team 已提交
263
/* ================== ternary kernels ================== */
264 265
#define KERN_SIG ctype x, ctype y, ctype z

M
Megvii Engine Team 已提交
266 267 268
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
269 270 271 272 273 274

#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN

M
Megvii Engine Team 已提交
275
}  // namespace megdnn
276 277 278 279 280 281 282 283

#if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
#undef MEGDNN_HOST_DEVICE_SELF_DEFINE
#undef __host__
#undef __device__
#endif

// vim: ft=cpp syntax=cpp.doxygen