gi_common.h 13.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * \file dnn/src/fallback/general_intrinsic/gi_common.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "math.h"
#include "stdint.h"
16
#include "string.h"
17 18 19 20 21 22

#if defined(_WIN32)
#include <intrin.h>
#include <windows.h>
#else
#if defined(__arm__) || defined(__aarch64__)
23
#include "src/arm_common/simd_macro/marm_neon.h"
24 25 26 27
#endif
#if defined(__x86_64__) || defined(__i386__)
#include <cpuid.h>
#include <immintrin.h>
28 29 30 31
#endif
#endif

#if defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
32 33
#define GI_TARGET_X86
#endif
34 35 36

#if defined(__arm__) || defined(__aarch64__)
#define GI_TARGET_ARM
37 38 39 40
#endif

#ifdef _WIN32
//! GI stand for general intrinsic
41
#define _GI_ALIGN_16                           __declspec(align(16))
42 43
#define GI_DECLSPEC_ALIGN(variable, alignment) DECLSPEC_ALIGN(alignment) variable
#else
44
#define _GI_ALIGN_16 __attribute__((aligned(16)))
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
#define GI_DECLSPEC_ALIGN(variable, alignment) \
    variable __attribute__((aligned(alignment)))
#endif

#if defined(_MSC_VER)
#define GI_FORCEINLINE __forceinline
#else
#define GI_FORCEINLINE __attribute__((always_inline)) inline
#endif

#if defined(_MSC_VER)
#define GI_INTERNAL_DATA extern "C"
#else
#define GI_INTERNAL_DATA extern "C" __attribute((visibility("hidden")))
#endif

#if defined(GI_TARGET_ARM)
#define GI_NEON_INTRINSICS
#if defined(__aarch64__)
#define GI_NEON64_INTRINSICS
65
#define GI_NEON32_INTRINSICS
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
#else
#define GI_NEON32_INTRINSICS
#endif
#elif defined(GI_TARGET_X86)
//#if defined(__FMA__)
//#define GI_FMA_INTRINSICS
//#define GI_AVX2_INTRINSICS
//#define GI_AVX_INTRINSICS
//#elif defined(__AVX2__)
//#define GI_AVX2_INTRINSICS
//#define GI_AVX_INTRINSICS
//#elif defined(__AVX__)
//#define GI_AVX_INTRINSICS
#if defined(__SSE4_2__)
#define GI_SSE42_INTRINSICS
#define GI_SSE2_INTRINSICS
#elif defined(__SSE2__)
#define GI_SSE2_INTRINSICS
#endif
#endif

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
#if defined(GI_TEST_NAIVE)
#undef GI_NEON_INTRINSICS
#undef GI_NEON64_INTRINSICS
#undef GI_NEON32_INTRINSICS
#undef GI_FMA_INTRINSICS
#undef GI_AVX2_INTRINSICS
#undef GI_AVX_INTRINSICS
#undef GI_SSE42_INTRINSICS
#undef GI_SSE2_INTRINSICS
#endif

//! general intrinsic support dynamic length simd, if avx or avx2 the simd
//! length is 256
#if defined(GI_AVX_INTRINSICS) || defined(GI_AVX2_INTRINSICS) || \
        defined(GI_FMA_INTRINSICS)
//! if neon and sse the simd lenght is 128
#define GI_SIMD_LEN      256
#define GI_SIMD_LEN_BYTE 32
#elif defined(GI_NEON_INTRINSICS) || defined(GI_SSE2_INTRINSICS) || \
        defined(GI_SSE42_INTRINSICS)
#define GI_SIMD_LEN      128
#define GI_SIMD_LEN_BYTE 16
#else
//! if no simd hardware support, the simd is implemented by C, default set to
//! 128
#define GI_SIMD_LEN      128
#define GI_SIMD_LEN_BYTE 16
#endif

#define gi_trap() __builtin_trap()

//! for ci test now
enum GiSimdType {
    GI_UNKNOWN,
    GI_NAIVE,
    GI_AVX,
    GI_SSE42,
    GI_SSE2,
    GI_NEON,
};

128 129
#if defined(GI_AVX_INTRINSICS) || defined(GI_AVX2_INTRINSICS) || \
        defined(GI_FMA_INTRINSICS)
