kern_defs.cuh 10.0 KB
Newer Older
1 2
#pragma once

M
Megvii Engine Team 已提交
3
#include "src/common/elemwise/erfinv.h"
4
#include "src/common/elemwise_helper.cuh"
M
Megvii Engine Team 已提交
5
#include "src/common/opr_param_defs_enumv.cuh"
6 7 8 9 10 11 12
#include "src/common/utils.cuh"

#include "megcore_cdefs.h"
#include "megdnn/dtype.h"

#include <cmath>
#include <cstdlib>
13
#include "math.h"
14 15 16 17 18

#if MEGDNN_CC_HOST
#include <algorithm>
using std::max;
using std::min;
19 20

#define rsqrtf(x) (1.f / sqrt(x))
21 22 23 24
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
M
Megvii Engine Team 已提交
25
#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL         1
26 27 28 29 30 31 32 33 34 35
#endif

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#endif

namespace megdnn {

M
Megvii Engine Team 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
template <typename T>
__device__ __host__ inline T log_sum_exp(T x, T y) {
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(exp(a - b)));
}

__device__ __host__ inline float fast_tanh(float x) {
    return x * (27.f + x * x) / (27.f + 9.f * x * x);
}

//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
//! pass
//! --use_fast_math to nvcc to enable --prec_div optimization, which will
//! cause performance drop on Turing architecture
__device__ __host__ inline float fuse_add_hswish(float x, float y) {
    float z = x + y;
    return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
}

__device__ __host__ inline float fast_tanh_grad(float x, float dx) {
    float x_pow2 = x * x;
    float deno = 3.f + x_pow2;
    return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}

//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
    const float one = 1.0;
    float sigmoid = one / (one + expf(-x));
    return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
71
#if MEGDNN_CC_HOST
M
Megvii Engine Team 已提交
72
    return 0.5f * (1.f + erff(x / sqrtf(2.f)));
73
#else
M
Megvii Engine Team 已提交
74 75
    //! use cuda build-in math
    return ::normcdff(x);
76
#endif
M
Megvii Engine Team 已提交
77
}
78

M
Megvii Engine Team 已提交
79 80 81 82 83 84 85 86
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
    //! 1/ sqrt(2 * pi)
    const float coeff = 0.3989422804014327f;
    float phi = coeff * expf(-0.5f * x * x);
    float normcdf_v = normcdf(x);
    return dy * (normcdf_v + x * phi);
}
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
__device__ __host__ inline bool feq(float a, float b) {
    return fabsf(a - b) < 1e-6;
}

__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
    do {                   \
        if (feq(y, _v)) {  \
            return _stmt;  \
        }                  \
    } while (0)

    CALL_IF(2.f, x * x);
    CALL_IF(0.5f, sqrtf(x));
    CALL_IF(-0.5f, rsqrtf(x));
    CALL_IF(0.f, 1.f);
    CALL_IF(1.f, x);
    CALL_IF(3.f, x * x * x);
    CALL_IF(-1.f, 1.f / x);
    CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
    return powf(x, y);
}

112 113 114 115 116 117 118 119 120
__device__ __host__ inline int dispatch_floordiv_int(int x, int y) {
    if ((x ^ y) < 0) {
        const auto quot = x / y;
        const auto rem = x % y;
        return rem ? quot - 1 : quot;
    }
    return x / y;
}

121 122
#include "src/common/elemwise/each_mode.inl"

M
Megvii Engine Team 已提交
123 124
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
struct ElemwiseKern;
125 126

//! define kernel for a single ctype
M
Megvii Engine Team 已提交
127 128 129 130 131
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
132 133 134
    }

//! define kernel for all float types
M
Megvii Engine Team 已提交
135 136
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
M
Megvii Engine Team 已提交
137 138
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
139 140

//! define kernel for all int types
M
Megvii Engine Team 已提交
141
#define DEF_KERN_INT(_mode, _imp)    \
142 143
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
M
Megvii Engine Team 已提交
144 145
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);
146 147 148

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
M
Megvii Engine Team 已提交
149 150
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);
151

M
Megvii Engine Team 已提交
152
/* ================== unary kernels ================== */
153 154
#define KERN_SIG ctype x

M
Megvii Engine Team 已提交
155 156
// int and float
DEF_KERN_ALL(NEGATE, -x);
157
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
158 159
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
160
#else
M
Megvii Engine Team 已提交
161
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
162
#endif
M
Megvii Engine Team 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
DEF_KERN_FLOAT(ABS, fabsf(x));

// float only
DEF_KERN_FLOAT(ACOS, acosf(x));
DEF_KERN_FLOAT(ASIN, asinf(x));
DEF_KERN_FLOAT(CEIL, ceilf(x));
DEF_KERN_FLOAT(COS, cosf(x));
DEF_KERN_FLOAT(EXP, expf(x));
DEF_KERN_FLOAT(EXPM1, expm1f(x));
DEF_KERN_FLOAT(FLOOR, floorf(x));
DEF_KERN_FLOAT(LOG, logf(x));
DEF_KERN_FLOAT(LOG1P, log1pf(x));
DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
DEF_KERN_FLOAT(SIN, sinf(x));
DEF_KERN_FLOAT(TANH, tanhf(x));
DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
DEF_KERN_FLOAT(ROUND, roundf(x));
DEF_KERN_FLOAT(ERF, erff(x));
DEF_KERN_FLOAT(ERFINV, erfinvf(x));
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
192 193 194

#undef KERN_SIG

M
Megvii Engine Team 已提交
195
/* ================== binary kernels ================== */
196 197
#define KERN_SIG ctype x, ctype y

M
Megvii Engine Team 已提交
198
// int and float
199
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
200 201
DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
202
#else
M
Megvii Engine Team 已提交
203
DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
204
#endif
M
Megvii Engine Team 已提交
205 206 207 208 209 210 211 212 213 214
DEF_KERN_ALL(ADD, x + y);
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x&& y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);
215
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
216 217
DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
218
#else
M
Megvii Engine Team 已提交
219
DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
220
#endif
M
Megvii Engine Team 已提交
221 222 223 224 225 226 227 228
DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

229
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y));
M
Megvii Engine Team 已提交
230 231 232 233 234 235 236
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

DEF_KERN_INT(MOD, x % y);
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);
DEF_KERN_INT(SHR, x >> y);
237
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
238 239
DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
240
#else
M
Megvii Engine Team 已提交
241
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
242 243
#endif

M
Megvii Engine Team 已提交
244 245 246 247 248
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
DEF_KERN_FLOAT(POW, powf(x, y));
DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
249

M
Megvii Engine Team 已提交
250 251
DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y));
DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f));
252

M
Megvii Engine Team 已提交
253 254 255 256 257
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(
        H_SWISH_GRAD,
        x < -3.f ? (ctype)0.f
                 : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
258

M
Megvii Engine Team 已提交
259 260 261
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
262 263
#undef KERN_SIG

M
Megvii Engine Team 已提交
264
/* ================== ternary kernels ================== */
265 266
#define KERN_SIG ctype x, ctype y, ctype z

M
Megvii Engine Team 已提交
267 268
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
269
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0));
M
Megvii Engine Team 已提交
270
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
271 272 273 274 275

#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL

/* ================== bool kernels ================== */
//! define kernel
template <megcorePlatform_t plat, uint32_t mode, typename stype, typename dtype>
struct ElemwiseBoolKern;

#define DEF_KERN(_ctype, _dtype, _mode, _imp)                                      \
    template <megcorePlatform_t plat>                                              \
    struct ElemwiseBoolKern<                                                       \
            plat, param_enumv::Elemwise::Mode::_mode, _ctype, _dtype> {            \
        typedef _ctype ctype;                                                      \
        static __host__ __device__ _dtype apply(KERN_SIG) { return _dtype(_imp); } \
    }

//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp)                              \
    DEF_KERN(dt_float32, dt_bool, _mode, _imp);                  \
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, dt_bool, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, dt_bool, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp)             \
    DEF_KERN(dt_int32, dt_bool, _mode, _imp); \
    DEF_KERN(dt_int16, dt_bool, _mode, _imp); \
    DEF_KERN(dt_int8, dt_bool, _mode, _imp);  \
    DEF_KERN(dt_uint8, dt_bool, _mode, _imp);

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);  \
    DEF_KERN(dt_bool, dt_bool, _mode, _imp);
#define KERN_SIG ctype x
DEF_KERN_FLOAT(ISNAN, isnan(float(x)));
DEF_KERN_FLOAT(ISINF, isinf(float(x)));
#undef KERN_SIG
#define KERN_SIG ctype x, ctype y
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN_ALL(NEQ, x != y);
#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL
327

M
Megvii Engine Team 已提交
328
}  // namespace megdnn
329 330 331 332 333 334 335 336

#if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
#undef MEGDNN_HOST_DEVICE_SELF_DEFINE
#undef __host__
#undef __device__
#endif

// vim: ft=cpp syntax=cpp.doxygen