opr_impl.h 9.6 KB
Newer Older
1
#pragma once
2
#include <unordered_map>
3
#include "megdnn/opr_param_defs.h"
4 5
#include "megdnn/oprs/base.h"
#include "src/common/algo_base.h"
6
#include "src/common/utils.h"
7
#include "src/naive/matrix_mul/opr_impl.h"
8

9 10
namespace megdnn {

11 12 13 14 15 16
struct AlgoTypePack {
    detail::AlgoDataType data_type : 32;
    param::MatrixMul::Format format : 32;
};

namespace fallback {
17 18 19
class MatrixMulImpl : public naive::MatrixMulForwardImpl {
public:
    using naive::MatrixMulForwardImpl::MatrixMulForwardImpl;
20
    using AlgoDataType = detail::AlgoDataType;
21 22 23

    bool is_thread_safe() const override { return true; }

M
Megvii Engine Team 已提交
24 25
    size_t get_workspace_in_bytes(
            const TensorLayout&, const TensorLayout&, const TensorLayout&) override;
26

M
Megvii Engine Team 已提交
27 28 29
    void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) override;
30 31 32 33 34 35 36 37

    struct KernSizeParam {
        DType A_type, B_type, C_type;
        size_t M, N, K;
        size_t LDA, LDB, LDC;
        bool trA, trB;
        Param::ComputeMode compute_mode;
        Param::Format format;
38 39
        //! get the data type category of the param for select the algo
        AlgoDataType deduce_algo_data_type() const;
40 41 42
    };

    struct KernParam : public KernSizeParam {
43 44 45 46 47
        RefPtr A_ptr;
        RefPtr B_ptr;
        RefPtr C_ptr;
        void* workspace_ptr = nullptr;
        size_t workspace_size = 0;
48 49 50 51

        template <typename T>
        inline const T* A() const {
            // A_type.assert_is_compatible_ctype<T>();
52
            return static_cast<const T*>(A_ptr.get_ptr());
53 54 55 56 57
        }

        template <typename T>
        inline const T* B() const {
            // B_type.assert_is_compatible_ctype<T>();
58
            return static_cast<const T*>(B_ptr.get_ptr());
59 60 61 62 63
        }

        template <typename T>
        inline T* C() const {
            // C_type.assert_is_compatible_ctype<T>();
64
            return static_cast<T*>(C_ptr.get_ptr());
65 66 67 68 69 70 71 72
        }
        template <typename T>
        inline T* workspace() const {
            return static_cast<T*>(workspace_ptr);
        }
    };

    typedef void (*kern_t)(const KernParam&);
M
Megvii Engine Team 已提交
73 74
    typedef void (*kern_naked_t)(
            const KernParam&, const void* a_panel, const void* b_panel);
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    class AlgoBase : public Algorithm {
    protected:
        virtual ~AlgoBase() = default;

        bool can_be_treated_as_int8x8x32(const KernSizeParam& param) const {
            return param.A_type.enumv() == param.B_type.enumv() &&
                   (param.A_type.enumv() == DTypeEnum::Int8 ||
                    param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
                   (param.C_type.enumv() == DTypeEnum::Int32 ||
                    param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
                   param.compute_mode == Param::ComputeMode::DEFAULT &&
                   param.format == param::MatrixMul::Format::DEFAULT;
        }

        bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const {
            return param.A_type.enumv() == param.B_type.enumv() &&
91 92 93 94
                   (param.A_type.enumv() == DTypeEnum::Int8 ||
                    param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
                   (param.C_type.enumv() == DTypeEnum::Int16 ||
                    param.C_type.enumv() == DTypeEnum::QuantizedS16);
95
        }
96

97
    public:
98
        AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; }
99 100 101 102
        enum class AlgoType : uint32_t {
            //! fallback
            FB_F32K8x12x1 = 1 << 0,
            FB_GEMV,
103
            FB_NAIVE,
104 105
            FB_GI_F32_GEMV_MK4,
            FB_GI_F32_MK4_4x8,
106
            FB_GI_F32_MK4_PACK_4x12,
107
            FB_GI_F32_4x12,
108 109 110 111 112 113 114 115 116 117 118

#if MEGDNN_X86
            //! x86
            X86_F32_BLAS = 1 << 8,
            X86_F32_MKL_PACKA,
            X86_INT8X8X32_AVX2_2X4X16,
            X86_INT8X8X32_AVX2_4X16X2,
            X86_INT8X8X16_AVX2,
            X86_INT8X8X16_SSE,
            X86_INT8X8X32_SSE_4X8X2,
            X86_F32_MK8_8X8,
119
            X86_F32_6x16,
120 121 122 123 124 125 126
            X86_INT8X8X32_VNNI,
            X86_INT8X8X32_MKLDNN,
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
            ARM_COMMON_INT8X8X16 = 1 << 8,
            ARM_COMMON_INT8X8X32_GEMV,
            ARM_COMMON_INT8X8X32_GEMV_MK4,
            ARM_COMMON_INT8X8X32_GEMV_MK4_DOT,
127 128
            ARM_COMMON_INT8X8X32_GEVM_DOT,
            ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT,
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
            ARM_COMMON_F16_GEMV,
            ARM_COMMON_GEVM,
#if MEGDNN_AARCH64
            AARCH64_F32_K8X12X1 = 1 << 16,
            AARCH64_F32_MK4_K8X12X1,
            AARCH64_F32_K4X16X1,
            AARCH64_F32_MK4_4x16,
            AARCH64_F32_GEMV,
            AARCH64_F16_K8X24X1,
            AARCH64_F16_MK8_8X8,
            AARCH64_INT8X8X32_K8X12X4_DOTPROD,
            AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD,
            AARCH64_INT8X8X32_MK4_4X4X16,
            AARCH64_INT8X8X32_K4X4X16,
            AARCH64_INT8X8X32_K8X8X8,
            AARCH64_INT8X8X16_K8X8X8,
            AARCH64_INT8X8X16_K4X4X16,
            AARCH64_INT8X8X16_MK4_16X12X4,
            AARCH64_INT8X8X16_MK4_K8X8X8,
            AARCH64_INT8X8X16_MK4_4X4X8,
            AARCH64_INT16X16X32_K12X8X1,
            AARCH64_INT16X16X32_MK8_8X8,
            AARCH64_QUINT8_K8X8X4_DOTPROD,
            AARCH64_QUINT8_GEMV_DOTPROD,
            AARCH64_QUINT8_K8X8X8,
154
            AARCH64_INT4X4X16_K8X8X8,
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
#else
            ARMV7_F32 = 1 << 16,
            ARMV7_F32_MK4_PACK_4X12,
            ARMV7_F32_MK4_4x8,
            ARMV7_F16_K4X16X1,
            ARMV7_F16_MK8_4X8,
            ARMV7_INT8_K6X8X4,
            ARMV7_QUINT8_K4X8X4,
            ARMV7_INT8_MK4_8X4X4_DOTPROD,
            ARMV7_F32_GEMV,
            ARMV7_INT8X8X32_K4X2X16,
            ARMV7_INT8X8X32_K4X8X8,
            ARMV7_QUINT8_K4X8X8,
            ARMV7_INT8X8X16_K4X2X16,
            ARMV7_INT8X8X16_K4X8X8,
            ARMV7_INT8X8X16_MK4_K8X8X4,
            ARMV7_INT16X16X32_K12X4X1,
            ARMV7_INT16X16X32_MK8_4X8,
173 174
            ARMV7_INT8X8X32_MK4_4X2X16,
            ARMV7_INT8X8X16_K8X8X4
175 176 177 178
#endif
#endif
        };

179
        enum class AlgoSet : uint32_t {
180 181
            ALGO_TYPE_GEMM = 0,
            ALGO_TYPE_GEMV = 1,
182
            ALGO_TYPE_GEVM = 2,
183 184
        };

185
        enum class PackMode : uint32_t {
186 187 188 189 190 191 192 193 194
            DEFAULT = 0,
            NO_PACK = 1,
            ONLY_PACKA = 2,
        };

        struct InnerBlockSize {
            size_t m, n, k;
        };

195 196 197
        struct MatmulDescription {
            PackMode packmode;
            InnerBlockSize innerblocksize;
198
            AlgoTypePack algo_type;
199 200 201
            size_t packa_type_size;
        };

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        virtual bool usable(const KernSizeParam&) const = 0;
        virtual bool preferred(const KernSizeParam&) const { return true; }
        virtual size_t get_workspace(const KernSizeParam&) const = 0;
        virtual kern_t get_kern(const KernSizeParam&) const = 0;
        virtual kern_naked_t get_kern_naked(const KernSizeParam&) const {
            megdnn_assert(0);
        };
        virtual AlgoSet algoset() const { return AlgoSet::ALGO_TYPE_GEMM; }
        virtual PackMode packmode() const { return PackMode::DEFAULT; }
        virtual void pack_A(const KernParam&, void*, size_t, size_t) const {
            megdnn_assert(0);
        };
        virtual void pack_B(const KernParam&, void*, size_t, size_t) const {
            megdnn_assert(0);
        };
        virtual WorkspaceBundle get_bundle(const KernSizeParam&) const {
            megdnn_assert(0);
        };
M
Megvii Engine Team 已提交
220
        virtual InnerBlockSize get_inner_block_size() const { megdnn_assert(0); };
221 222
        bool preferred_attribute(
                const KernSizeParam& param,
M
Megvii Engine Team 已提交
223
                const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
224 225 226
                const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
            return contain_attribute_all(positive_attr) &&
                   !contain_attribute_any(negative_attr) && preferred(param);
227
        };
228
        virtual MatmulDescription matmul_description() const = 0;
229 230

        using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
231 232
    };

233
private:
234 235 236 237 238
    class AlgoF32K8x12x1;        // Fallback F32 Kernel 8x12x1
    class AlgoF32GiGemvMK4;      // fallback F32 gi Gemv NCHW44
    class AlgoF32GiMK4_4x8;      // fallback F32 gi Gemm NCHW44
    class AlgoF32GiMK4Pack4x12;  // fallback F32 gi Gemm pack NCHW44
    class AlgoF32Gi4x12;         // fallback F32 gi Gemm
239
    class AlgoGemv;
240
    class AlgoNaive;
241 242 243
    class AlgoPack;
    //! maintain all the algos of in the opr of fallback
    static const AlgoPack& algo_pack();
244 245
    Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;

246
public:
247 248 249
    /**
     * \brief get all the algorithm for the opr.
     */
250
    virtual SmallVector<AlgoBase*> get_all_packed_algo();
251

252 253 254 255 256
    /**
     * \brief select algo according to input algo type
     */
    SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type);

257
protected:
M
Megvii Engine Team 已提交
258 259
    KernSizeParam make_kern_size_param(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C);
260

M
Megvii Engine Team 已提交
261 262 263
    KernParam make_kern_param(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace);
264

M
Megvii Engine Team 已提交
265 266 267
    std::vector<Algorithm*> get_all_algorithms(
            const TensorLayout& A, const TensorLayout& B,
            const TensorLayout& C) override;
268

M
Megvii Engine Team 已提交
269 270 271
    std::vector<Algorithm*> get_all_algorithms_safe(
            const TensorLayout& A, const TensorLayout& B,
            const TensorLayout& C) override;
272

273 274 275 276
    Algorithm* get_algorithm_heuristic(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
            const AlgoAttribute& negative_attr) override;
277 278 279 280 281 282
};

}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen