strategy.cpp 11.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14
 */

#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
15 16
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h"
17
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
18 19 20
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
21 22
#include "src/common/utils.h"

23

24 25 26 27 28 29
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);

30 31
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
                        int k0, int kmax, bool transpose_A) const {
32
    if (transpose_A) {
33 34
        matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0,
                                                 kmax);
35
    } else {
36 37
        matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0,
                                                 kmax);
38 39 40 41 42 43
    }
}

void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
                        int k0, int kmax, bool transpose_B) const {
    if (transpose_B) {
44 45
        matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0,
                                                 kmax);
46
    } else {
47 48
        matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0,
                                                 kmax);
49 50 51
    }
}

52 53 54
void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
                      size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
                      const float*, float*) const {
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
                  A_dtype.enumv() == C_dtype.enumv() &&
                  A_dtype.enumv() == DTypeEnum::Float32);
    MEGDNN_MARK_USED_VAR(A_dtype);
    MEGDNN_MARK_USED_VAR(B_dtype);
    MEGDNN_MARK_USED_VAR(C_dtype);

    constexpr size_t A_INTERLEAVE = 4;
    constexpr size_t B_INTERLEAVE = 16;
    const int K16 = K * 16;
    const int K4 = K * 4;

    size_t m = 0;
    for (; m < M; m += A_INTERLEAVE) {
        float* output = C + (m * LDC);

        size_t n = 0;
        const float* cur_packB = packB;
        for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
74 75 76
            matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC,
                                           is_first_k,
                                           std::min<size_t>(M - m, 4));
77 78 79 80 81
            output += B_INTERLEAVE;
            cur_packB += K16;
        }

        for (; n < N; n += 4) {
82 83 84
            matmul_general_4x16::kern_4x4(
                    packA, cur_packB, K, output, LDC, is_first_k,
                    std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
85 86 87 88 89 90 91 92 93 94
            output += 4;
            cur_packB += K4;
        }

        packA += K4;
    }
}

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);

95 96
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
                        int k0, int kmax, bool transpose_A) const {
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    if (transpose_A) {
        matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
                                                 kmax);
    } else {
        matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0,
                                                 kmax);
    }
}

void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
                        int k0, int kmax, bool transpose_B) const {
    if (transpose_B) {
        matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0,
                                                 kmax);
    } else {
        matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0,
                                                 kmax);
    }
}

117 118 119 120
template <typename gemm_class>
static inline void sgemm_8x12_helper(const float* packA, const float* packB,
                                     size_t M, size_t N, size_t K, float* C,
                                     size_t LDC, bool is_first_k) {
121 122 123 124 125 126 127 128 129 130 131 132 133 134
    constexpr size_t A_INTERLEAVE = 8;
    constexpr size_t A_INTERLEAVE4 = 4;
    constexpr size_t B_INTERLEAVE = 12;
    const int K12 = K * 12;
    const int K8 = K * 8;
    const int K4 = K * 4;

    size_t m = 0;
    for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
        float* output = C + (m * LDC);

        size_t n = 0;
        const float* cur_packB = packB;
        for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
135
            gemm_class::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
136 137 138 139 140
            output += B_INTERLEAVE;
            cur_packB += K12;
        }

        for (; n < N; n += 4) {
141 142
            gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
                                 std::min<size_t>(N - n, 4));
143 144 145 146 147 148 149 150 151 152
            output += 4;
            cur_packB += K4;
        }
        packA += K8;
    }
    for (; m < M; m += A_INTERLEAVE4) {
        float* output = C + (m * LDC);
        size_t n = 0;
        const float* cur_packB = packB;
        for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
153 154
            gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k,
                                  std::min<size_t>(M - m, 4));
155 156 157 158 159
            output += B_INTERLEAVE;
            cur_packB += K12;
        }

        for (; n < N; n += 4) {
160 161 162
            gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
                                 std::min<size_t>(M - m, 4),
                                 std::min<size_t>(N - n, 4));
163 164 165 166 167 168 169
            output += 4;
            cur_packB += K4;
        }
        packA += K4;
    }
}

