algo.cpp 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

ConvBiasForwardImpl::AlgoPack::AlgoPack() {
    non_cudnn_algos.push_back(&chanwise);
    non_cudnn_algos.push_back(&chanwise_small);
10
    non_cudnn_algos.push_back(&depthwise_large_filter);
11 12 13 14 15

    non_cudnn_algos.push_back(&inplace_matmul);
    non_cudnn_algos.push_back(&matmul);
    non_cudnn_algos.push_back(&matmul8x8x32);
    non_cudnn_algos.push_back(&batched_matmul);
16
    non_cudnn_algos.push_back(&int1_simple);
17

18
#if CUDNN_VERSION >= 8020
19 20 21 22
    all_algos.push_back(&cudnn_conv_v8);
    all_algos.push_back(&cudnn_conv_bias_activation_v8);
#endif

23 24 25 26 27 28 29 30 31
    fill_cudnn_algos();
    for (auto&& algo : cudnn_conv_bias_activations) {
        all_algos.push_back(&algo);
    }

    //! add conv+nonlinear algos
    std::vector<AlgoBase*> conv_algos;
    conv_algos.push_back(&chanwise);
    conv_algos.push_back(&chanwise_small);
32
    conv_algos.push_back(&depthwise_large_filter);
33 34 35 36 37 38 39 40
    conv_algos.push_back(&chanwise8x8x32);
    for (auto&& algo : cudnn_convs) {
        conv_algos.push_back(&algo);
    }
    conv_algos.push_back(&inplace_matmul);
    conv_algos.push_back(&matmul);
    conv_algos.push_back(&matmul8x8x32);
    conv_algos.push_back(&batched_matmul);
41
    conv_algos.push_back(&group);
42
    conv_algos.push_back(&int1_simple);
43 44 45 46 47

    for (auto&& algo : conv_algos) {
        all_algos.push_back(algo);
    }

48 49
    all_algos.push_back(&bfloat16);
    bfloat16_algos.push_back(&bfloat16);
50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    size_t all_algo_size = all_algos.size();
#if CUDA_VERSION >= 10000
    fill_imma_algos();
    all_algos.push_back(&wmma_quint4x4x32);
    for (auto&& algo : int8_nchw4_imma) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : int8_chwn4_imma) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : int8_chwn4_imma_reorder_filter) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : int8_chwn4_imma_unroll_width) {
        all_algos.push_back(&algo);
    }
67 68 69 70
#if CUDA_VERSION >= 10020
    for (auto&& algo : int8_nchw32_imma) {
        all_algos.push_back(&algo);
    }
71 72 73
    for (auto&& algo : int8_nhwc_imma) {
        all_algos.push_back(&algo);
    }
74 75 76
    for (auto&& algo : int4_int4_nchw64_imma) {
        all_algos.push_back(&algo);
    }
77 78 79
    for (auto&& algo : uint4_int4_nchw64_imma) {
        all_algos.push_back(&algo);
    }
80 81 82 83 84 85
    for (auto&& algo : int4_int4_nhwc_imma) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : uint4_int4_nhwc_imma) {
        all_algos.push_back(&algo);
    }
86
#endif
87
#endif
88 89 90 91
    fill_dp4a_algos();
    for (auto&& algo : int8_nchw4_dotprod) {
        all_algos.push_back(&algo);
    }
92
    fill_dwconv_algos();
93
    all_algos.push_back(&int8_chwn4_dotprod);
94
    all_algos.push_back(&fallback_nchw_qs8);
95 96 97
    for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
        non_cudnn_algos.push_back(all_algos[i]);
    }
98 99 100 101

    for (auto&& algo : all_algos) {
        m_all_algos_map.emplace(algo->info().desc, algo);
    }
102 103 104 105
}

ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack;

106 107
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl)

M
Megvii Engine Team 已提交
108
ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
109
        const ConvBiasForwardImpl* o, const TensorLayout& src,
M
Megvii Engine Team 已提交
110 111 112 113 114
        const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z,
        const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter)
        : SizeArgs(
                  o, src, filter, o->make_canonized_filter_meta(src.ndim, filter), bias,
                  z, dst, preprocessed_filter) {}
115 116

ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
117
        const ConvBiasForwardImpl* o, const TensorLayout& src,
118
        const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
M
Megvii Engine Team 已提交
119 120
        const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
        const PreprocessedFilter* preprocessed_filter)
121 122 123 124 125 126 127 128
        : BiasForwardSizeArgs{concrete_handle(o->handle()),
                              &src,
                              &filter,
                              &bias,
                              &z,
                              filter_meta,
                              &dst,
                              o->param().nonlineMode},
M
Megvii Engine Team 已提交
129 130
          opr{o},
          preprocessed_filter{preprocessed_filter} {}
131 132

ConvBiasForwardImpl::AlgoBase::ExecArgs::ExecArgs(
M
Megvii Engine Team 已提交
133 134 135 136 137 138
        ConvBiasForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in filter,
        _megdnn_tensor_in bias, _megdnn_tensor_in z, _megdnn_tensor_out dst,
        _megdnn_workspace workspace, const PreprocessedFilter* preprocessed_filter)
        : SizeArgs(
                  opr, src.layout, filter.layout, bias.layout, z.layout, dst.layout,
                  preprocessed_filter),
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
          src_tensor{&src},
          filter_tensor{&filter},
          bias_tensor{&bias},
          z_tensor{&z},
          dst_tensor{&dst},
          workspace{workspace} {}

std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
    auto&& fm = filter_meta;
    MEGDNN_MARK_USED_VAR(fm);
    std::string nonlinear_mode_str;
    switch (nonlinear_mode) {
        case param::ConvBias::NonlineMode::RELU:
            nonlinear_mode_str = "RELU";
            break;
        case param::ConvBias::NonlineMode::SIGMOID:
            nonlinear_mode_str = "SIGMOID";
            break;
        case param::ConvBias::NonlineMode::IDENTITY:
            nonlinear_mode_str = "IDENTITY";
            break;
160 161 162
        case param::ConvBias::NonlineMode::H_SWISH:
            nonlinear_mode_str = "H_SWISH";
            break;
163 164 165
        default:
            megdnn_throw("invalid conv bias nonlinear mode");
    }
M
Megvii Engine Team 已提交
166
    return ssprintf(
167
            "src=%s, filter=%s, bias=%s, z=%s, dst=%s, "
168 169
            "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s, "
            "nonlinear_mode=%s",
170 171
            src_layout->to_string().c_str(), filter_layout->to_string().c_str(),
            bias_layout->to_string().c_str(), z_layout->to_string().c_str(),
M
Megvii Engine Team 已提交
172 173 174
            dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], fm.stride[0],
            fm.stride[1], fm.dilation[0], fm.dilation[1], !fm.should_flip,
            src_layout->dtype.name(), dst_layout->dtype.name(),
M
Megvii Engine Team 已提交
175
            nonlinear_mode_str.c_str());
176 177
}

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
param::Convolution ConvBiasForwardImpl::AlgoBase::get_param_convolution(
        const SizeArgs& args) const {
    param::Convolution::Mode mode;
    param::Convolution::Sparse sparse = args.filter_meta.group > 1
                                              ? param::Convolution::Sparse::GROUP
                                              : param::Convolution::Sparse::DENSE;
    if (args.filter_meta.should_flip) {
        mode = param::Convolution::Mode::CONVOLUTION;
    } else {
        mode = param::Convolution::Mode::CROSS_CORRELATION;
    }
    return param::Convolution{
            mode,
            args.filter_meta.padding[0],
            args.filter_meta.padding[1],
            args.filter_meta.stride[0],
            args.filter_meta.stride[1],
            args.filter_meta.dilation[1],
            args.filter_meta.dilation[0],
            sparse,
            args.filter_meta.format,
            args.opr->param().compute_mode};
}

202
void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() {
203 204 205 206
    for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
        cudnn_conv_bias_activations.push_back(algo.first);
        cudnn_convs.push_back(algo.first);
    }
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
}

#if CUDA_VERSION >= 10000
void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
    int8_chwn4_imma.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA16x16x16});
    int8_chwn4_imma.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA32x8x16});
    int8_chwn4_imma.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA8x32x16});
    int8_nchw4_imma.push_back(
            {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA16x16x16});
    int8_nchw4_imma.push_back(
            {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA32x8x16});
    int8_nchw4_imma.push_back(
            {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA8x32x16});
    int8_chwn4_imma_reorder_filter.push_back(
M
Megvii Engine Team 已提交
224
            {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA16x16x16});
225
    int8_chwn4_imma_reorder_filter.push_back(
M
Megvii Engine Team 已提交
226
            {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA32x8x16});
227
    int8_chwn4_imma_reorder_filter.push_back(
M
Megvii Engine Team 已提交
228
            {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA8x32x16});
229
    int8_chwn4_imma_unroll_width.push_back(
M
Megvii Engine Team 已提交
230
            {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA16x16x16});
231
    int8_chwn4_imma_unroll_width.push_back(
M
Megvii Engine Team 已提交
232
            {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA32x8x16});
233
    int8_chwn4_imma_unroll_width.push_back(
M
Megvii Engine Team 已提交
234
            {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA8x32x16});
