helper.h 16.3 KB
Newer Older
1 2 3 4
#pragma once
#include "src/common/unroll_macro.h"
#include "src/fallback/general_intrinsic/gi_float.h"

5 6 7 8 9 10 11 12 13
#define ADDF    GiAddFloat32
#define ADDFV2  GiAddFloat32V2
#define SUBF    GiSubtractFloat32
#define SUBFV2  GiSubtractFloat32V2
#define MULF    GiMultiplyFloat32
#define MULFV2  GiMultiplyFloat32V2
#define MULSF   GiMultiplyScalerFloat32
#define MULSFV2 GiMultiplyScalerFloat32V2

14 15 16 17
namespace megdnn {
namespace fallback {
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
    GI_FLOAT32_V2_t a0, a1;
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
    GiSetSubVectorFloat32V2(a0, 0, GiLoadFloat32(src + 0 * lda));
    GiSetSubVectorFloat32V2(a0, 1, GiLoadFloat32(src + 1 * lda));
    GiSetSubVectorFloat32V2(a1, 0, GiLoadFloat32(src + 2 * lda));
    GiSetSubVectorFloat32V2(a1, 1, GiLoadFloat32(src + 3 * lda));
    GI_FLOAT32_V2_t b0 = GiZipqFloat32(
            GiGetSubVectorFloat32V2(a0, 0), GiGetSubVectorFloat32V2(a1, 0));
    GI_FLOAT32_V2_t b1 = GiZipqFloat32(
            GiGetSubVectorFloat32V2(a0, 1), GiGetSubVectorFloat32V2(a1, 1));
    GI_FLOAT32_V2_t c0 = GiZipqFloat32(
            GiGetSubVectorFloat32V2(b0, 0), GiGetSubVectorFloat32V2(b1, 0));
    GI_FLOAT32_V2_t c1 = GiZipqFloat32(
            GiGetSubVectorFloat32V2(b0, 1), GiGetSubVectorFloat32V2(b1, 1));
    GiStoreFloat32(dst + 0 * ldb, GiGetSubVectorFloat32V2(c0, 0));
    GiStoreFloat32(dst + 1 * ldb, GiGetSubVectorFloat32V2(c0, 1));
    GiStoreFloat32(dst + 2 * ldb, GiGetSubVectorFloat32V2(c1, 0));
    GiStoreFloat32(dst + 3 * ldb, GiGetSubVectorFloat32V2(c1, 1));
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
}
}  // namespace fallback
}  // namespace megdnn

#define MATRIX_MUL4x4(sum, a, b)                           \
    sum##0 = GiMlaqLowLaneFloat32(sum##0, b##0, a##0, 0);  \
    sum##0 = GiMlaqLowLaneFloat32(sum##0, b##1, a##0, 1);  \
    sum##0 = GiMlaqHighLaneFloat32(sum##0, b##2, a##0, 2); \
    sum##0 = GiMlaqHighLaneFloat32(sum##0, b##3, a##0, 3); \
    sum##1 = GiMlaqLowLaneFloat32(sum##1, b##0, a##1, 0);  \
    sum##1 = GiMlaqLowLaneFloat32(sum##1, b##1, a##1, 1);  \
    sum##1 = GiMlaqHighLaneFloat32(sum##1, b##2, a##1, 2); \
    sum##1 = GiMlaqHighLaneFloat32(sum##1, b##3, a##1, 3); \
    sum##2 = GiMlaqLowLaneFloat32(sum##2, b##0, a##2, 0);  \
    sum##2 = GiMlaqLowLaneFloat32(sum##2, b##1, a##2, 1);  \
    sum##2 = GiMlaqHighLaneFloat32(sum##2, b##2, a##2, 2); \
    sum##2 = GiMlaqHighLaneFloat32(sum##2, b##3, a##2, 3); \
    sum##3 = GiMlaqLowLaneFloat32(sum##3, b##0, a##3, 0);  \
    sum##3 = GiMlaqLowLaneFloat32(sum##3, b##1, a##3, 1);  \
    sum##3 = GiMlaqHighLaneFloat32(sum##3, b##2, a##3, 2); \
    sum##3 = GiMlaqHighLaneFloat32(sum##3, b##3, a##3, 3);

#define CONCAT(a, idx) a##idx

#if MEGDNN_AARCH64
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
#define TRANSPOSE_8x8(a, ret)                                              \
    do {                                                                   \
        auto b0 = GiZipqFloat32(CONCAT(a, 0).val[0], CONCAT(a, 1).val[0]); \
        auto b1 = GiZipqFloat32(CONCAT(a, 0).val[1], CONCAT(a, 1).val[1]); \
        auto b2 = GiZipqFloat32(CONCAT(a, 2).val[0], CONCAT(a, 3).val[0]); \
        auto b3 = GiZipqFloat32(CONCAT(a, 2).val[1], CONCAT(a, 3).val[1]); \
        auto b4 = GiZipqFloat32(CONCAT(a, 4).val[0], CONCAT(a, 5).val[0]); \
        auto b5 = GiZipqFloat32(CONCAT(a, 4).val[1], CONCAT(a, 5).val[1]); \
        auto b6 = GiZipqFloat32(CONCAT(a, 6).val[0], CONCAT(a, 7).val[0]); \
        auto b7 = GiZipqFloat32(CONCAT(a, 6).val[1], CONCAT(a, 7).val[1]); \
        CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b0.val[0]),                     \
                GiReinterpretqFloat32ToS64(b2.val[0])));                   \
        CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b4.val[0]),                     \
                GiReinterpretqFloat32ToS64(b6.val[0])));                   \
        CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b0.val[0]),                     \
                GiReinterpretqFloat32ToS64(b2.val[0])));                   \
        CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b4.val[0]),                     \
                GiReinterpretqFloat32ToS64(b6.val[0])));                   \
        CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b0.val[1]),                     \
                GiReinterpretqFloat32ToS64(b2.val[1])));                   \
        CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b4.val[1]),                     \
                GiReinterpretqFloat32ToS64(b6.val[1])));                   \
        CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b0.val[1]),                     \
                GiReinterpretqFloat32ToS64(b2.val[1])));                   \
        CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b4.val[1]),                     \
                GiReinterpretqFloat32ToS64(b6.val[1])));                   \
        CONCAT(ret, 4).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b1.val[0]),                     \
                GiReinterpretqFloat32ToS64(b3.val[0])));                   \
        CONCAT(ret, 4).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b5.val[0]),                     \
                GiReinterpretqFloat32ToS64(b7.val[0])));                   \
        CONCAT(ret, 5).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b1.val[0]),                     \
                GiReinterpretqFloat32ToS64(b3.val[0])));                   \
        CONCAT(ret, 5).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b5.val[0]),                     \
                GiReinterpretqFloat32ToS64(b7.val[0])));                   \
        CONCAT(ret, 6).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b1.val[1]),                     \
                GiReinterpretqFloat32ToS64(b3.val[1])));                   \
        CONCAT(ret, 6).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64(     \
                GiReinterpretqFloat32ToS64(b5.val[1]),                     \
                GiReinterpretqFloat32ToS64(b7.val[1])));                   \
        CONCAT(ret, 7).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b1.val[1]),                     \
                GiReinterpretqFloat32ToS64(b3.val[1])));                   \
        CONCAT(ret, 7).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64(     \
                GiReinterpretqFloat32ToS64(b5.val[1]),                     \
                GiReinterpretqFloat32ToS64(b7.val[1])));                   \
117 118
    } while (0);

119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
#define TRANSPOSE_6x6(a, ret)                                    \
    do {                                                         \
        auto b0 = GiZipqFloat32(CONCAT(a, 00), CONCAT(a, 10));   \
        auto b1 = GiZipqFloat32(CONCAT(a, 01), CONCAT(a, 11));   \
        auto b2 = GiZipqFloat32(CONCAT(a, 20), CONCAT(a, 30));   \
        auto b3 = GiZipqFloat32(CONCAT(a, 21), CONCAT(a, 31));   \
        auto b4 = GiZipqFloat32(CONCAT(a, 40), CONCAT(a, 50));   \
        auto b5 = GiZipqFloat32(CONCAT(a, 41), CONCAT(a, 51));   \
        CONCAT(ret, 00) = GiReinterpretqS64ToFloat32(GiZip1qS64( \
                GiReinterpretqFloat32ToS64(b0.val[0]),           \
                GiReinterpretqFloat32ToS64(b2.val[0])));         \
        CONCAT(ret, 01) = b4.val[0];                             \
        CONCAT(ret, 10) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
                GiReinterpretqFloat32ToS64(b0.val[0]),           \
                GiReinterpretqFloat32ToS64(b2.val[0])));         \
        CONCAT(ret, 11) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
                GiReinterpretqFloat32ToS64(b4.val[0]),           \
                GiReinterpretqFloat32ToS64(b5.val[0])));         \
        CONCAT(ret, 20) = GiReinterpretqS64ToFloat32(GiZip1qS64( \
                GiReinterpretqFloat32ToS64(b0.val[1]),           \
                GiReinterpretqFloat32ToS64(b2.val[1])));         \
        CONCAT(ret, 21) = b4.val[1];                             \
        CONCAT(ret, 30) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
                GiReinterpretqFloat32ToS64(b0.val[1]),           \
                GiReinterpretqFloat32ToS64(b2.val[1])));         \
        CONCAT(ret, 31) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
                GiReinterpretqFloat32ToS64(b4.val[1]),           \
                GiReinterpretqFloat32ToS64(b5.val[1])));         \
        CONCAT(ret, 40) = GiReinterpretqS64ToFloat32(GiZip1qS64( \
                GiReinterpretqFloat32ToS64(b1.val[0]),           \
                GiReinterpretqFloat32ToS64(b3.val[0])));         \
        CONCAT(ret, 41) = b5.val[0];                             \
        CONCAT(ret, 50) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
                GiReinterpretqFloat32ToS64(b1.val[0]),           \
                GiReinterpretqFloat32ToS64(b3.val[0])));         \
        CONCAT(ret, 51) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
                GiReinterpretqFloat32ToS64(b5.val[0]),           \
                GiReinterpretqFloat32ToS64(b4.val[0])));         \
    } while (0);

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
#define TRANSPOSE_8x3(a, ret)                                      \
    auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1));           \
    auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3));           \
    auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5));           \
    auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7));           \
    CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b0.val[0]),                 \
            GiReinterpretqFloat32ToS64(b1.val[0])));               \
    CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b2.val[0]),                 \
            GiReinterpretqFloat32ToS64(b3.val[0])));               \
    CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
            GiReinterpretqFloat32ToS64(b0.val[0]),                 \
            GiReinterpretqFloat32ToS64(b1.val[0])));               \
    CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
            GiReinterpretqFloat32ToS64(b2.val[0]),                 \
            GiReinterpretqFloat32ToS64(b3.val[0])));               \
    CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b0.val[1]),                 \
            GiReinterpretqFloat32ToS64(b1.val[1])));               \
    CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b2.val[1]),                 \
181 182
            GiReinterpretqFloat32ToS64(b3.val[1])));

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
#define TRANSPOSE_8x4(a, ret)                                      \
    auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1));           \
    auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3));           \
    auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5));           \
    auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7));           \
    CONCAT(ret, 0).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b0.val[0]),                 \
            GiReinterpretqFloat32ToS64(b1.val[0])));               \
    CONCAT(ret, 0).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b2.val[0]),                 \
            GiReinterpretqFloat32ToS64(b3.val[0])));               \
    CONCAT(ret, 1).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
            GiReinterpretqFloat32ToS64(b0.val[0]),                 \
            GiReinterpretqFloat32ToS64(b1.val[0])));               \
    CONCAT(ret, 1).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
            GiReinterpretqFloat32ToS64(b2.val[0]),                 \
            GiReinterpretqFloat32ToS64(b3.val[0])));               \
    CONCAT(ret, 2).val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b0.val[1]),                 \
            GiReinterpretqFloat32ToS64(b1.val[1])));               \
    CONCAT(ret, 2).val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
            GiReinterpretqFloat32ToS64(b2.val[1]),                 \
            GiReinterpretqFloat32ToS64(b3.val[1])));               \
    CONCAT(ret, 3).val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
            GiReinterpretqFloat32ToS64(b0.val[1]),                 \
            GiReinterpretqFloat32ToS64(b1.val[1])));               \
    CONCAT(ret, 3).val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
            GiReinterpretqFloat32ToS64(b2.val[1]),                 \
211 212 213
            GiReinterpretqFloat32ToS64(b3.val[1])));

#else
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
#define TRANSPOSE_8x4(a, ret)                                           \
    auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1));                \
    auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3));                \
    auto b2 = GiZipqFloat32(CONCAT(a, 4), CONCAT(a, 5));                \
    auto b3 = GiZipqFloat32(CONCAT(a, 6), CONCAT(a, 7));                \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 0), 0,                                          \
            GiCombineFloat32(                                           \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)),    \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0))));  \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 1), 0,                                          \
            GiCombineFloat32(                                           \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)),   \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 2), 0,                                          \
            GiCombineFloat32(                                           \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)),    \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1))));  \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 3), 0,                                          \
            GiCombineFloat32(                                           \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)),   \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 0), 1,                                          \
            GiCombineFloat32(                                           \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)),    \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0))));  \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 1), 1,                                          \
            GiCombineFloat32(                                           \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)),   \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 2), 1,                                          \
            GiCombineFloat32(                                           \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)),    \
                    GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1))));  \
    GiSetSubVectorFloat32V2(                                            \
            CONCAT(ret, 3), 1,                                          \
            GiCombineFloat32(                                           \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)),   \
                    GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1))));
259 260 261

#endif
// vim: syntax=cpp.doxygen