kern_defs.cuh 8.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/elemwise/kern_defs.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

M
Megvii Engine Team 已提交
14
#include "src/common/elemwise/erfinv.h"
15
#include "src/common/elemwise_helper.cuh"
M
Megvii Engine Team 已提交
16
#include "src/common/opr_param_defs_enumv.cuh"
17 18 19 20 21 22 23 24 25 26 27 28
#include "src/common/utils.cuh"

#include "megcore_cdefs.h"
#include "megdnn/dtype.h"

#include <cmath>
#include <cstdlib>

#if MEGDNN_CC_HOST
#include <algorithm>
using std::max;
using std::min;
29 30

#define rsqrtf(x) (1.f / sqrt(x))
31 32 33 34
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
M
Megvii Engine Team 已提交
35
#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL         1
36 37 38 39 40 41 42 43 44 45
#endif

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#endif

namespace megdnn {

M
Megvii Engine Team 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename T>
__device__ __host__ inline T log_sum_exp(T x, T y) {
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(exp(a - b)));
}

__device__ __host__ inline float fast_tanh(float x) {
    return x * (27.f + x * x) / (27.f + 9.f * x * x);
}

//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
//! pass
//! --use_fast_math to nvcc to enable --prec_div optimization, which will
//! cause performance drop on Turing architecture
__device__ __host__ inline float fuse_add_hswish(float x, float y) {
    float z = x + y;
    return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
}

__device__ __host__ inline float fast_tanh_grad(float x, float dx) {
    float x_pow2 = x * x;
    float deno = 3.f + x_pow2;
    return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}

//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
    const float one = 1.0;
    float sigmoid = one / (one + expf(-x));
    return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
81
#if MEGDNN_CC_HOST
M
Megvii Engine Team 已提交
82
    return 0.5f * (1.f + erff(x / sqrtf(2.f)));
83
#else
M
Megvii Engine Team 已提交
84 85
    //! use cuda build-in math
    return ::normcdff(x);
86
#endif
M
Megvii Engine Team 已提交
87
}
88

M
Megvii Engine Team 已提交
89 90 91 92 93 94 95 96
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
    //! 1/ sqrt(2 * pi)
    const float coeff = 0.3989422804014327f;
    float phi = coeff * expf(-0.5f * x * x);
    float normcdf_v = normcdf(x);
    return dy * (normcdf_v + x * phi);
}
97

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
__device__ __host__ inline bool feq(float a, float b) {
    return fabsf(a - b) < 1e-6;
}

__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
    do {                   \
        if (feq(y, _v)) {  \
            return _stmt;  \
        }                  \
    } while (0)

    CALL_IF(2.f, x * x);
    CALL_IF(0.5f, sqrtf(x));
    CALL_IF(-0.5f, rsqrtf(x));
    CALL_IF(0.f, 1.f);
    CALL_IF(1.f, x);
    CALL_IF(3.f, x * x * x);
    CALL_IF(-1.f, 1.f / x);
    CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
    return powf(x, y);
}

122 123
#include "src/common/elemwise/each_mode.inl"

M
Megvii Engine Team 已提交
124 125
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
struct ElemwiseKern;
126 127

//! define kernel for a single ctype
M
Megvii Engine Team 已提交
128 129 130 131 132
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
133 134 135
    }

//! define kernel for all float types
M
Megvii Engine Team 已提交
136 137
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
M
Megvii Engine Team 已提交
138 139
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
140 141

//! define kernel for all int types
M
Megvii Engine Team 已提交
142
#define DEF_KERN_INT(_mode, _imp)    \
143 144
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
M
Megvii Engine Team 已提交
145 146
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);
147 148 149

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
M
Megvii Engine Team 已提交
150 151
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);
152

M
Megvii Engine Team 已提交
153
/* ================== unary kernels ================== */
154 155
#define KERN_SIG ctype x

M
Megvii Engine Team 已提交
156 157
// int and float
DEF_KERN_ALL(NEGATE, -x);
158
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
159 160
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
161
#else
M
Megvii Engine Team 已提交
162
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
163
#endif
M
Megvii Engine Team 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
DEF_KERN_FLOAT(ABS, fabsf(x));

// float only
DEF_KERN_FLOAT(ACOS, acosf(x));
DEF_KERN_FLOAT(ASIN, asinf(x));
DEF_KERN_FLOAT(CEIL, ceilf(x));
DEF_KERN_FLOAT(COS, cosf(x));
DEF_KERN_FLOAT(EXP, expf(x));
DEF_KERN_FLOAT(EXPM1, expm1f(x));
DEF_KERN_FLOAT(FLOOR, floorf(x));
DEF_KERN_FLOAT(LOG, logf(x));
DEF_KERN_FLOAT(LOG1P, log1pf(x));
DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
DEF_KERN_FLOAT(SIN, sinf(x));
DEF_KERN_FLOAT(TANH, tanhf(x));
DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
DEF_KERN_FLOAT(ROUND, roundf(x));
DEF_KERN_FLOAT(ERF, erff(x));
DEF_KERN_FLOAT(ERFINV, erfinvf(x));
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
193 194 195

#undef KERN_SIG

M
Megvii Engine Team 已提交
196
/* ================== binary kernels ================== */
197 198
#define KERN_SIG ctype x, ctype y

M
Megvii Engine Team 已提交
199
// int and float
200
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
201 202
DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
203
#else
M
Megvii Engine Team 已提交
204
DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
205
#endif
M
Megvii Engine Team 已提交
206 207 208 209 210 211 212 213 214 215
DEF_KERN_ALL(ADD, x + y);
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x&& y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);
216
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
217 218
DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
219
#else
M
Megvii Engine Team 已提交
220
DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
221
#endif
M
Megvii Engine Team 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

DEF_KERN_INT(FLOOR_DIV, x / y);
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

DEF_KERN_INT(MOD, x % y);
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);
DEF_KERN_INT(SHR, x >> y);
238
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
239 240
DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
241
#else
M
Megvii Engine Team 已提交
242
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
243 244
#endif

M
Megvii Engine Team 已提交
245 246 247 248 249
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
DEF_KERN_FLOAT(POW, powf(x, y));
DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
250

M
Megvii Engine Team 已提交
251 252
DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y));
DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f));
253

M
Megvii Engine Team 已提交
254 255 256 257 258
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(
        H_SWISH_GRAD,
        x < -3.f ? (ctype)0.f
                 : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
259

M
Megvii Engine Team 已提交
260 261 262
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
263 264
#undef KERN_SIG

M
Megvii Engine Team 已提交
265
/* ================== ternary kernels ================== */
266 267
#define KERN_SIG ctype x, ctype y, ctype z

M
Megvii Engine Team 已提交
268 269 270
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
271 272 273 274 275 276

#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN

M
Megvii Engine Team 已提交
277
}  // namespace megdnn
278 279 280 281 282 283 284 285

#if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
#undef MEGDNN_HOST_DEVICE_SELF_DEFINE
#undef __host__
#undef __device__
#endif

// vim: ft=cpp syntax=cpp.doxygen