130
#define __gi_simd_type GI_AVX
131 132 133 134 135 136
typedef __m256 GI_FLOAT32_t;
typedef __m256i GI_UINT8_t;
typedef __m256i GI_INT8_t;
typedef __m256i GI_INT16_t;
typedef __m256i GI_INT32_t;
typedef __m256i GI_UINT32_t;
137
#elif defined(GI_NEON_INTRINSICS)
138
#define __gi_simd_type GI_NEON
139 140 141 142 143 144
typedef float32x4_t GI_FLOAT32_t;
typedef uint8x16_t GI_UINT8_t;
typedef int8x16_t GI_INT8_t;
typedef int16x8_t GI_INT16_t;
typedef int32x4_t GI_INT32_t;
typedef uint32x4_t GI_UINT32_t;
145 146 147 148 149 150 151
typedef float32x4x2_t GI_FLOAT32_V2_t;
typedef float32x4x4_t GI_FLOAT32_V4_t;
typedef int32x4x2_t GI_INT32_V2_t;
typedef int32x4x4_t GI_INT32_V4_t;
typedef int16x8x2_t GI_INT16_V2_t;
typedef int8x16x2_t GI_INT8_V2_t;
typedef int64x2_t GI_INT64_t;
152
#elif defined(GI_SSE2_INTRINSICS) || defined(GI_SSE42_INTRINSICS)
153 154 155 156 157 158 159 160 161 162

#if defined(GI_SSE42_INTRINSICS)
#define __gi_simd_type GI_SSE42
#elif defined(GI_SSE2_INTRINSICS)
#define __gi_simd_type GI_SSE2
#else
#define __gi_simd_type GI_UNKNOWN
#error "code issue happened!!"
#endif

163 164 165 166 167 168
typedef __m128 GI_FLOAT32_t;
typedef __m128i GI_UINT8_t;
typedef __m128i GI_INT8_t;
typedef __m128i GI_INT16_t;
typedef __m128i GI_INT32_t;
typedef __m128i GI_UINT32_t;
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
typedef __m128i GI_INT64_t;
#define _INSERTPS_NDX(srcField, dstField) (((srcField) << 6) | ((dstField) << 4))
#define _M64(out, inp)                    _mm_storel_epi64((__m128i*)&(out), inp)
#define _pM128i(a)                        _mm_loadl_epi64((__m128i*)&(a))
#define _pM128(a)                         _mm_castsi128_ps(_pM128i(a))
#define _M128i(a)                         _mm_castps_si128(a)
#define _M128(a)                          _mm_castsi128_ps(a)
#if defined(__x86_64__)
#define _M64f(out, inp) out.m64_i64[0] = _mm_cvtsi128_si64(_M128i(inp));
#else
#define _M64f(out, inp) _mm_storel_epi64((__m128i*)&(out), _M128i(inp))
#endif
#define _SSE_SWITCH16(NAME, a, b, LANE) \
    switch (LANE) {                     \
        case 0:                         \
            return NAME(a b, 0);        \
        case 1:                         \
            return NAME(a b, 1);        \
        case 2:                         \
            return NAME(a b, 2);        \
        case 3:                         \
            return NAME(a b, 3);        \
        case 4:                         \
            return NAME(a b, 4);        \
        case 5:                         \
            return NAME(a b, 5);        \
        case 6:                         \
            return NAME(a b, 6);        \
        case 7:                         \
            return NAME(a b, 7);        \
        case 8:                         \
            return NAME(a b, 8);        \
        case 9:                         \
            return NAME(a b, 9);        \
        case 10:                        \
            return NAME(a b, 10);       \
        case 11:                        \
            return NAME(a b, 11);       \
        case 12:                        \
            return NAME(a b, 12);       \
        case 13:                        \
            return NAME(a b, 13);       \
        case 14:                        \
            return NAME(a b, 14);       \
        case 15:                        \
            return NAME(a b, 15);       \
        default:                        \
            gi_trap();                  \
            return NAME(a b, 0);        \
    }
#if !defined(__SSE3__)
GI_FORCEINLINE __m128i _sse2_mm_alignr_epi8(__m128i b, __m128i a, int imm8) {
    int imm2 = sizeof(__m128i) - imm8;
    return _mm_or_si128(_mm_srli_si128(a, imm8), _mm_slli_si128(b, imm2));
}
#endif

#define _SSE_COMMA ,
GI_FORCEINLINE __m128i _MM_ALIGNR_EPI8(__m128i a, __m128i b, int LANE) {
#if !defined(__SSE3__)
    _SSE_SWITCH16(_sse2_mm_alignr_epi8, a, _SSE_COMMA b, LANE)
#else
    _SSE_SWITCH16(_mm_alignr_epi8, a, _SSE_COMMA b, LANE)
#endif
}
typedef float float32_t;
typedef double float64_t;
typedef union __m64_128 {
    uint64_t m64_u64[1];
    int64_t m64_i64[1];
    float64_t m64_d64[1];
    uint32_t m64_u32[2];
    int32_t m64_i32[2];
    float32_t m64_f32[2];
    int16_t m64_i16[4];
    uint16_t m64_u16[4];
    int8_t m64_i8[8];
    uint8_t m64_u8[8];
} __m64_128;
typedef __m64_128 float32x2_t;

