kern_defs.cuh 12.1 KB
Newer Older
1 2
#pragma once

M
Megvii Engine Team 已提交
3
#include "src/common/elemwise/erfinv.h"
4
#include "src/common/elemwise_helper.cuh"
M
Megvii Engine Team 已提交
5
#include "src/common/opr_param_defs_enumv.cuh"
6 7 8 9 10 11 12
#include "src/common/utils.cuh"

#include "megcore_cdefs.h"
#include "megdnn/dtype.h"

#include <cmath>
#include <cstdlib>
13
#include "math.h"
14 15 16 17 18

#if MEGDNN_CC_HOST
#include <algorithm>
using std::max;
using std::min;
19 20

#define rsqrtf(x) (1.f / sqrt(x))
21 22 23 24
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
M
Megvii Engine Team 已提交
25
#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL         1
26 27 28 29 30 31 32 33 34 35
#endif

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#endif

namespace megdnn {

M
Megvii Engine Team 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
template <typename T>
__device__ __host__ inline T log_sum_exp(T x, T y) {
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(exp(a - b)));
}

__device__ __host__ inline float fast_tanh(float x) {
    return x * (27.f + x * x) / (27.f + 9.f * x * x);
}

//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
//! pass
//! --use_fast_math to nvcc to enable --prec_div optimization, which will
//! cause performance drop on Turing architecture
__device__ __host__ inline float fuse_add_hswish(float x, float y) {
    float z = x + y;
    return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
}

__device__ __host__ inline float fast_tanh_grad(float x, float dx) {
    float x_pow2 = x * x;
    float deno = 3.f + x_pow2;
    return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}

//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
    const float one = 1.0;
    float sigmoid = one / (one + expf(-x));
    return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
71
#if MEGDNN_CC_HOST
M
Megvii Engine Team 已提交
72
    return 0.5f * (1.f + erff(x / sqrtf(2.f)));
73
#else
M
Megvii Engine Team 已提交
74 75
    //! use cuda build-in math
    return ::normcdff(x);
76
#endif
M
Megvii Engine Team 已提交
77
}
78

M
Megvii Engine Team 已提交
79 80 81 82 83 84 85 86
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
    //! 1/ sqrt(2 * pi)
    const float coeff = 0.3989422804014327f;
    float phi = coeff * expf(-0.5f * x * x);
    float normcdf_v = normcdf(x);
    return dy * (normcdf_v + x * phi);
}
87

88 89 90 91 92 93 94 95 96
//! grad of softplus
__device__ __host__ inline float softplus_grad(float x, float dy) {
    float logg = -dy * expf(-fabs(x)) / (1.f + expf(-fabs(x)));
    float grad0 = x > 0.f ? logg : -logg;
    float relux = x < 0.f ? 0.f : x;
    float grad1 = relux > 0.f ? dy : 0.f;
    return grad0 + grad1;
}

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
__device__ __host__ inline bool feq(float a, float b) {
    return fabsf(a - b) < 1e-6;
}

__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
    do {                   \
        if (feq(y, _v)) {  \
            return _stmt;  \
        }                  \
    } while (0)

    CALL_IF(2.f, x * x);
    CALL_IF(0.5f, sqrtf(x));
    CALL_IF(-0.5f, rsqrtf(x));
    CALL_IF(0.f, 1.f);
    CALL_IF(1.f, x);
    CALL_IF(3.f, x * x * x);
    CALL_IF(-1.f, 1.f / x);
    CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
    return powf(x, y);
}

121 122 123 124 125 126 127 128 129
__device__ __host__ inline int dispatch_floordiv_int(int x, int y) {
    if ((x ^ y) < 0) {
        const auto quot = x / y;
        const auto rem = x % y;
        return rem ? quot - 1 : quot;
    }
    return x / y;
}

130 131
#include "src/common/elemwise/each_mode.inl"

M
Megvii Engine Team 已提交
132 133
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
struct ElemwiseKern;
134 135

//! define kernel for a single ctype
M
Megvii Engine Team 已提交
136 137 138 139 140
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
141 142 143
    }

//! define kernel for all float types
M
Megvii Engine Team 已提交
144 145
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
M
Megvii Engine Team 已提交
146 147
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
148 149

//! define kernel for all int types
M
Megvii Engine Team 已提交
150
#define DEF_KERN_INT(_mode, _imp)    \
151 152
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
M
Megvii Engine Team 已提交
153 154
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);
155 156 157

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
M
Megvii Engine Team 已提交
158 159
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);
160

M
Megvii Engine Team 已提交
161
/* ================== unary kernels ================== */
162 163
#define KERN_SIG ctype x

M
Megvii Engine Team 已提交
164 165
// int and float
DEF_KERN_ALL(NEGATE, -x);
M
Megvii Engine Team 已提交
166
DEF_KERN_ALL(SQUARE, x* x);
167
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
168
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
M
Megvii Engine Team 已提交
169 170
DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
M
Megvii Engine Team 已提交
171
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
M
Megvii Engine Team 已提交
172 173
DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6)));
DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f));
174
#else
M
Megvii Engine Team 已提交
175
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
M
Megvii Engine Team 已提交
176 177
DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
178
#endif
M
Megvii Engine Team 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
DEF_KERN_FLOAT(ABS, fabsf(x));

