algo.cpp 14.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/algo.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
 */

#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

ConvBiasForwardImpl::AlgoPack::AlgoPack() {
    non_cudnn_algos.push_back(&chanwise);
    non_cudnn_algos.push_back(&chanwise_small);

    non_cudnn_algos.push_back(&inplace_matmul);
    non_cudnn_algos.push_back(&matmul);
    non_cudnn_algos.push_back(&matmul8x8x32);
    non_cudnn_algos.push_back(&batched_matmul);

    fill_cudnn_algos();
    for (auto&& algo : cudnn_conv_bias_activations) {
        all_algos.push_back(&algo);
    }

    //! add conv+nonlinear algos
    std::vector<AlgoBase*> conv_algos;
    conv_algos.push_back(&chanwise);
    conv_algos.push_back(&chanwise_small);
    conv_algos.push_back(&chanwise8x8x32);
    for (auto&& algo : cudnn_convs) {
        conv_algos.push_back(&algo);
    }
    conv_algos.push_back(&inplace_matmul);
    conv_algos.push_back(&matmul);
    conv_algos.push_back(&matmul8x8x32);
    conv_algos.push_back(&batched_matmul);
45
    conv_algos.push_back(&group);
46 47 48 49 50

    for (auto&& algo : conv_algos) {
        all_algos.push_back(algo);
    }

51 52
    all_algos.push_back(&bfloat16);
    bfloat16_algos.push_back(&bfloat16);
53

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    size_t all_algo_size = all_algos.size();
#if CUDA_VERSION >= 10000
    fill_imma_algos();
    all_algos.push_back(&wmma_quint4x4x32);
    for (auto&& algo : int8_nchw4_imma) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : int8_chwn4_imma) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : int8_chwn4_imma_reorder_filter) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : int8_chwn4_imma_unroll_width) {
        all_algos.push_back(&algo);
    }
70 71 72 73
#if CUDA_VERSION >= 10020
    for (auto&& algo : int8_nchw32_imma) {
        all_algos.push_back(&algo);
    }
74 75 76
    for (auto&& algo : int8_nhwc_imma) {
        all_algos.push_back(&algo);
    }
77 78 79
    for (auto&& algo : int4_int4_nchw64_imma) {
        all_algos.push_back(&algo);
    }
80 81 82
    for (auto&& algo : uint4_int4_nchw64_imma) {
        all_algos.push_back(&algo);
    }
83 84 85 86 87 88
    for (auto&& algo : int4_int4_nhwc_imma) {
        all_algos.push_back(&algo);
    }
    for (auto&& algo : uint4_int4_nhwc_imma) {
        all_algos.push_back(&algo);
    }
89
#endif
90
#endif
91 92 93 94
    fill_dp4a_algos();
    for (auto&& algo : int8_nchw4_dotprod) {
        all_algos.push_back(&algo);
    }
95
    all_algos.push_back(&int8_chwn4_dotprod);
96
    all_algos.push_back(&fallback_nchw_qs8);
97 98 99
    for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
        non_cudnn_algos.push_back(all_algos[i]);
    }
100 101 102 103

    for (auto&& algo : all_algos) {
        m_all_algos_map.emplace(algo->info().desc, algo);
    }
104 105 106 107
}

ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack;

108 109
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl)

M
Megvii Engine Team 已提交
110
ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
111
        const ConvBiasForwardImpl* o, const TensorLayout& src,
M
Megvii Engine Team 已提交
112 113 114
        const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst,
        const PreprocessedFilter* preprocessed_filter)
115 116 117
        : SizeArgs(o, src, filter,
                   o->make_canonized_filter_meta(src.ndim, filter), bias, z,
                   dst, preprocessed_filter) {}
118 119

ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
120
        const ConvBiasForwardImpl* o, const TensorLayout& src,
121 122
        const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
        const TensorLayout& bias, const TensorLayout& z,
M
Megvii Engine Team 已提交
123
        const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter)
124 125 126 127 128 129 130 131
        : BiasForwardSizeArgs{concrete_handle(o->handle()),
                              &src,
                              &filter,
                              &bias,
                              &z,
                              filter_meta,
                              &dst,
                              o->param().nonlineMode},
M
Megvii Engine Team 已提交
132 133
          opr{o},
          preprocessed_filter{preprocessed_filter} {}
134 135 136 137

ConvBiasForwardImpl::AlgoBase::ExecArgs::ExecArgs(
        ConvBiasForwardImpl* opr, _megdnn_tensor_in src,
        _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z,
M
Megvii Engine Team 已提交
138 139
        _megdnn_tensor_out dst, _megdnn_workspace workspace,
        const PreprocessedFilter* preprocessed_filter)
140
        : SizeArgs(opr, src.layout, filter.layout, bias.layout, z.layout,
M
Megvii Engine Team 已提交
141
                   dst.layout, preprocessed_filter),
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
          src_tensor{&src},
          filter_tensor{&filter},
          bias_tensor{&bias},
          z_tensor{&z},
          dst_tensor{&dst},
          workspace{workspace} {}

std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
    auto&& fm = filter_meta;
    MEGDNN_MARK_USED_VAR(fm);
    std::string nonlinear_mode_str;
    switch (nonlinear_mode) {
        case param::ConvBias::NonlineMode::RELU:
            nonlinear_mode_str = "RELU";
            break;
        case param::ConvBias::NonlineMode::SIGMOID:
            nonlinear_mode_str = "SIGMOID";
            break;
        case param::ConvBias::NonlineMode::IDENTITY:
            nonlinear_mode_str = "IDENTITY";
            break;
163 164 165
        case param::ConvBias::NonlineMode::H_SWISH:
            nonlinear_mode_str = "H_SWISH";
            break;
166 167 168
        default:
            megdnn_throw("invalid conv bias nonlinear mode");
    }
M
Megvii Engine Team 已提交
169
    return ssprintf(
170
            "src=%s, filter=%s, bias=%s, z=%s, dst=%s, "
171 172
            "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s, "
            "nonlinear_mode=%s",
173 174 175 176 177
            src_layout->to_string().c_str(), filter_layout->to_string().c_str(),
            bias_layout->to_string().c_str(), z_layout->to_string().c_str(),
            dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1],
            fm.stride[0], fm.stride[1], fm.dilation[0], fm.dilation[1],
            !fm.should_flip, src_layout->dtype.name(), dst_layout->dtype.name(),
M
Megvii Engine Team 已提交
178
            nonlinear_mode_str.c_str());
179 180 181
}

void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() {
182 183 184 185
    for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
        cudnn_conv_bias_activations.push_back(algo.first);
        cudnn_convs.push_back(algo.first);
    }
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
}

#if CUDA_VERSION >= 10000
void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
    int8_chwn4_imma.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA16x16x16});
    int8_chwn4_imma.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA32x8x16});
    int8_chwn4_imma.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize::IMMA8x32x16});
    int8_nchw4_imma.push_back(
            {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA16x16x16});
    int8_nchw4_imma.push_back(
            {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA32x8x16});
    int8_nchw4_imma.push_back(
            {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA8x32x16});
    int8_chwn4_imma_reorder_filter.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::
                     IMMA16x16x16});
    int8_chwn4_imma_reorder_filter.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::
                     IMMA32x8x16});
    int8_chwn4_imma_reorder_filter.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::
                     IMMA8x32x16});
    int8_chwn4_imma_unroll_width.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::
                     IMMA16x16x16});
    int8_chwn4_imma_unroll_width.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::
                     IMMA32x8x16});
    int8_chwn4_imma_unroll_width.push_back(
            {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::
                     IMMA8x32x16});
220 221 222
#if CUDA_VERSION >= 10020
    {
        using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam;
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        int8_nchw32_imma.emplace_back(
                AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(
                AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(
                AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(
                AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(
                AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2});
        int8_nchw32_imma.emplace_back(
                AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1});
        int8_nchw32_imma.emplace_back(
                AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1});
        int8_nchw32_imma.emplace_back(
                AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1});
        int8_nchw32_imma.emplace_back(
                AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1});
241
    }
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    {
        using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam;
        int8_nhwc_imma.emplace_back(
                AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 16});
        int8_nhwc_imma.emplace_back(
                AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 8});
        int8_nhwc_imma.emplace_back(
                AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 4});
        int8_nhwc_imma.emplace_back(
                AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 16});
        int8_nhwc_imma.emplace_back(
                AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 8});
        int8_nhwc_imma.emplace_back(
                AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 4});
    }
257 258
    {
        using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
259
        int4_int4_nchw64_imma.emplace_back(
260
                AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
261
        int4_int4_nchw64_imma.emplace_back(
262
                AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
263
        int4_int4_nchw64_imma.emplace_back(
264
                AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
265
        int4_int4_nchw64_imma.emplace_back(
266
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
267 268 269 270
    }
    {
        using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
        uint4_int4_nchw64_imma.emplace_back(
271
                AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
272
        uint4_int4_nchw64_imma.emplace_back(
273
                AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
274
        uint4_int4_nchw64_imma.emplace_back(
275
                AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
276
        uint4_int4_nchw64_imma.emplace_back(
277
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
278
    }
279 280
    {
        using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
281 282 283 284 285 286
        int4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 32});
        int4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 16});
        int4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 8});
287
        int4_int4_nhwc_imma.emplace_back(
288
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
289
        int4_int4_nhwc_imma.emplace_back(
290
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
291
        int4_int4_nhwc_imma.emplace_back(
292
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
293
        int4_int4_nhwc_imma.emplace_back(
294
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
295
        int4_int4_nhwc_imma.emplace_back(
296
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
297
        int4_int4_nhwc_imma.emplace_back(
298
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
299 300 301
    }
    {
        using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
302 303 304 305 306 307
        uint4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 32});
        uint4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 16});
        uint4_int4_nhwc_imma.emplace_back(
                AlgoParam{128, 16, 64, 128, 16, 64, 8, 8, 32, 2, 8});
308
        uint4_int4_nhwc_imma.emplace_back(
309
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
310
        uint4_int4_nhwc_imma.emplace_back(
311
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
312
        uint4_int4_nhwc_imma.emplace_back(
313
                AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
314
        uint4_int4_nhwc_imma.emplace_back(
315
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
316
        uint4_int4_nhwc_imma.emplace_back(
317
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
318
        uint4_int4_nhwc_imma.emplace_back(
319
                AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
320
    }
321
#endif
322 323 324
}
#endif

325 326
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
    using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1});
    int8_nchw4_dotprod.emplace_back(
            AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2});
345 346
}

347 348 349 350 351 352 353
ConvBiasForwardImpl::AlgoBase*
ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum(
        cudnnConvolutionFwdAlgo_t algo) {
    for (auto&& i : cudnn_convs) {
        if (i.cudnn_enum() == algo)
            return &i;
    }
M
Megvii Engine Team 已提交
354 355
    megdnn_throw(ssprintf("can not find cudnn conv fwd algorithm %d",
                          static_cast<int>(algo)));
356 357 358 359 360 361 362 363 364
}

ConvBiasForwardImpl::AlgoBase*
ConvBiasForwardImpl::AlgoPack::cudnn_conv_bias_act_from_enum(
        cudnnConvolutionFwdAlgo_t algo) {
    for (auto&& i : cudnn_conv_bias_activations) {
        if (i.cudnn_enum() == algo)
            return &i;
    }
M
Megvii Engine Team 已提交
365 366
    megdnn_throw(ssprintf("can not find cudnn conv bias act algorithm %d",
                          static_cast<int>(algo)));
367 368 369
}

// vim: syntax=cpp.doxygen