#define return64(a) \
    _M64(res64, a); \
    return res64;
#define return64f(a) \
    _M64f(res64, a); \
    return res64;
#define _sse_vextq_s32(a, b, c)       _MM_ALIGNR_EPI8(b, a, c * 4)
#define _sse_vget_lane_f32(vec, lane) vec.m64_f32[lane]
258
#else
259
#define __gi_simd_type GI_NAIVE
260 261 262 263 264 265
typedef float GI_FLOAT32_t __attribute__((vector_size(16)));
typedef uint8_t GI_UINT8_t __attribute__((vector_size(16)));
typedef int8_t GI_INT8_t __attribute__((vector_size(16)));
typedef int16_t GI_INT16_t __attribute__((vector_size(16)));
typedef int32_t GI_INT32_t __attribute__((vector_size(16)));
typedef uint32_t GI_UINT32_t __attribute__((vector_size(16)));
266 267 268
typedef int64_t GI_INT64_t __attribute__((vector_size(16)));
#if !defined(__arm__) && !defined(__aarch64__)
typedef float float32x2_t __attribute__((vector_size(8)));
269
#endif
270
typedef float float32_t;
271 272
#endif

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
//! some GI api do not support full GiSimdType
//! for example: GiAbsInt32 do not imp SSE2 case
//! when *_t will define as _m128*(may be long long)
//! vector index do not have same logic as naive vector
typedef float GI_FLOAT32_NAIVE_t __attribute__((vector_size(16)));
typedef uint8_t GI_UINT8_NAIVE_t __attribute__((vector_size(16)));
typedef int8_t GI_INT8_NAIVE_t __attribute__((vector_size(16)));
typedef int16_t GI_INT16_NAIVE_t __attribute__((vector_size(16)));
typedef int32_t GI_INT32_NAIVE_t __attribute__((vector_size(16)));
typedef uint32_t GI_UINT32_NAIVE_t __attribute__((vector_size(16)));
typedef int64_t GI_INT64_NAIVE_t __attribute__((vector_size(16)));
typedef float float32x2_NAIVE_t __attribute__((vector_size(8)));
typedef struct {
    GI_INT32_NAIVE_t val[2];
} GI_INT32_V2_NAIVE_t;

typedef struct {
    GI_INT32_NAIVE_t val[4];
} GI_INT32_V4_NAIVE_t;

typedef struct {
    GI_FLOAT32_NAIVE_t val[2];
} GI_FLOAT32_V2_NAIVE_t;

typedef struct {
    GI_FLOAT32_NAIVE_t val[4];
} GI_FLOAT32_V4_NAIVE_t;

typedef struct {
    GI_INT16_NAIVE_t val[2];
} GI_INT16_V2_NAIVE_t;

typedef struct {
    GI_INT8_NAIVE_t val[2];
} GI_INT8_V2_NAIVE_t;

309 310 311
#define Max(a, b) (a) > (b) ? (a) : (b)
#define Min(a, b) (a) < (b) ? (a) : (b)

312 313 314 315 316 317 318 319 320 321 322 323
#if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS)
#define v_fma_ps_f32(c, b, a)         vfmaq_f32((c), (b), (a))
#define v_fma_n_f32(c, b, a)          vfmaq_n_f32((c), (b), (a))
#define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane))
#else
#define v_fma_ps_f32(c, b, a)         vmlaq_f32((c), (b), (a))
#define v_fma_n_f32(c, b, a)          vmlaq_n_f32((c), (b), (a))
#define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane))
#endif
#endif

324
#if !defined(GI_NEON_INTRINSICS)
325
typedef struct {
326 327
    GI_INT32_t val[2];
} GI_INT32_V2_t;
328 329

typedef struct {
330 331
    GI_INT32_t val[4];
} GI_INT32_V4_t;
332 333

typedef struct {
334 335
    GI_FLOAT32_t val[2];
} GI_FLOAT32_V2_t;
336 337

typedef struct {
338 339 340 341 342 343 344 345 346 347
    GI_FLOAT32_t val[4];
} GI_FLOAT32_V4_t;

typedef struct {
    GI_INT16_t val[2];
} GI_INT16_V2_t;

typedef struct {
    GI_INT8_t val[2];
} GI_INT8_V2_t;
348
#endif
349 350

