kern_defs.cuh 11.8 KB
Newer Older
1 2
#pragma once

M
Megvii Engine Team 已提交
3
#include "src/common/elemwise/erfinv.h"
4
#include "src/common/elemwise_helper.cuh"
M
Megvii Engine Team 已提交
5
#include "src/common/opr_param_defs_enumv.cuh"
6 7 8 9 10 11 12
#include "src/common/utils.cuh"

#include "megcore_cdefs.h"
#include "megdnn/dtype.h"

#include <cmath>
#include <cstdlib>
13
#include "math.h"
14 15 16 17 18

#if MEGDNN_CC_HOST
#include <algorithm>
using std::max;
using std::min;
19 20

#define rsqrtf(x) (1.f / sqrt(x))
21 22 23 24
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
M
Megvii Engine Team 已提交
25
#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL         1
26 27 28 29 30 31 32 33 34 35
#endif

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#endif

namespace megdnn {

M
Megvii Engine Team 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
template <typename T>
__device__ __host__ inline T log_sum_exp(T x, T y) {
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(exp(a - b)));
}

__device__ __host__ inline float fast_tanh(float x) {
    return x * (27.f + x * x) / (27.f + 9.f * x * x);
}

//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
//! pass
//! --use_fast_math to nvcc to enable --prec_div optimization, which will
//! cause performance drop on Turing architecture
__device__ __host__ inline float fuse_add_hswish(float x, float y) {
    float z = x + y;
    return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
}

__device__ __host__ inline float fast_tanh_grad(float x, float dx) {
    float x_pow2 = x * x;
    float deno = 3.f + x_pow2;
    return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}

//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
    const float one = 1.0;
    float sigmoid = one / (one + expf(-x));
    return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
71
#if MEGDNN_CC_HOST
M
Megvii Engine Team 已提交
72
    return 0.5f * (1.f + erff(x / sqrtf(2.f)));
73
#else
M
Megvii Engine Team 已提交
74 75
    //! use cuda build-in math
    return ::normcdff(x);
76
#endif
M
Megvii Engine Team 已提交
77
}
78

M
Megvii Engine Team 已提交
79 80 81 82 83 84 85 86
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
    //! 1/ sqrt(2 * pi)
    const float coeff = 0.3989422804014327f;
    float phi = coeff * expf(-0.5f * x * x);
    float normcdf_v = normcdf(x);
    return dy * (normcdf_v + x * phi);
}
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
__device__ __host__ inline bool feq(float a, float b) {
    return fabsf(a - b) < 1e-6;
}

__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
    do {                   \
        if (feq(y, _v)) {  \
            return _stmt;  \
        }                  \
    } while (0)

    CALL_IF(2.f, x * x);
    CALL_IF(0.5f, sqrtf(x));
    CALL_IF(-0.5f, rsqrtf(x));
    CALL_IF(0.f, 1.f);
    CALL_IF(1.f, x);
    CALL_IF(3.f, x * x * x);
    CALL_IF(-1.f, 1.f / x);
    CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
    return powf(x, y);
}

112 113 114 115 116 117 118 119 120
__device__ __host__ inline int dispatch_floordiv_int(int x, int y) {
    if ((x ^ y) < 0) {
        const auto quot = x / y;
        const auto rem = x % y;
        return rem ? quot - 1 : quot;
    }
    return x / y;
}

121 122
#include "src/common/elemwise/each_mode.inl"

M
Megvii Engine Team 已提交
123 124
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
struct ElemwiseKern;
125 126

//! define kernel for a single ctype
M
Megvii Engine Team 已提交
127 128 129 130 131
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
132 133 134
    }

//! define kernel for all float types
M
Megvii Engine Team 已提交
135 136
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
M
Megvii Engine Team 已提交
137 138
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
139 140

//! define kernel for all int types
M
Megvii Engine Team 已提交
141
#define DEF_KERN_INT(_mode, _imp)    \
142 143
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
M
Megvii Engine Team 已提交
144 145
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);
146 147 148

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
M
Megvii Engine Team 已提交
149 150
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);
151

M
Megvii Engine Team 已提交
152
/* ================== unary kernels ================== */
153 154
#define KERN_SIG ctype x

M
Megvii Engine Team 已提交
155 156
// int and float
DEF_KERN_ALL(NEGATE, -x);
157
DEF_KERN_ALL(SQUARE, x* x);
158
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
159
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
160 161
DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
M
Megvii Engine Team 已提交
162
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
163 164
DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6)));
DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f));
165
#else
M
Megvii Engine Team 已提交
166
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
167 168
DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6)));
DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0)));
169
#endif
M
Megvii Engine Team 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
DEF_KERN_FLOAT(ABS, fabsf(x));

