kern_defs.cuh 7.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/elemwise/kern_defs.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

M
Megvii Engine Team 已提交
14
#include "src/common/elemwise/erfinv.h"
15
#include "src/common/elemwise_helper.cuh"
M
Megvii Engine Team 已提交
16
#include "src/common/opr_param_defs_enumv.cuh"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#include "src/common/utils.cuh"

#include "megcore_cdefs.h"
#include "megdnn/dtype.h"

#include <cmath>
#include <cstdlib>

#if MEGDNN_CC_HOST
#include <algorithm>
using std::max;
using std::min;
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
M
Megvii Engine Team 已提交
33
#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL         1
34 35 36 37 38 39 40 41 42 43
#endif

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#endif

namespace megdnn {

M
Megvii Engine Team 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
template <typename T>
__device__ __host__ inline T log_sum_exp(T x, T y) {
    T a, b;
    a = x < y ? x : y;
    b = x < y ? y : x;
    return T(b + log1pf(exp(a - b)));
}

__device__ __host__ inline float fast_tanh(float x) {
    return x * (27.f + x * x) / (27.f + 9.f * x * x);
}

//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't
//! pass
//! --use_fast_math to nvcc to enable --prec_div optimization, which will
//! cause performance drop on Turing architecture
__device__ __host__ inline float fuse_add_hswish(float x, float y) {
    float z = x + y;
    return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f);
}

__device__ __host__ inline float fast_tanh_grad(float x, float dx) {
    float x_pow2 = x * x;
    float deno = 3.f + x_pow2;
    return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx;
}

//! grad of silu
__device__ __host__ inline float silu_grad(float x, float dy) {
    const float one = 1.0;
    float sigmoid = one / (one + expf(-x));
    return dy * sigmoid * (one + x * (one - sigmoid));
}

__device__ __host__ inline float normcdf(float x) {
79
#if MEGDNN_CC_HOST
M
Megvii Engine Team 已提交
80
    return 0.5f * (1.f + erff(x / sqrtf(2.f)));
81
#else
M
Megvii Engine Team 已提交
82 83
    //! use cuda build-in math
    return ::normcdff(x);
84
#endif
M
Megvii Engine Team 已提交
85
}
86

M
Megvii Engine Team 已提交
87 88 89 90 91 92 93 94
//! grad of gelu
__device__ __host__ inline float gelu_grad(float x, float dy) {
    //! 1/ sqrt(2 * pi)
    const float coeff = 0.3989422804014327f;
    float phi = coeff * expf(-0.5f * x * x);
    float normcdf_v = normcdf(x);
    return dy * (normcdf_v + x * phi);
}
95

96 97
#include "src/common/elemwise/each_mode.inl"

M
Megvii Engine Team 已提交
98 99
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
struct ElemwiseKern;
100 101

//! define kernel for a single ctype
M
Megvii Engine Team 已提交
102 103 104 105 106
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
107 108 109
    }

//! define kernel for all float types
M
Megvii Engine Team 已提交
110 111
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
M
Megvii Engine Team 已提交
112 113
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
114 115

//! define kernel for all int types
M
Megvii Engine Team 已提交
116
#define DEF_KERN_INT(_mode, _imp)    \
117 118
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
M
Megvii Engine Team 已提交
119 120
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);
121 122 123

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
M
Megvii Engine Team 已提交
124 125
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);
126

M
Megvii Engine Team 已提交
127
/* ================== unary kernels ================== */
128 129
#define KERN_SIG ctype x

M
Megvii Engine Team 已提交
130 131
// int and float
DEF_KERN_ALL(NEGATE, -x);
132
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
133 134
DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x);
DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x);
135
#else
M
Megvii Engine Team 已提交
136
DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x);
137
#endif
M
Megvii Engine Team 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
DEF_KERN_INT(ABS, abs(int(x)));
// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x);
DEF_KERN_FLOAT(ABS, fabsf(x));