GI_FORCEINLINE
351
GI_INT32_t GiAndInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) {
352 353 354 355 356 357 358 359 360 361
#if defined(GI_NEON_INTRINSICS)
    return vandq_s32(Vector1, Vector2);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_and_si128(Vector1, Vector2);
#else
    return Vector1 & Vector2;
#endif
}

GI_FORCEINLINE
362
GI_INT32_t GiOrInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) {
363 364 365 366 367 368 369 370 371 372
#if defined(GI_NEON_INTRINSICS)
    return vorrq_s32(Vector1, Vector2);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_or_si128(Vector1, Vector2);
#else
    return Vector1 | Vector2;
#endif
}

GI_FORCEINLINE
373
GI_INT32_t GiAndNotInt32(GI_INT32_t VectorNot, GI_INT32_t Vector) {
374 375 376 377 378 379 380 381 382 383
#if defined(GI_NEON_INTRINSICS)
    return vandq_s32(vmvnq_s32(VectorNot), Vector);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_andnot_si128(VectorNot, Vector);
#else
    return (~VectorNot) & Vector;
#endif
}

GI_FORCEINLINE
384
GI_INT32_t GiXorInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) {
385 386 387 388 389 390 391 392 393
#if defined(GI_NEON_INTRINSICS)
    return veorq_s32(Vector1, Vector2);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_xor_si128(Vector1, Vector2);
#else
    return Vector1 ^ Vector2;
#endif
}

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
GI_FORCEINLINE
GI_FLOAT32_t GiBroadcastFloat32(float Value) {
#if defined(GI_NEON_INTRINSICS)
    return vdupq_n_f32(Value);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_set1_ps(Value);
#else
    GI_FLOAT32_t ret;
    for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) {
        ret[i] = Value;
    }
    return ret;
#endif
}

GI_FORCEINLINE
GI_INT32_t GiBroadcastInt32(int32_t Value) {
#if defined(GI_NEON_INTRINSICS)
    return vdupq_n_s32(Value);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_set1_epi32(Value);
#else
    GI_INT32_t ret;
    for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
        ret[i] = Value;
    }
    return ret;
#endif
}

GI_FORCEINLINE
GI_INT8_t GiBroadcastInt8(int8_t Value) {
#if defined(GI_NEON_INTRINSICS)
    return vdupq_n_s8(Value);
#elif defined(GI_SSE2_INTRINSICS)
    return _mm_set1_epi8(Value);
#else
    GI_INT8_t ret;
    for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) {
        ret[i] = Value;
    }
    return ret;
#endif
}

439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
GI_FORCEINLINE
GiSimdType GiGetSimdType() {
    //! override by special macro to insure ci have test naive and sse2
    //! now we do not imp GI_AVX to now and x64 ci device will test GI_SSE42
    //! now arm ci device will test GI_NEON
    //! insure test GI_SSE2 by command:
    //! --copt -march=core2 --copt -mno-sse4.2
    //! --copt -mno-sse3 --copt -DGI_TEST_SSE2
    //! insure test GI_NAIVE by command:
    //! --copt -DGI_TEST_SSE2
    //! DNN code at least need sse2 at x86
    //! so we can not test GI_NAIVE by
    //! --copt -march=core2 --copt -mno-sse4.2
    //! --copt -mno-sse3 --copt -mno-sse2
    //! --copt -DGI_TEST_NAIVE
    //! about CMake, can override build flags to CMAKE_CXX_FLAGS/CMAKE_C_FLAGS by
    //! EXTRA_CMAKE_ARGS when use scripts/cmake-build/*.sh
#if defined(GI_TEST_NAIVE)
#undef __gi_simd_type
#define __gi_simd_type GI_NAIVE
#elif defined(GI_TEST_SSE2)
#undef __gi_simd_type
#define __gi_simd_type GI_SSE2
#endif

    return __gi_simd_type;
}

467 468 469 470 471 472 473 474
__attribute__((unused)) const GI_INT8_t vzero_int8 = GiBroadcastInt8(0);
__attribute__((unused)) const GI_INT32_t vzero = GiBroadcastInt32(0);
__attribute__((unused)) const GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
__attribute__((unused)) const GI_FLOAT32_t vfhalf = GiBroadcastFloat32(0.5f);
__attribute__((unused)) const GI_FLOAT32_t vfneg_half = GiBroadcastFloat32(-0.5f);
__attribute__((unused)) const GI_FLOAT32_t vfmin_int8 = GiBroadcastFloat32(-128.0f);
__attribute__((unused)) const GI_FLOAT32_t vfmax_int8 = GiBroadcastFloat32(127.0f);

475
// vim: syntax=cpp.doxygen