// float only
DEF_KERN_FLOAT(ACOS, acosf(x));
DEF_KERN_FLOAT(ASIN, asinf(x));
DEF_KERN_FLOAT(CEIL, ceilf(x));
DEF_KERN_FLOAT(COS, cosf(x));
DEF_KERN_FLOAT(EXP, expf(x));
DEF_KERN_FLOAT(EXPM1, expm1f(x));
DEF_KERN_FLOAT(FLOOR, floorf(x));
DEF_KERN_FLOAT(LOG, logf(x));
DEF_KERN_FLOAT(LOG1P, log1pf(x));
DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
DEF_KERN_FLOAT(SIN, sinf(x));
DEF_KERN_FLOAT(TANH, tanhf(x));
DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
DEF_KERN_FLOAT(ROUND, roundf(x));
DEF_KERN_FLOAT(ERF, erff(x));
DEF_KERN_FLOAT(ERFINV, erfinvf(x));
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));
M
Megvii Engine Team 已提交
205 206 207 208 209 210 211 212 213 214 215 216
DEF_KERN_FLOAT(SINH, sinhf(x));
DEF_KERN_FLOAT(COSH, coshf(x));
DEF_KERN_FLOAT(ASINH, asinhf(x));
DEF_KERN_FLOAT(ACOSH, acoshf(x));
DEF_KERN_FLOAT(ATANH, atanhf(x));
DEF_KERN_FLOAT(TAN, tanf(x));
DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x));
DEF_KERN_FLOAT(
        HSIGMOID,
        x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f)));
DEF_KERN_FLOAT(SQRT, sqrtf(x));
DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x));
M
Megvii Engine Team 已提交
217 218 219

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
220 221 222

#undef KERN_SIG

M
Megvii Engine Team 已提交
223
/* ================== binary kernels ================== */
224 225
#define KERN_SIG ctype x, ctype y

M
Megvii Engine Team 已提交
226
// int and float
227
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
228 229
DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
230
#else
M
Megvii Engine Team 已提交
231
DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
232
#endif
M
Megvii Engine Team 已提交
233 234 235 236 237 238 239 240 241 242
DEF_KERN_ALL(ADD, x + y);
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x&& y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);
243
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
244 245
DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
246
#else
M
Megvii Engine Team 已提交
247
DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
248
#endif
M
Megvii Engine Team 已提交
249 250 251 252 253 254 255 256
DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

257
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y));
M
Megvii Engine Team 已提交
258 259
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

260
DEF_KERN_INT(MOD, ((y + x % y) % y));  // consistent with python modulo
M
Megvii Engine Team 已提交
261 262 263 264
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);
DEF_KERN_INT(SHR, x >> y);
265
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
266 267
DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
268
#else
M
Megvii Engine Team 已提交
269
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
270
#endif
M
Megvii Engine Team 已提交
271 272 273 274 275 276
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y));
DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y));
#else
DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y));
#endif
277

M
Megvii Engine Team 已提交
278 279 280 281 282
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
DEF_KERN_FLOAT(POW, powf(x, y));
DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
283

M
Megvii Engine Team 已提交
284 285
DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y));
DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f));
286

M
Megvii Engine Team 已提交
287 288 289 290 291
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(
        H_SWISH_GRAD,
        x < -3.f ? (ctype)0.f
                 : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
292

M
Megvii Engine Team 已提交
293 294 295
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
M
Megvii Engine Team 已提交
296 297 298
DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f));
DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f));
DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x));
299
DEF_KERN_FLOAT(SOFTPLUS_GRAD, softplus_grad(x, y));
M
Megvii Engine Team 已提交
300 301 302 303
DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y));
DEF_KERN_FLOAT(
        HSIGMOID_GRAD,
        x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f)));
304 305
#undef KERN_SIG

M
Megvii Engine Team 已提交
306
/* ================== ternary kernels ================== */
307 308
#define KERN_SIG ctype x, ctype y, ctype z

M
Megvii Engine Team 已提交
309 310
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
311
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0));
M
Megvii Engine Team 已提交
312
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
M
Megvii Engine Team 已提交
313 314
DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z));
DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z));
315 316 317 318 319

#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL

/* ================== bool kernels ================== */
//! define kernel
template <megcorePlatform_t plat, uint32_t mode, typename stype, typename dtype>
struct ElemwiseBoolKern;

#define DEF_KERN(_ctype, _dtype, _mode, _imp)                                      \
    template <megcorePlatform_t plat>                                              \
    struct ElemwiseBoolKern<                                                       \
            plat, param_enumv::Elemwise::Mode::_mode, _ctype, _dtype> {            \
        typedef _ctype ctype;                                                      \
        static __host__ __device__ _dtype apply(KERN_SIG) { return _dtype(_imp); } \
    }

//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp)                              \
    DEF_KERN(dt_float32, dt_bool, _mode, _imp);                  \
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, dt_bool, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, dt_bool, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp)             \
    DEF_KERN(dt_int32, dt_bool, _mode, _imp); \
    DEF_KERN(dt_int16, dt_bool, _mode, _imp); \
    DEF_KERN(dt_int8, dt_bool, _mode, _imp);  \
    DEF_KERN(dt_uint8, dt_bool, _mode, _imp);

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);  \
    DEF_KERN(dt_bool, dt_bool, _mode, _imp);
#define KERN_SIG ctype x
DEF_KERN_FLOAT(ISNAN, isnan(float(x)));
DEF_KERN_FLOAT(ISINF, isinf(float(x)));
#undef KERN_SIG
#define KERN_SIG ctype x, ctype y
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN_ALL(NEQ, x != y);
#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL
371

M
Megvii Engine Team 已提交
372
}  // namespace megdnn
373 374 375 376 377 378 379 380

#if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
#undef MEGDNN_HOST_DEVICE_SELF_DEFINE
#undef __host__
#undef __device__
#endif

// vim: ft=cpp syntax=cpp.doxygen