// float only
DEF_KERN_FLOAT(ACOS, acosf(x));
DEF_KERN_FLOAT(ASIN, asinf(x));
DEF_KERN_FLOAT(CEIL, ceilf(x));
DEF_KERN_FLOAT(COS, cosf(x));
DEF_KERN_FLOAT(EXP, expf(x));
DEF_KERN_FLOAT(EXPM1, expm1f(x));
DEF_KERN_FLOAT(FLOOR, floorf(x));
DEF_KERN_FLOAT(LOG, logf(x));
DEF_KERN_FLOAT(LOG1P, log1pf(x));
DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f));
DEF_KERN_FLOAT(SIN, sinf(x));
DEF_KERN_FLOAT(TANH, tanhf(x));
DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x));
DEF_KERN_FLOAT(ROUND, roundf(x));
DEF_KERN_FLOAT(ERF, erff(x));
DEF_KERN_FLOAT(ERFINV, erfinvf(x));
DEF_KERN_FLOAT(ERFC, erfcf(x));
DEF_KERN_FLOAT(ERFCINV, erfcinvf(x));
DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f));
DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f));
DEF_KERN_FLOAT(GELU, x* normcdf(x));

// int only
DEF_KERN(dt_bool, NOT, x ^ 1);
167 168 169

#undef KERN_SIG

M
Megvii Engine Team 已提交
170
/* ================== binary kernels ================== */
171 172
#define KERN_SIG ctype x, ctype y

M
Megvii Engine Team 已提交
173
// int and float
174
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
175 176
DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y);
DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y);
177
#else
M
Megvii Engine Team 已提交
178
DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y);
179
#endif
M
Megvii Engine Team 已提交
180 181 182 183 184 185 186 187 188 189
DEF_KERN_ALL(ADD, x + y);
DEF_KERN_ALL(MAX, x > y ? x : y);
DEF_KERN_ALL(MIN, x < y ? x : y);
DEF_KERN_ALL(MUL, x* y);
DEF_KERN(dt_bool, AND, x&& y);
DEF_KERN(dt_bool, OR, x || y);
DEF_KERN(dt_bool, XOR, x ^ y);
DEF_KERN_INT(RMULH, round_mulh_saturate(x, y));
DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y);
DEF_KERN_ALL(SUB, x - y);
190
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
191 192
DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0));
193
#else
M
Megvii Engine Team 已提交
194
DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0));
195
#endif
M
Megvii Engine Team 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y);
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);

DEF_KERN_INT(FLOOR_DIV, x / y);
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));

DEF_KERN_INT(MOD, x % y);
DEF_KERN_FLOAT(MOD, fmodf(x, y));

DEF_KERN_INT(SHL, x << y);
DEF_KERN_INT(SHR, x >> y);
212
#if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__)
M
Megvii Engine Team 已提交
213 214
DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y));
215
#else
M
Megvii Engine Team 已提交
216
DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y));
217 218
#endif

M
Megvii Engine Team 已提交
219 220 221 222 223
// float only
DEF_KERN_FLOAT(TRUE_DIV, x / y);
DEF_KERN_FLOAT(POW, powf(x, y));
DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y));
DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y));
224

M
Megvii Engine Team 已提交
225 226
DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y));
DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f));
227

M
Megvii Engine Team 已提交
228 229 230 231 232
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(
        H_SWISH_GRAD,
        x < -3.f ? (ctype)0.f
                 : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
233

M
Megvii Engine Team 已提交
234 235 236
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y));
DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
237 238
#undef KERN_SIG

M
Megvii Engine Team 已提交
239
/* ================== ternary kernels ================== */
240 241
#define KERN_SIG ctype x, ctype y, ctype z

M
Megvii Engine Team 已提交
242 243 244
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
245 246 247 248 249 250

#undef KERN_SIG

#undef DEF_KERN_AD
#undef DEF_KERN

M
Megvii Engine Team 已提交
251
}  // namespace megdnn
252 253 254 255 256 257 258 259

#if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE)
#undef MEGDNN_HOST_DEVICE_SELF_DEFINE
#undef __host__
#undef __device__
#endif

// vim: ft=cpp syntax=cpp.doxygen