// float only
DEF_KERN_FLOAT(ACOS, acosf(x));
DEF_KERN_FLOAT(ASIN, asinf(x));
DEF_KERN_FLOAT(CEIL, ceilf(x));
DEF_KERN_FLOAT(COS, cosf(x));
DEF_KERN_FLOAT(EXP, expf(x));
DEF_KERN_FLOAT(EXPM1, expm1f(x));
DEF_KERN_FLOAT(FLOOR, floorf(x));
DEF_KERN_FLOAT(LOG, logf(x));
DEF_KERN_FLOAT(LOG1P, log1pf(x));
DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
DEF_KERN_FLOAT(SIN, sinf(x));
DEF_KERN_FLOAT(TANH, tanhf(x));
DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
DEF_KERN_FLOAT(ROUND, roundf(x));
DEF_KERN_FLOAT(ERF, erff(x));
DEF_KERN_FLOAT(ERFINV, erfinvf(x));
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));
196 197 198 199 200 201 202 203 204 205 206 207
DEF_KERN_FLOAT(SINH, sinhf(x));
DEF_KERN_FLOAT(COSH, coshf(x));
DEF_KERN_FLOAT(ASINH, asinhf(x));
DEF_KERN_FLOAT(ACOSH, acoshf(x));
DEF_KERN_FLOAT(ATANH, atanhf(x));
DEF_KERN_FLOAT(TAN, tanf(x));
DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x));
DEF_KERN_FLOAT(
        HSIGMOID,
        x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f)));
DEF_KERN_FLOAT(SQRT, sqrtf(x));
DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x));
M
Megvii Engine Team 已提交
208 209 210

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
211 212 213

#undef KERN_SIG

M
Megvii Engine Team 已提交
214
/* ================== binary kernels ================== */
215 216
#define KERN_SIG ctype x, ctype y

M
Megvii Engine Team 已提交
217
// int and float
218
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
219 220
DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
221
#else
M
Megvii Engine Team 已提交
222
DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
223
#endif
M
Megvii Engine Team 已提交
224 225 226 227 228 229 230 231 232 233
DEF_KERN_ALL(ADD, x + y);
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x&& y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);
234
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
235 236
DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
237
#else
M
Megvii Engine Team 已提交
238
DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
239
#endif
M
Megvii Engine Team 已提交
240 241 242 243 244 245 246 247
DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

248
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y));
M
Megvii Engine Team 已提交
249
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));
M
Megvii Engine Team 已提交
250 251
DEF_KERN_INT(SAFE_DIV, y != 0 ? x / y : 0);
DEF_KERN_FLOAT(SAFE_DIV, y != 0.f ? x / y : 0.f);
M
Megvii Engine Team 已提交
252 253 254 255 256 257

DEF_KERN_INT(MOD, x % y);
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);
DEF_KERN_INT(SHR, x >> y);
258
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
259 260
DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
261
#else
M
Megvii Engine Team 已提交
262
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
263
#endif
264 265 266 267 268 269
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y));
DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y));
#else
DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y));
#endif
270

M
Megvii Engine Team 已提交
271 272 273 274 275
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
DEF_KERN_FLOAT(POW, powf(x, y));
DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
276

M
Megvii Engine Team 已提交
277 278
DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y));
DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f));
279

M
Megvii Engine Team 已提交
280 281 282 283 284
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(
        H_SWISH_GRAD,
        x < -3.f ? (ctype)0.f
                 : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
285

M
Megvii Engine Team 已提交
286 287 288
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
289 290 291 292 293 294 295 296
DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f));
DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f));
DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x));
DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x)));
DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y));
DEF_KERN_FLOAT(
        HSIGMOID_GRAD,
        x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f)));
297 298
#undef KERN_SIG

M
Megvii Engine Team 已提交
299
/* ================== ternary kernels ================== */
300 301
#define KERN_SIG ctype x, ctype y, ctype z

M
Megvii Engine Team 已提交
302 303
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
304
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0));
M
Megvii Engine Team 已提交
305
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
306 307
DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z));
DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z));
308 309 310 311 312

#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL

/* ================== bool kernels ================== */
//! define kernel
template <megcorePlatform_t plat, uint32_t mode, typename stype, typename dtype>
struct ElemwiseBoolKern;

#define DEF_KERN(_ctype, _dtype, _mode, _imp)                                      \
    template <megcorePlatform_t plat>                                              \
    struct ElemwiseBoolKern<                                                       \
            plat, param_enumv::Elemwise::Mode::_mode, _ctype, _dtype> {            \
        typedef _ctype ctype;                                                      \
        static __host__ __device__ _dtype apply(KERN_SIG) { return _dtype(_imp); } \
    }

//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp)                              \
    DEF_KERN(dt_float32, dt_bool, _mode, _imp);                  \
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, dt_bool, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, dt_bool, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp)             \
    DEF_KERN(dt_int32, dt_bool, _mode, _imp); \
    DEF_KERN(dt_int16, dt_bool, _mode, _imp); \
    DEF_KERN(dt_int8, dt_bool, _mode, _imp);  \
    DEF_KERN(dt_uint8, dt_bool, _mode, _imp);

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);  \
    DEF_KERN(dt_bool, dt_bool, _mode, _imp);
#define KERN_SIG ctype x
DEF_KERN_FLOAT(ISNAN, isnan(float(x)));
DEF_KERN_FLOAT(ISINF, isinf(float(x)));
#undef KERN_SIG
#define KERN_SIG ctype x, ctype y
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN_ALL(NEQ, x != y);
#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN
#undef DEF_KERN_FLOAT
#undef DEF_KERN_INT
#undef DEF_KERN_ALL
364

M
Megvii Engine Team 已提交
365
}  // namespace megdnn
366 367 368 369 370 371 372 373

#if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
#undef MEGDNN_HOST_DEVICE_SELF_DEFINE
#undef __host__
#undef __device__
#endif

// vim: ft=cpp syntax=cpp.doxygen