algo.h 41.1 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */

#pragma once

#include "megdnn/oprs.h"

17 18
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
19
#include "src/common/utils.h"
20
#include "src/cuda/conv_bias/conv_bias_int8.cuh"
21 22 23
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh"
24
#include "src/cuda/cudnn_wrapper.h"
25
#include "src/cuda/handle.h"
26 27 28 29 30

#include <cuda.h>
#include <memory>
#include <unordered_map>

31 32 33 34 35 36 37 38 39 40 41
namespace cutlass {
namespace library {

// forward declaration of cutlass library concepts, we hope that algo.h does
// not depend on cutlass headers

class Operation;

}  // namespace library
}  // namespace cutlass

42 43 44 45 46 47 48 49 50 51 52 53 54 55
namespace megdnn {
namespace cuda {

/*!
 * \brief base class for conv bias algos
 *
 * All the algo impls should try to support non-contiguous batch dim, for group
 * conv execution.
 */
class ConvBiasForwardImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;

public:
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    enum class AlgoType : uint32_t {
        CUDA_CUDNN_CONVBIAS,
        CUDA_CHANWISE,
        CUDA_CHANWISE_SMALL,
        CUDA_CHANWISE_INT8X8X32,
        CUDA_CUDNN_CONV,
        CUDA_INPLACE_MATMUL,
        CUDA_MATMUL,
        CUDA_MATMUL_INT8X8X32,
        CUDA_BATCHED_MATMUL,
        CUDA_GROUP_CONV_GENERAL,
        CUDA_WMMA_UINT4X4X32,
        CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8,
75
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT8,
76
        CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4,
77
        CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4,
78 79
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT4_INT4,
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_UINT4_INT4,
80 81 82 83 84
        CUDA_BFLOAT16,
        CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
85 86
        CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4,
        CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4,
87
        CUDA_FALLBACK_NCHW_INT4
88 89 90
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

91
    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
92
    struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
93
        const ConvBiasForwardImpl* opr;
M
Megvii Engine Team 已提交
94
        const PreprocessedFilter* preprocessed_filter;
95

96
        std::string to_string() const;
97
        SizeArgs(const ConvBiasForwardImpl* opr, const TensorLayout& src,
98
                 const TensorLayout& filter, const TensorLayout& bias,
M
Megvii Engine Team 已提交
99 100
                 const TensorLayout& z, const TensorLayout& dst,
                 const PreprocessedFilter* preprocessed_filter = nullptr);
101
        SizeArgs(const ConvBiasForwardImpl* opr, const TensorLayout& src,
102 103 104
                 const TensorLayout& filter,
                 const CanonizedFilterMeta& filter_meta,
                 const TensorLayout& bias, const TensorLayout& z,
M
Megvii Engine Team 已提交
105 106
                 const TensorLayout& dst,
                 const PreprocessedFilter* preprocessed_filter = nullptr);
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

        void init_conv_bias_desc(conv_bias::CUDNNForwardDescs& desc) const {
            desc.set_conv_bias(*src_layout, filter_meta, *dst_layout,
                               *bias_layout, *z_layout, opr->param());
        }

        void init_conv_desc(conv_bias::CUDNNForwardDescs& desc) const {
            desc.set_conv(*src_layout, filter_meta, *dst_layout, opr->param());
        }
    };
    struct ExecArgs : public SizeArgs {
        const TensorND *src_tensor, *filter_tensor, *bias_tensor, *z_tensor,
                *dst_tensor;
        Workspace workspace;

        ExecArgs(ConvBiasForwardImpl* opr, _megdnn_tensor_in src,
                 _megdnn_tensor_in filter, _megdnn_tensor_in bias,
                 _megdnn_tensor_in z, _megdnn_tensor_out dst,
M
Megvii Engine Team 已提交
125 126
                 _megdnn_workspace workspace,
                 const PreprocessedFilter* preprocessed_filter = nullptr);
127 128 129 130
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;
M
Megvii Engine Team 已提交
131 132
    virtual size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const {
133
        MEGDNN_MARK_USED_VAR(args);
M
Megvii Engine Team 已提交
134 135 136 137
        return 0;
    }
    virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const {
138
        MEGDNN_MARK_USED_VAR(args);
M
Megvii Engine Team 已提交
139 140
        return {};
    }
141 142 143
    virtual void exec_preprocess(const ExecArgs& args) const {
        MEGDNN_MARK_USED_VAR(args);
    }
144 145 146 147 148

    bool is_available_wk(const SizeArgs& args, size_t limit) {
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }

149 150
    bool is_available_attribute(
            const SizeArgs& args,
151 152
            const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
            const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
153
            size_t limit = std::numeric_limits<size_t>::max()) {
154 155 156
        return contain_attribute_all(positive_attr) &&
               !contain_attribute_any(negative_attr) &&
               is_available_wk(args, limit);
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    }

    AlgoBase& check_workspace(const SizeArgs& args,
                              const Workspace& workspace) {
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(
                req <= workspace.size,
                "conv bias fwd algo %s: required workspace %zu bytes, got %zu",
                name(), req, workspace.size);
        return *this;
    }

    virtual bool is_cudnn() const { return false; }
};

class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
public:
174 175 176 177 178 179 180 181
    AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
                      CudnnAlgoPack::conv_fwd_algos().end());
        m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
        m_name = ConvBiasForward::algo_name<DefaultParam>(
                "CUDNN:ConvBiasActivation:" + m_attr.name, {});
    }
182 183 184 185 186 187 188 189

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    param::Convolution get_param_convolution(const SizeArgs& args) const;
    bool is_available(const SizeArgs&) const override;

    const char* name() const override { return m_name.c_str(); }

190 191 192 193 194
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
195 196 197
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
198 199
        return ret;
    }
200 201 202 203 204

    cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }

    bool is_cudnn() const override { return true; }

205 206 207 208 209 210 211 212
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }

213 214 215
private:
    std::string m_name;
    cudnnConvolutionFwdAlgo_t m_cudnn_enum;
216
    CudnnAlgoPack::Attr m_attr;
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
};

class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name =
                    ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE", {});
        }
        return m_name.c_str();
    }
232 233 234
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
235

236 237
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<DirectParam>(
                    "CHANNEL_WISE_SMALL", {});
        }
        return m_name.c_str();
    }
255
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
256 257 258
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<DirectParam>(
                    "CHANNEL_WISE_8X8X32", {});
        }
        return m_name.c_str();
    }
276
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32)
277 278 279
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
280 281 282 283 284 285 286

private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
public:
287 288 289 290 291 292 293 294
    AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
                      CudnnAlgoPack::conv_fwd_algos().end());
        m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
        m_name = ConvBiasForward::algo_name<DefaultParam>(
                "CUDNN:Convolution:" + m_attr.name, {});
    }
295 296 297 298 299

    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

300 301 302 303 304
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
305 306 307
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
308 309
        return ret;
    }
310 311 312 313 314 315

    const char* name() const override { return m_name.c_str(); }

    cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }

    bool is_cudnn() const override { return true; }
316 317 318 319 320 321 322 323 324

    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }

325 326 327
private:
    std::string m_name;
    cudnnConvolutionFwdAlgo_t m_cudnn_enum;
328
    CudnnAlgoPack::Attr m_attr;
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346

    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! compute small matmul in the kernel
class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>(
                    "INPLACE_MATMUL", {});
        }
        return m_name.c_str();
    }
347
    MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
348 349 350
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368

private:
    mutable std::string m_name;
};

//! im2col and matmul, with dilation
class ConvBiasForwardImpl::AlgoMatmul final : public AlgoBase {
    template <typename T>
    static void exec_internal(const ExecArgs& args,
                              const WorkspaceBundle& bundle);

public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
369 370
            m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>("MATMUL",
                                                                       {});
371 372 373
        }
        return m_name.c_str();
    }
374 375 376 377

    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;
378
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
379
    AlgoAttribute attribute() const override {
380 381
        return AlgoAttribute::REPRODUCIBLE |
               AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
382
    }
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoMatmul8x8x32 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
                    "MATMUL8X8X32", {});
        }
        return m_name.c_str();
    }
401
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32)
402 403 404
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424

private:
    bool need_src_unroll(const SizeArgs& args) const;
    bool need_filter_reshape(const SizeArgs& args) const;
    template <Param::Format>
    WorkspaceBundle get_bundle(const SizeArgs& args) const;
    template <Param::Format>
    void exec_internal(const ExecArgs& args) const;
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoBatchedMatmul final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
425
                    "BATCHED_MATMUL", {});
426 427 428
        }
        return m_name.c_str();
    }
429 430 431 432 433

    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

434
    AlgoAttribute attribute() const override {
435 436
        return AlgoAttribute::REPRODUCIBLE |
               AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
437 438
    }

439
    MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
440 441 442 443 444 445 446 447 448 449 450 451 452

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
    mutable std::string m_name;
};

//! implement group conv by another algo
class ConvBiasForwardImpl::AlgoGroupConvGeneral final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

453 454 455
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;
456

457 458 459 460 461 462
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<DirectParam>("CUDA:GROUP_CONV",
                                                             {});
        }
        return m_name.c_str();
463 464
    }

465 466
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
467
    }
468

469 470
    MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)

471 472
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
473
    mutable std::string m_name;
474 475 476 477 478 479 480 481 482 483
};

#if CUDA_VERSION >= 10000
class ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA final : public AlgoBase {
public:
    AlgoQUInt4x4x32WMMA() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return "QUINT4x4x32_WMMA"; }
484 485 486
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
487

488
private:
489 490
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
491 492
    bool use_kernel_fhxfw(const SizeArgs& args) const;
    size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const;
493
    MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
494 495 496 497 498 499 500 501 502 503 504 505 506
};
#endif

class ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm final
        : public AlgoBase {
public:
    AlgoInt8CHWN4DotProdImplicitGemm() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM";
    }
507 508 509
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
510 511 512 513 514 515 516
    template <typename BiasVisitor>
    static void dispatch_nonlinear_mode(
            const int8_t* d_src, const int8_t* d_filter,
            BiasVisitor bias_visitor, const int8_t* d_z, int8_t* d_dst,
            const convolution::ConvParam& param, float alpha, float beta,
            float gamma, float scale, cudaStream_t stream,
            param::ConvBias::NonlineMode nonlinear_mode);
517
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8)
518 519
};

520 521 522 523 524 525 526 527
/*********************** Cutlass Algorithms ************************/

/* The inheritance of cutlass algorithm classes:
 *
 * AlgoCutlassConvolutionBase
 * +
 * +--- AlgoInt8NCHW4DotProdImplicitGemm
 * +--- AlgoInt8NCHW32IMMAImplicitGemm
528
 * +--- AlgoInt8NHWCIMMAImplicitGemm
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
 * +
 * +--- AlgoInt4NCHW64IMMAImplicitGemmBase
 * +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm
 * +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm
 * +
 * +--- AlgoInt4NHWCIMMAImplicitGemmBase
 * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm
 * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm
 * +
 */

/*
 * The base class for all cutlass algorithm classes
 */
class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase {
544
public:
545 546 547 548 549 550 551 552 553 554 555 556 557 558
    // corresponds to cutlass::conv::Operator. we hope that algo.h does not
    // depend on cutlass headers
    enum class ConvOperator { kFprop, kDgrad, kWgrad };

    // corresponds to cutlass::conv::ConvType. we hope that algo.h does not
    // depend on cutlass headers
    enum class ConvType {
        kConvolution,
        kBatchConvolution,
        kLocal,
        kLocalShare
    };

    // common parameters for operation selection
559 560 561 562 563 564 565
    struct AlgoParam {
        int threadblock_m;
        int threadblock_n;
        int threadblock_k;
        int warp_m;
        int warp_n;
        int warp_k;
566 567 568
        int instruction_m;
        int instruction_n;
        int instruction_k;
569
        int stage;
570 571 572 573 574 575 576 577
        int access_size;

        AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_,
                  int warp_m_, int warp_n_, int warp_k_, int instruction_m_,
                  int instruction_n_, int instruction_k_, int stage_,
                  int access_size_ = 0);

        std::string to_string() const;
578
    };
579 580 581 582 583 584 585 586

    AlgoCutlassConvolutionBase(AlgoParam algo_param)
            : m_algo_param{algo_param} {}

    // generate a cutlass::library::ConvolutionKey and find the corresponding
    // operation (cutlass kernel) from the global OperationTable
    const cutlass::library::Operation* get_cutlass_conv_op(
            const SizeArgs& args, ConvOperator conv_op, ConvType conv_type,
587
            bool use_conv_filter_unity_opt, bool without_shared_load) const;
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611

    // execute the cutlass kernel found by get_cutlass_conv_op. we give
    // subclasses full freedom to decide where and how these arguments are
    // extracted
    void execute_cutlass_conv_op(const cutlass::library::Operation* op,
                                 const void* src, const void* filter,
                                 const void* bias, const void* z, void* dst,
                                 void* workspace, size_t n, size_t hi,
                                 size_t wi, size_t ci, size_t co, size_t fh,
                                 size_t fw, size_t ho, size_t wo, size_t ph,
                                 size_t pw, size_t sh, size_t sw, size_t dh,
                                 size_t dw, const void* alpha, const void* beta,
                                 const void* gamma, const void* delta,
                                 const void* theta, const void* threshold,
                                 const void* dst_scale, cudaStream_t stream,
                                 const void* extra_param = nullptr) const;

protected:
    AlgoParam m_algo_param;
};

class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
        : public AlgoCutlassConvolutionBase {
public:
612
    AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
613
            : AlgoCutlassConvolutionBase(algo_param),
614 615
              m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
                              m_algo_param.to_string().c_str())} {}
616 617 618
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
619
    const char* name() const override { return m_name.c_str(); }
620 621 622
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
623 624 625 626 627
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
628 629 630 631 632 633 634
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }
635 636 637 638

private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
639
    std::string m_name;
640 641
};

642 643 644 645 646 647 648 649 650 651 652 653
class ConvBiasForwardImpl::AlgoFallbackNCHWQS8 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        return "FALLBACK_CONV_NCHW_QS8";
    }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
654 655 656 657
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;
 
658 659 660 661
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677
#if CUDA_VERSION >= 10000
class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final
        : public AlgoBase {
public:
    enum class MMATileSize : uint32_t {
        IMMA16x16x16,
        IMMA32x8x16,
        IMMA8x32x16
    };
    AlgoInt8CHWN4IMMAImplicitGemm(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_" +
                     to_string(m_mma_tile_size)} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
678
    const char* name() const override { return m_name.c_str(); }
679 680 681
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
682 683 684 685 686 687 688 689 690 691
    template <typename BiasVisitor>
    static void dispatch_nonlinear_mode(
            const int8_t* d_src, const int8_t* d_filter,
            BiasVisitor bias_visitor, int8_t* d_z, int8_t* d_dst,
            const convolution::ConvParam& param, float alpha, float beta,
            float gamma, float scale, cudaStream_t stream,
            param::ConvBias::NonlineMode nonlinear_mode,
            MMATileSize mma_tile_size);
    static std::string to_string(MMATileSize mma_tile_size);

692 693 694 695 696 697 698 699
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }

700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8NCHW4IMMAImplicitGemm(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_NCHW4_IMMA_IMPLICIT_GEMM_" +
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(
                             m_mma_tile_size)} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
717 718 719 720 721 722 723 724
    const char* name() const override { return m_name.c_str(); }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
725 726 727
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
728

729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8CHWN4IMMAImplicitGemmReorderFilter(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_" +
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(
                             m_mma_tile_size)} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
749 750 751 752 753 754 755
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
756 757 758
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777

private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_" +
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(
                             m_mma_tile_size)} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
778 779 780 781 782 783 784
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
785 786 787
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
788 789 790 791 792 793 794

private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};
#endif

795 796
#if CUDA_VERSION >= 10020
class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final
797
        : public AlgoCutlassConvolutionBase {
798 799
public:
    AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param)
800
            : AlgoCutlassConvolutionBase(algo_param) {
801 802 803 804 805 806 807 808 809
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s",
                         to_string(m_algo_param).c_str()),
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
810 811 812
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
813
    static std::string to_string(AlgoParam algo_param);
814 815 816 817 818
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
819 820 821 822 823 824 825
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }
826

827 828 829 830 831 832 833
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;

    std::string m_name;
};

834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
class ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm final
        : public AlgoCutlassConvolutionBase {
public:
    AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param)
            : AlgoCutlassConvolutionBase(algo_param) {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM_%s",
                         to_string(m_algo_param).c_str()),
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    static std::string to_string(AlgoParam algo_param);
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }

private:
    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const;

    void reorder_filter(const ExecArgs& args, int interleaved,
                        void* reordered_filter) const;

    std::string m_name;
};

875
class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase
876
        : public AlgoCutlassConvolutionBase {
877
public:
878
    AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param)
879
            : AlgoCutlassConvolutionBase(algo_param) {}
880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913

    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    const char* name() const override { return m_name.c_str(); }
    std::string param() const override;

    bool is_available(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    std::string to_string(AlgoParam algo_param);

protected:
    virtual DTypeEnum src_dtype() const = 0;

    // return filter_ptr, bias_ptr
    virtual std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const = 0;

    // return alpha, beta, gamma, delta, theta
    virtual std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const = 0;

    void reorder_filter(const ExecArgs& args, void* reordered_filter) const;

    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final
        : public AlgoInt4NCHW64IMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NCHW64IMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

914
    AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param)
915
            : Base{algo_param} {
916 917 918 919 920
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf("INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
                         to_string(m_algo_param).c_str()),
                ConvBias::DirectParam{});
    }
921

922 923 924 925 926 927 928
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

929
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4)
930

931
private:
932
    DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; }
933

934 935 936 937 938
    std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const override;

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;
939
};
940 941

class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final
942
        : public AlgoInt4NCHW64IMMAImplicitGemmBase {
943
public:
944 945 946
    using Base = AlgoInt4NCHW64IMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

947
    AlgoUInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param)
948
            : Base{algo_param} {
949 950 951 952 953
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf("UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
                         to_string(m_algo_param).c_str()),
                ConvBias::DirectParam{});
    }
954

955 956 957 958 959 960 961
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

962
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4)
963

964 965 966 967 968 969 970 971 972 973 974 975 976
private:
    DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; }

    std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const override;

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;

    void update_bias(const ExecArgs& args, void* updated_bias,
                     void* reduce_filter_ptr, void* reduce_workspace) const;
};

977 978
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase
        : public AlgoCutlassConvolutionBase {
979 980
public:
    AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param)
981
            : AlgoCutlassConvolutionBase(algo_param) {}
982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064

    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    const char* name() const override { return m_name.c_str(); }
    std::string param() const override;

    bool is_available(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    std::string to_string(AlgoParam algo_param);

protected:
    virtual DTypeEnum src_dtype() const = 0;

    // return filter_ptr, bias_ptr
    virtual std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const = 0;

    // return alpha, beta, gamma, delta, theta
    virtual std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const = 0;

    void reorder_filter(const ExecArgs& args, int interleaved,
                        void* reordered_filter) const;

    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final
        : public AlgoInt4NHWCIMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NHWCIMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

    AlgoInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf("INT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s",
                         to_string(m_algo_param).c_str()),
                ConvBias::DirectParam{});
    }

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT4_INT4)

private:
    DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; }

    std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const override;

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;
};

class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final
        : public AlgoInt4NHWCIMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NHWCIMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

    AlgoUInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf("UINT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s",
                         to_string(m_algo_param).c_str()),
                ConvBias::DirectParam{});
    }

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    size_t get_preprocess_workspace_in_bytes(
            const SizeArgs& args) const override;
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_UINT4_INT4)

1065
private:
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
    DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; }

    std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const override;

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;

    void update_bias(const ExecArgs& args, void* updated_bias,
                     void* reduce_filter_ptr, void* reduce_workspace) const;
1076
};
1077 1078
#endif

1079 1080 1081 1082 1083 1084
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

1085 1086 1087
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;
1088

1089
    const char* name() const override { return "CONVBIAS_BFLOAT16"; }
1090 1091 1092 1093

    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
1094

1095
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
1096 1097 1098 1099
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

1100 1101 1102
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
    AlgoBase::Mapper m_all_algos_map;
1103 1104 1105 1106 1107 1108

public:
    AlgoPack();

    std::vector<AlgoBase*> all_algos,
            //! non-cudnn algos, used for heuristic if cudnn is not supported
1109
            non_cudnn_algos, bfloat16_algos;
1110 1111
    std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
    std::vector<AlgoCUDNNConv> cudnn_convs;
1112
    AlgoFallbackNCHWQS8 fallback_nchw_qs8;
1113 1114 1115 1116 1117 1118 1119
    AlgoChanwise chanwise;
    AlgoChanwiseSmall chanwise_small;
    AlgoChanwise8x8x32 chanwise8x8x32;
    AlgoInplaceMatmul inplace_matmul;
    AlgoMatmul matmul;
    AlgoMatmul8x8x32 matmul8x8x32;
    AlgoBatchedMatmul batched_matmul;
1120
    std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
1121 1122 1123 1124 1125 1126 1127 1128 1129
    AlgoInt8CHWN4DotProdImplicitGemm int8_chwn4_dotprod;
#if CUDA_VERSION >= 10000
    AlgoQUInt4x4x32WMMA wmma_quint4x4x32;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemm> int8_chwn4_imma;
    std::vector<AlgoInt8NCHW4IMMAImplicitGemm> int8_nchw4_imma;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemmReorderFilter>
            int8_chwn4_imma_reorder_filter;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth>
            int8_chwn4_imma_unroll_width;
1130 1131 1132
#endif
#if CUDA_VERSION >= 10020
    std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma;
1133
    std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma;
1134
    std::vector<AlgoInt4Int4NCHW64IMMAImplicitGemm> int4_int4_nchw64_imma;
1135
    std::vector<AlgoUInt4Int4NCHW64IMMAImplicitGemm> uint4_int4_nchw64_imma;
1136 1137
    std::vector<AlgoInt4Int4NHWCIMMAImplicitGemm> int4_int4_nhwc_imma;
    std::vector<AlgoUInt4Int4NHWCIMMAImplicitGemm> uint4_int4_nhwc_imma;
1138
#endif
1139
    AlgoGroupConvGeneral group;
1140
    AlgoBFloat16 bfloat16;
1141 1142 1143 1144 1145

    AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);

    AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo);

1146 1147
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }

1148 1149 1150 1151 1152
private:
#if CUDA_VERSION >= 10000
    void fill_imma_algos();
#endif
    void fill_cudnn_algos();
1153
    void fill_dp4a_algos();
1154 1155 1156 1157 1158 1159
};

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen