algos.cpp 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/**
 * \file dnn/src/fallback/matrix_mul/algos.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/fallback/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
14
#include "src/fallback/matrix_mul/gemv.h"
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "midout.h"

MIDOUT_DECL(megdnn_fb_matmul_f32_kern)
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)

using namespace megdnn;
using namespace fallback;

/* ===================== F32 8x12x1 algo ===================== */

namespace {
void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) {
    MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern, void) {
        size_t M = kern_param.M, N = kern_param.N, K = kern_param.K;
        matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_param.A_type,
                                              kern_param.B_type,
                                              kern_param.C_type);
        matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
                M, N, K, kern_param.trA, kern_param.trB, strategy)
                .execute(kern_param.A<float>(), kern_param.LDA,
                         kern_param.B<float>(), kern_param.LDB,
                         kern_param.C<float>(), kern_param.LDC,
                         kern_param.workspace_ptr);
    }
    MIDOUT_END();
}
}  // anonymous namespace

bool MatrixMulImpl::AlgoF32K8x12x1::usable(
        const KernSizeParam& kern_size_param) const {
    return kern_size_param.compute_mode ==
                   param::MatrixMul::ComputeMode::DEFAULT &&
           kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
           kern_size_param.B_type == kern_size_param.A_type &&
           kern_size_param.C_type == kern_size_param.A_type &&
           kern_size_param.A_type == dtype::Float32{};
}

size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
        const KernSizeParam& kern_size_param) const {
56 57 58 59 60 61 62 63 64 65 66 67 68 69
    MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern,
                 midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
        auto M = kern_size_param.M, N = kern_size_param.N,
             K = kern_size_param.K;
        matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
                                              kern_size_param.B_type,
                                              kern_size_param.C_type);
        return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
                       M, N, K, kern_size_param.trA, kern_size_param.trB,
                       strategy)
                .get_workspace_size();
    }
    MIDOUT_END();
    return 0;
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
        const KernSizeParam&) const {
    return f32_8x12x1_kern;
}

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
                                     5, matmul::fallback::sgemm_8x12, float,
                                     float);

/* ===================== gemv algo ===================== */
bool MatrixMulImpl::AlgoGemv::usable(
        const KernSizeParam& kern_size_param) const {
    return !kern_size_param.trA && !kern_size_param.trB &&
           kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
           !((kern_size_param.A_type.enumv() ==
              kern_size_param.B_type.enumv()) &&
             (kern_size_param.A_type.enumv() == DTypeEnum::Int16) &&
             (kern_size_param.C_type.enumv() == DTypeEnum::Int32));
}

bool MatrixMulImpl::AlgoGemv::preferred(
        const KernSizeParam& kern_size_param) const {
    return kern_size_param.M <= 2 &&
           kern_size_param.A_type.category() != DTypeCategory::FLOAT;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern(
        const KernSizeParam& kern_size_param) const {
#define DISPATCH(A, C, func, _midout_iv)                               \
    if (kern_size_param.A_type.enumv() == DTypeEnum::A &&              \
        kern_size_param.B_type.enumv() == DTypeEnum::A &&              \
        kern_size_param.C_type.enumv() == DTypeEnum::C &&              \
        kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && \
        kern_size_param.format == param::MatrixMul::Format::DEFAULT) { \
        MIDOUT_BEGIN(megdnn_fb_matmul_f32_gemm_gemv_like,              \
                     midout_iv(_midout_iv)) {                          \
            return func;                                               \
        }                                                              \
        MIDOUT_END();                                                  \
    }

    DISPATCH(Float32, Float32, (gemm_gemv_like<dt_float32, dt_float32>), 0);
    MEGDNN_INC_FLOAT16(DISPATCH(Float16, Float16,
                                (gemm_gemv_like<dt_float16, dt_float16>), 1));
    DISPATCH(Int8, Int16, (gemm_gemv_like<dt_int8, dt_int16>), 2);
    DISPATCH(Quantized8Asymm, QuantizedS32,
             (gemm_gemv_like<dt_uint8, dt_int32, true>), 3);
    if (can_be_treated_as_int8x8x32(kern_size_param)) {
        MIDOUT_BEGIN(megdnn_fb_matmul_f32_gemm_gemv_like, midout_iv(4)) {
            return gemm_gemv_like<dt_int8, dt_int32>;
        }
        MIDOUT_END();
    }
#undef DISPATCH
    megdnn_assert(0);
}

// vim: syntax=cpp.doxygen