235 236 237
#if CUDA_VERSION >= 10020
    {
        using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam;
M
Megvii Engine Team 已提交
238 239 240 241 242 243 244 245 246
        int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1});
        int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1});
        int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1});
        int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1});
247
    }
248 249
    {
        using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam;
M
Megvii Engine Team 已提交
250 251 252
        int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 16});
        int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 8});
        int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 4});
253 254
        int8_nhwc_imma.emplace_back(
                AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 16});
M
Megvii Engine Team 已提交
255 256
        int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 8});
        int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 4});
257
    }
258 259
    {
        using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
260
        int4_int4_nchw64_imma.emplace_back(
261
                AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
262
        int4_int4_nchw64_imma.emplace_back(
263
                AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
264
        int4_int4_nchw64_imma.emplace_back(
265
                AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
266
        int4_int4_nchw64_imma.emplace_back(
267
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
268 269 270 271
    }
    {
        using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
        uint4_int4_nchw64_imma.emplace_back(
272
                AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
273
        uint4_int4_nchw64_imma.emplace_back(
274
                AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
275
        uint4_int4_nchw64_imma.emplace_back(
276
                AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
277
        uint4_int4_nchw64_imma.emplace_back(
278
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
279
    }
280 281
    {
        using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
282 283 284 285 286 287
        int4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 32});
        int4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 16});
        int4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 8});
288
        int4_int4_nhwc_imma.emplace_back(
289
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
290
        int4_int4_nhwc_imma.emplace_back(
291
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
292
        int4_int4_nhwc_imma.emplace_back(
293
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
294
        int4_int4_nhwc_imma.emplace_back(
295
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
296
        int4_int4_nhwc_imma.emplace_back(
297
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
298
        int4_int4_nhwc_imma.emplace_back(
299
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
300 301 302
    }
    {
        using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
303 304 305 306 307 308
        uint4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 32});
        uint4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 16});
        uint4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 8});
309
        uint4_int4_nhwc_imma.emplace_back(
310
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
311
        uint4_int4_nhwc_imma.emplace_back(
312
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
313
        uint4_int4_nhwc_imma.emplace_back(
314
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
315
        uint4_int4_nhwc_imma.emplace_back(
316
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
317
        uint4_int4_nhwc_imma.emplace_back(
318
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
319
        uint4_int4_nhwc_imma.emplace_back(
320
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
321
    }
322
#endif
323 324 325
}
#endif

326 327
void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
    using AlgoParam = AlgoCutlassConvolutionBase::AlgoParam;
328 329
    /// preferred algo
    f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 1, 1, 1, 2});
330 331 332 333
    f32_implicit_bmm.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 1, 1, 1, 2});
    f32_implicit_bmm.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 1, 1, 1, 2});
    f32_implicit_bmm.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 1, 1, 1, 2});
    f32_implicit_bmm.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 1, 1, 1, 2});
334
    f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 1, 1, 1, 2});
335 336 337 338 339 340
    f32_implicit_bmm.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 1, 1, 1, 2});
    f32_implicit_bmm.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 1, 1, 1, 2});
    f32_implicit_bmm.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 1, 1, 1, 2});
    for (auto&& algo : f32_implicit_bmm) {
        all_algos.push_back(&algo);
    }
341
#if CUDA_VERSION >= 10010
342 343
    /// preferred algo
    f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
344 345 346 347 348 349 350 351 352 353
    f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
    f16_implicit_bmm.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
    f16_implicit_bmm.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
    f16_implicit_bmm.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
    for (auto&& algo : f16_implicit_bmm) {
        all_algos.push_back(&algo);
    }
#endif
}

354 355
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
    using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
M
Megvii Engine Team 已提交
356 357 358 359 360 361 362 363 364
    int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1});
    int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2});
365 366
}

M
Megvii Engine Team 已提交
367
ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum(
368 369 370 371 372
        cudnnConvolutionFwdAlgo_t algo) {
    for (auto&& i : cudnn_convs) {
        if (i.cudnn_enum() == algo)
            return &i;
    }
M
Megvii Engine Team 已提交
373 374
    megdnn_throw(ssprintf(
            "can not find cudnn conv fwd algorithm %d", static_cast<int>(algo)));
375 376
}

M
Megvii Engine Team 已提交
377 378
ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::
        cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo) {
379 380 381 382
    for (auto&& i : cudnn_conv_bias_activations) {
        if (i.cudnn_enum() == algo)
            return &i;
    }
M
Megvii Engine Team 已提交
383 384
    megdnn_throw(ssprintf(
            "can not find cudnn conv bias act algorithm %d", static_cast<int>(algo)));
385 386 387
}

// vim: syntax=cpp.doxygen