170 171 172 173 174 175 176 177 178 179 180 181 182 183
void sgemm_8x12::kern(const float* packA, const float* packB, size_t M,
                      size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
                      const float*, float*) const {
    megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
                  A_dtype.enumv() == C_dtype.enumv() &&
                  A_dtype.enumv() == DTypeEnum::Float32);
    MEGDNN_MARK_USED_VAR(A_dtype);
    MEGDNN_MARK_USED_VAR(B_dtype);
    MEGDNN_MARK_USED_VAR(C_dtype);
#if !MGB_ENABLE_CPUINFO
    sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
                                           is_first_k);
#else
    auto arch = cpuinfo_get_current_core()->uarch;
M
Megvii Engine Team 已提交
184 185 186
#ifdef __IN_TEE_ENV__
    arch = cpuinfo_uarch_unknown;
#endif
187 188 189 190 191 192 193 194 195 196 197 198 199
    if (arch == cpuinfo_uarch_cortex_a53) {
        sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C,
                                                   LDC, is_first_k);
    } else if (arch == cpuinfo_uarch_cortex_a55) {
        sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C,
                                                   LDC, is_first_k);
    } else {
        sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
                                               is_first_k);
    }
#endif
}

200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);

void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
                            int ymax, int k0, int kmax,
                            bool transpose_A) const {
    megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A");
    matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax);
}

void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
                            int xmax, int k0, int kmax,
                            bool transpose_B) const {
    megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B");
    matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}

216 217 218 219 220 221 222
template <typename gemm_name>
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
                                         size_t M, size_t N, size_t K, float* C,
                                         size_t LDC, bool is_first_k) {
    const int K12 = K * 12;
    const int K8 = K * 8;
    const int K4 = K * 4;
223 224 225 226 227 228 229 230 231 232 233
    constexpr size_t PACK_C_SIZE = 4;
    constexpr size_t A_INTERLEAVE = 8;
    constexpr size_t A_INTERLEAVE4 = 4;
    constexpr size_t B_INTERLEAVE = 12;
    size_t m = 0;
    for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
        float* output = C + (m / PACK_C_SIZE * LDC);

        size_t n = 0;
        const float* cur_packB = packB;
        for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
234
            gemm_name::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
235 236 237 238
            output += B_INTERLEAVE * PACK_C_SIZE;
            cur_packB += K12;
        }

239 240 241
        for (; n < N; n += 4) {
            gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
                                std::min<size_t>(N - n, 4));
242 243 244 245 246 247 248 249 250 251
            output += 4 * PACK_C_SIZE;
            cur_packB += K4;
        }
        packA += K8;
    }
    for (; m < M; m += A_INTERLEAVE4) {
        float* output = C + (m / PACK_C_SIZE * LDC);
        size_t n = 0;
        const float* cur_packB = packB;
        for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
252
            gemm_name::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
253 254 255 256
            output += B_INTERLEAVE * PACK_C_SIZE;
            cur_packB += K12;
        }
        for (; n < N; n += 4) {
257 258
            gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
                                std::min<size_t>(N - n, 4));
259 260 261 262 263 264
            output += 4 * PACK_C_SIZE;
            cur_packB += K4;
        }
        packA += K4;
    }
}
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M,
                          size_t N, size_t K, float* C, size_t LDC,
                          bool is_first_k, const float*, float*) const {
    megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
                  A_dtype.enumv() == C_dtype.enumv() &&
                  A_dtype.enumv() == DTypeEnum::Float32);
    MEGDNN_MARK_USED_VAR(A_dtype);
    MEGDNN_MARK_USED_VAR(B_dtype);
    MEGDNN_MARK_USED_VAR(C_dtype);
    megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
#if !MGB_ENABLE_CPUINFO
    sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
                                           is_first_k);
#else
    auto arch = cpuinfo_get_current_core()->uarch;
M
Megvii Engine Team 已提交
280 281 282
#ifdef __IN_TEE_ENV__
    arch = cpuinfo_uarch_unknown;
#endif
283 284 285 286 287 288 289 290 291 292 293 294
    if (arch == cpuinfo_uarch_cortex_a53) {
        sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C,
                                                   LDC, is_first_k);
    } else if (arch == cpuinfo_uarch_cortex_a55) {
        sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C,
                                                   LDC, is_first_k);
    } else {
        sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
                                               is_first_k);
    }
#endif
}
295

296
// vim: syntax=cpp.doxygen