algos.cpp 6.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/matrix_mul/algos.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13
 */
#include "./algos.h"
#include <cuda.h>
14 15 16 17
#include "src/common/algo_base.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/utils.h"
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#if CUDA_VERSION >= 10010
#include <cublasLt.h>
#endif

using namespace megdnn;
using namespace cuda;

MatrixMulForwardImpl::AlgoPack::AlgoPack() {
    all_algos.push_back(&cublas);
#if CUDA_VERSION >= 10000
    all_algos.push_back(&wmma_uint4x4x32);
#endif
#if CUDA_VERSION >= 10010
    all_algos.push_back(&cublas_lt);
#endif
33
#if !MEGDNN_DISABLE_FLOAT16
34
    all_algos.push_back(&bfloat16);
35
#endif
36
#if CUDA_VERSION >= 9020
37 38 39 40
    fill_cutlass_algos();
    for (auto&& algo : simt_float32) {
        all_algos.push_back(&algo);
    }
41 42 43
    for (auto&& algo : simt_float32_split_k) {
        all_algos.push_back(&algo);
    }
44 45 46
    for (auto&& algo : simt_float32_gemv_batched_strided) {
        all_algos.push_back(&algo);
    }
47
#if CUDA_VERSION >= 10020
48 49 50 51 52 53
    for (auto&& algo : tensorop_float16) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : tensorop_float16_split_k) {
        all_algos.push_back(&algo);
    }
54
#endif
55
#endif
56

57
    all_algos.push_back(&naive);
58

59 60 61 62 63 64 65 66 67 68 69 70
    std::vector<cudnnConvolutionFwdAlgo_t> cudnn_conv_enum;
    for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
        cudnn_conv_enum.push_back(algo.first);
    }

    for (auto&& algo : cudnn_conv_enum) {
        conv1x1.push_back(AlgoConv1X1CUDNN(algo));
    }
    for (size_t i = 0; i < conv1x1.size(); ++i) {
        all_algos.push_back(&conv1x1[i]);
    }

71 72 73
    for (auto&& algo : all_algos) {
        m_all_algos_map.emplace(algo->info().desc, algo);
    }
74 75
}

76
#if CUDA_VERSION >= 9020
77
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
78
    using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam;
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
    simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
    simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
    simt_float32.emplace_back(AlgoParam{256, 32, 8, 64, 16, 8});
    simt_float32.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8});
    simt_float32.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8});
    simt_float32.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8});
    simt_float32.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8});
    simt_float32.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8});
    simt_float32.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8});
    simt_float32.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8});
    simt_float32.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8});
    simt_float32.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8});
    simt_float32.emplace_back(AlgoParam{8, 32, 8, 8, 32, 8});
    simt_float32.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
    simt_float32.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
    simt_float32.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    simt_float32_split_k.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{256, 32, 8, 64, 16, 8});
    simt_float32_split_k.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{8, 32, 8, 8, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
    simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
    simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
113 114 115
    simt_float32_gemv_batched_strided.emplace_back(128);
    simt_float32_gemv_batched_strided.emplace_back(64);
    simt_float32_gemv_batched_strided.emplace_back(32);
116 117 118 119 120 121 122 123 124 125
#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \
    cb(256, 128, 32, 64, 64, 32, 8, 8, 4);    \
    cb(128, 256, 32, 64, 64, 32, 8, 8, 4);    \
    cb(128, 128, 32, 64, 64, 32, 8, 8, 4);    \
    cb(256, 128, 32, 64, 64, 32, 16, 8, 8);   \
    cb(128, 256, 32, 64, 64, 32, 16, 8, 8);   \
    cb(128, 128, 32, 64, 64, 32, 16, 8, 8);
#define cb(...)                                            \
    tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \
    tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__});
126
#if CUDA_VERSION >= 10020
127
    FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb)
128
#endif
129 130
#undef cb
#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES
131
}
132
#endif
133

134 135
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;

136 137
MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl)

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o,
                                                   const TensorLayout& A,
                                                   const TensorLayout& B,
                                                   const TensorLayout& C)
        : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {}

MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr,
                                                   _megdnn_tensor_in A,
                                                   _megdnn_tensor_in B,
                                                   _megdnn_tensor_out C,
                                                   _megdnn_workspace workspace)
        : SizeArgs(opr, A.layout, B.layout, C.layout),
          tensor_a{A},
          tensor_b{B},
          tensor_c{C},
          workspace{workspace} {}

std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const {
    auto&& param = opr->param();
    size_t m = layout_a.shape[0], n = layout_b.shape[1],
           k = layout_a.shape[param.transposeA ? 0 : 1];
    MEGDNN_MARK_USED_VAR(m);
    MEGDNN_MARK_USED_VAR(n);
    MEGDNN_MARK_USED_VAR(k);
M
Megvii Engine Team 已提交
162
    return ssprintf(
163 164 165
            "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose "
            "B=%d,ldA=%zu,ldB=%zu,ldC=%zu",
            m, k, k, n, m, n, param.transposeA, param.transposeB,
M
Megvii Engine Team 已提交
166
            layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]);
167
}
168

169
// vim: syntax=cpp.doxygen