algos.h 7.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/matrix_mul/algos.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */

#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include "src/cuda/matrix_mul/opr_impl.h"
17 18
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
19

20
#include <unordered_map>
21
#include <cuda.h>
22
#include <memory>
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#if CUDA_VERSION >= 10010
#include <cublasLt.h>
#endif

namespace megdnn {
namespace cuda {

/*!
 * \brief base class for matrix mul algos
 *
 */
class MatrixMulForwardImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;

public:
39 40 41 42 43
    enum class AlgoType : uint32_t {
        CUDA_CUBLAS,
        CUDA_WMMA_UINT4X4X32,
        CUDA_CUBLASLT,
        CUDA_NAIVE,
44 45
        CUDA_BFLOAT16, 
        CUDA_FLOAT32_SIMT, 
46 47 48
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

49
    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
50 51 52 53 54
    struct SizeArgs {
        MatrixMulForwardImpl* opr;
        TensorLayout layout_a, layout_b, layout_c;

        std::string to_string() const;
55 56
        SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A,
                 const TensorLayout& B, const TensorLayout& C);
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

        bool can_be_treated_as_int8x8x32() const {
            return layout_a.dtype.enumv() == layout_b.dtype.enumv() &&
                   (layout_a.dtype.enumv() == DTypeEnum::Int8 ||
                    layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) &&
                   (layout_c.dtype.enumv() == DTypeEnum::Int32 ||
                    layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) &&
                   opr->param().format == param::MatrixMul::Format::DEFAULT;
        }
    };
    struct ExecArgs : public SizeArgs {
        TensorND tensor_a, tensor_b, tensor_c;
        Workspace workspace;

        ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A,
                 _megdnn_tensor_in B, _megdnn_tensor_out C,
                 _megdnn_workspace workspace);
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;

79
    bool is_available_wk(const SizeArgs& args, size_t limit) const {
80 81 82 83
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }
    bool is_available_reproducible(
            const SizeArgs& args, bool reproducible = true,
84
            size_t limit = std::numeric_limits<size_t>::max()) const {
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        return (!reproducible || is_reproducible()) &&
               is_available_wk(args, limit);
    }
    AlgoBase& check_workspace(const SizeArgs& args,
                              const Workspace& workspace) {
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(
                req <= workspace.size,
                "matrix mul fwd algo %s: required workspace %zu bytes, got %zu",
                name(), req, workspace.size);
        return *this;
    }
};

class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase {
public:
    AlgoCuBlas() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
        return 0_z;
    }
106
    const char* name() const override { return "CUBLAS"; }
107
    void exec(const ExecArgs& args) const override;
108 109
    bool is_reproducible() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
110 111 112 113 114 115 116 117
};

#if CUDA_VERSION >= 10000
class MatrixMulForwardImpl::AlgoUInt4x4x32WMMA final : public AlgoBase {
public:
    AlgoUInt4x4x32WMMA() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
118
    const char* name() const override { return "UINT4x4x32_WMMA"; }
119
    void exec(const ExecArgs& args) const override;
120 121
    bool is_reproducible() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
122 123 124 125 126 127 128
};
#endif
#if CUDA_VERSION >= 10010
class MatrixMulForwardImpl::AlgoCuBlasLt final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
129
    const char* name() const override { return "CUBLAS_LT"; }
130
    void exec(const ExecArgs& args) const override;
131 132
    bool is_reproducible() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT)
133 134 135 136 137 138 139 140 141 142 143 144 145
};
#endif

class MatrixMulForwardImpl::AlgoNaive final : public AlgoBase {
public:
    AlgoNaive() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
        return 0_z;
    }
    const char* name() const override { return "NAIVE"; }
    void exec(const ExecArgs& args) const override;
    bool is_reproducible() const override { return true; }
146
    MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)
147 148
};

149 150 151 152 153 154
#if !MEGDNN_DISABLE_FLOAT16
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
155
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
156

157 158 159 160 161 162
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

    const char* name() const override { return "MATMUL_BFLOAT16"; }
    bool is_reproducible() const override { return true; }
163 164 165 166 167 168

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};
#endif

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase {
public:
    struct AlgoParam {
        int threadblock_m, threadblock_n, threadblock_k;
        int warp_m, warp_n, warp_k;
        std::string to_string() {
            return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
                            threadblock_k, warp_m, warp_n, warp_k);
        }
    };
    AlgoFloat32SIMT(AlgoParam algo_param)
            : m_algo_param{algo_param},
              m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s",
                              m_algo_param.to_string().c_str())} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
    void exec(const ExecArgs& args) const override;
    bool is_reproducible() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }

private:
    AlgoParam m_algo_param;
    std::string m_name;
};

201 202 203
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
    AlgoBase::Mapper m_all_algos_map;
204 205 206 207 208 209 210 211 212 213 214

public:
    AlgoPack();
    AlgoCuBlas cublas;
    AlgoNaive naive;
#if CUDA_VERSION >= 10000
    AlgoUInt4x4x32WMMA wmma_uint4x4x32;
#endif
#if CUDA_VERSION >= 10010
    AlgoCuBlasLt cublas_lt;
#endif
215
#if !MEGDNN_DISABLE_FLOAT16
216
    AlgoBFloat16 bfloat16;
217
#endif
218
    std::vector<AlgoFloat32SIMT> simt_float32;
219
    std::vector<AlgoBase*> all_algos;
220 221

    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
222
    void fill_cutlass_algos();
223 224 225 226 227 228
};

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen