algo.h 39.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */

#pragma once

#include "megdnn/oprs.h"

17 18
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
19
#include "src/common/utils.h"
20
#include "src/cuda/conv_bias/conv_bias_int8.cuh"
21 22 23
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh"
24
#include "src/cuda/cudnn_wrapper.h"
25
#include "src/cuda/handle.h"
26 27 28 29 30

#include <cuda.h>
#include <memory>
#include <unordered_map>

31 32 33 34 35 36 37 38 39 40 41
namespace cutlass {
namespace library {

// forward declaration of cutlass library concepts, we hope that algo.h does
// not depend on cutlass headers

class Operation;

}  // namespace library
}  // namespace cutlass

42 43 44 45 46 47 48 49 50 51 52 53 54 55
namespace megdnn {
namespace cuda {

/*!
 * \brief base class for conv bias algos
 *
 * All the algo impls should try to support non-contiguous batch dim, for group
 * conv execution.
 */
class ConvBiasForwardImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;

public:
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    enum class AlgoType : uint32_t {
        CUDA_CUDNN_CONVBIAS,
        CUDA_CHANWISE,
        CUDA_CHANWISE_SMALL,
        CUDA_CHANWISE_INT8X8X32,
        CUDA_CUDNN_CONV,
        CUDA_INPLACE_MATMUL,
        CUDA_MATMUL,
        CUDA_MATMUL_INT8X8X32,
        CUDA_BATCHED_MATMUL,
        CUDA_GROUP_CONV_GENERAL,
        CUDA_WMMA_UINT4X4X32,
        CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8,
75
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT8,
76
        CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4,
77
        CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4,
78 79
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT4_INT4,
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_UINT4_INT4,
80 81 82 83 84
        CUDA_BFLOAT16,
        CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
85 86
        CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4,
        CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4,
87
        CUDA_FALLBACK_NCHW_INT4
88 89 90
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

91
    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
92
    struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
93
        const ConvBiasForwardImpl* opr;
M
Megvii Engine Team 已提交
94
        const PreprocessedFilter* preprocessed_filter;
95

96
        std::string to_string() const;
M
Megvii Engine Team 已提交
97 98 99 100 101 102 103 104 105 106 107
        SizeArgs(
                const ConvBiasForwardImpl* opr, const TensorLayout& src,
                const TensorLayout& filter, const TensorLayout& bias,
                const TensorLayout& z, const TensorLayout& dst,
                const PreprocessedFilter* preprocessed_filter = nullptr);
        SizeArgs(
                const ConvBiasForwardImpl* opr, const TensorLayout& src,
                const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
                const TensorLayout& bias, const TensorLayout& z,
                const TensorLayout& dst,
                const PreprocessedFilter* preprocessed_filter = nullptr);
108 109

        void init_conv_bias_desc(conv_bias::CUDNNForwardDescs& desc) const {
M
Megvii Engine Team 已提交
110 111 112
            desc.set_conv_bias(
                    *src_layout, filter_meta, *dst_layout, *bias_layout, *z_layout,
                    opr->param());
113 114 115 116 117 118 119 120 121 122 123
        }

        void init_conv_desc(conv_bias::CUDNNForwardDescs& desc) const {
            desc.set_conv(*src_layout, filter_meta, *dst_layout, opr->param());
        }
    };
    struct ExecArgs : public SizeArgs {
        const TensorND *src_tensor, *filter_tensor, *bias_tensor, *z_tensor,
                *dst_tensor;
        Workspace workspace;

M
Megvii Engine Team 已提交
124 125 126 127 128
        ExecArgs(
                ConvBiasForwardImpl* opr, _megdnn_tensor_in src,
                _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z,
                _megdnn_tensor_out dst, _megdnn_workspace workspace,
                const PreprocessedFilter* preprocessed_filter = nullptr);
129 130 131 132
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;
M
Megvii Engine Team 已提交
133
    virtual size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
134
        MEGDNN_MARK_USED_VAR(args);
M
Megvii Engine Team 已提交
135 136 137 138
        return 0;
    }
    virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const {
139
        MEGDNN_MARK_USED_VAR(args);
M
Megvii Engine Team 已提交
140 141
        return {};
    }
142 143 144
    virtual void exec_preprocess(const ExecArgs& args) const {
        MEGDNN_MARK_USED_VAR(args);
    }
145 146 147 148 149

    bool is_available_wk(const SizeArgs& args, size_t limit) {
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }

150 151
    bool is_available_attribute(
            const SizeArgs& args,
152 153
            const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
            const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
154
            size_t limit = std::numeric_limits<size_t>::max()) {
155
        return contain_attribute_all(positive_attr) &&
M
Megvii Engine Team 已提交
156
               !contain_attribute_any(negative_attr) && is_available_wk(args, limit);
157 158
    }

M
Megvii Engine Team 已提交
159
    AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) {
160 161 162
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(
                req <= workspace.size,
M
Megvii Engine Team 已提交
163 164
                "conv bias fwd algo %s: required workspace %zu bytes, got %zu", name(),
                req, workspace.size);
165 166 167 168 169 170 171 172
        return *this;
    }

    virtual bool is_cudnn() const { return false; }
};

class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
public:
173 174
    AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
M
Megvii Engine Team 已提交
175 176 177
        megdnn_assert(
                CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
                CudnnAlgoPack::conv_fwd_algos().end());
178 179 180 181
        m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
        m_name = ConvBiasForward::algo_name<DefaultParam>(
                "CUDNN:ConvBiasActivation:" + m_attr.name, {});
    }
182 183 184 185 186 187 188 189

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    param::Convolution get_param_convolution(const SizeArgs& args) const;
    bool is_available(const SizeArgs&) const override;

    const char* name() const override { return m_name.c_str(); }

190 191 192 193 194
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
195 196 197
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
198 199
        return ret;
    }
200 201 202 203 204

    cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }

    bool is_cudnn() const override { return true; }

205 206 207 208 209 210 211 212
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }

213 214 215
private:
    std::string m_name;
    cudnnConvolutionFwdAlgo_t m_cudnn_enum;
216
    CudnnAlgoPack::Attr m_attr;
217 218 219 220 221 222 223 224 225 226
};

class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
227
            m_name = ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE", {});
228 229 230
        }
        return m_name.c_str();
    }
M
Megvii Engine Team 已提交
231
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
232

233 234
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)

235 236 237 238 239 240 241 242 243 244 245 246
private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
247
            m_name = ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE_SMALL", {});
248 249 250
        }
        return m_name.c_str();
    }
251
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
M
Megvii Engine Team 已提交
252
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
253 254 255 256 257 258 259 260 261 262 263 264

private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
265
            m_name = ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE_8X8X32", {});
266 267 268
        }
        return m_name.c_str();
    }
269
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32)
M
Megvii Engine Team 已提交
270
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
271 272 273 274 275 276 277

private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
public:
M
Megvii Engine Team 已提交
278 279 280 281
    AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(
                CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
                CudnnAlgoPack::conv_fwd_algos().end());
282 283 284 285
        m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
        m_name = ConvBiasForward::algo_name<DefaultParam>(
                "CUDNN:Convolution:" + m_attr.name, {});
    }
286 287 288 289 290

    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

291 292 293 294 295
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
296 297 298
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
299 300
        return ret;
    }
301 302 303 304 305 306

    const char* name() const override { return m_name.c_str(); }

    cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }

    bool is_cudnn() const override { return true; }
307 308 309 310 311 312 313 314 315

    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }

316 317 318
private:
    std::string m_name;
    cudnnConvolutionFwdAlgo_t m_cudnn_enum;
319
    CudnnAlgoPack::Attr m_attr;
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337

    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! compute small matmul in the kernel
class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>(
                    "INPLACE_MATMUL", {});
        }
        return m_name.c_str();
    }
338
    MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
M
Megvii Engine Team 已提交
339
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
340 341 342 343 344 345 346 347

private:
    mutable std::string m_name;
};

//! im2col and matmul, with dilation
class ConvBiasForwardImpl::AlgoMatmul final : public AlgoBase {
    template <typename T>
M
Megvii Engine Team 已提交
348
    static void exec_internal(const ExecArgs& args, const WorkspaceBundle& bundle);
349 350 351 352 353 354 355 356

public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
357
            m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>("MATMUL", {});
358 359 360
        }
        return m_name.c_str();
    }
361 362

    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
363
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
364
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
365
    AlgoAttribute attribute() const override {
M
Megvii Engine Team 已提交
366
        return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
367
    }
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoMatmul8x8x32 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
                    "MATMUL8X8X32", {});
        }
        return m_name.c_str();
    }
386
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32)
M
Megvii Engine Team 已提交
387
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407

private:
    bool need_src_unroll(const SizeArgs& args) const;
    bool need_filter_reshape(const SizeArgs& args) const;
    template <Param::Format>
    WorkspaceBundle get_bundle(const SizeArgs& args) const;
    template <Param::Format>
    void exec_internal(const ExecArgs& args) const;
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoBatchedMatmul final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
408
                    "BATCHED_MATMUL", {});
409 410 411
        }
        return m_name.c_str();
    }
412 413

    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
414
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
415

416
    AlgoAttribute attribute() const override {
M
Megvii Engine Team 已提交
417
        return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
418 419
    }

420
    MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
421 422 423 424 425 426 427 428 429 430 431 432 433

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
    mutable std::string m_name;
};

//! implement group conv by another algo
class ConvBiasForwardImpl::AlgoGroupConvGeneral final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

434
    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
435
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
436

437 438
    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
439
            m_name = ConvBiasForward::algo_name<DirectParam>("CUDA:GROUP_CONV", {});
440 441
        }
        return m_name.c_str();
442 443
    }

M
Megvii Engine Team 已提交
444
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
445

446 447
    MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)

448 449
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
450
    mutable std::string m_name;
451 452 453 454 455 456 457 458 459 460
};

#if CUDA_VERSION >= 10000
class ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA final : public AlgoBase {
public:
    AlgoQUInt4x4x32WMMA() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return "QUINT4x4x32_WMMA"; }
M
Megvii Engine Team 已提交
461
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
462

463
private:
M
Megvii Engine Team 已提交
464
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
465 466
    bool use_kernel_fhxfw(const SizeArgs& args) const;
    size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const;
467
    MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
468 469 470
};
#endif

M
Megvii Engine Team 已提交
471
class ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm final : public AlgoBase {
472 473 474 475 476
public:
    AlgoInt8CHWN4DotProdImplicitGemm() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
M
Megvii Engine Team 已提交
477 478
    const char* name() const override { return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"; }
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
479 480
    template <typename BiasVisitor>
    static void dispatch_nonlinear_mode(
M
Megvii Engine Team 已提交
481 482 483
            const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor,
            const int8_t* d_z, int8_t* d_dst, const convolution::ConvParam& param,
            float alpha, float beta, float gamma, float scale, cudaStream_t stream,
484
            param::ConvBias::NonlineMode nonlinear_mode);
485
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8)
486 487
};

488 489 490 491 492 493 494 495
/*********************** Cutlass Algorithms ************************/

/* The inheritance of cutlass algorithm classes:
 *
 * AlgoCutlassConvolutionBase
 * +
 * +--- AlgoInt8NCHW4DotProdImplicitGemm
 * +--- AlgoInt8NCHW32IMMAImplicitGemm
496
 * +--- AlgoInt8NHWCIMMAImplicitGemm
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
 * +
 * +--- AlgoInt4NCHW64IMMAImplicitGemmBase
 * +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm
 * +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm
 * +
 * +--- AlgoInt4NHWCIMMAImplicitGemmBase
 * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm
 * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm
 * +
 */

/*
 * The base class for all cutlass algorithm classes
 */
class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase {
512
public:
513 514 515 516 517 518
    // corresponds to cutlass::conv::Operator. we hope that algo.h does not
    // depend on cutlass headers
    enum class ConvOperator { kFprop, kDgrad, kWgrad };

    // corresponds to cutlass::conv::ConvType. we hope that algo.h does not
    // depend on cutlass headers
M
Megvii Engine Team 已提交
519
    enum class ConvType { kConvolution, kBatchConvolution, kLocal, kLocalShare };
520 521

    // common parameters for operation selection
522 523 524 525 526 527 528
    struct AlgoParam {
        int threadblock_m;
        int threadblock_n;
        int threadblock_k;
        int warp_m;
        int warp_n;
        int warp_k;
529 530 531
        int instruction_m;
        int instruction_n;
        int instruction_k;
532
        int stage;
533 534
        int access_size;

M
Megvii Engine Team 已提交
535 536 537 538
        AlgoParam(
                int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_,
                int warp_n_, int warp_k_, int instruction_m_, int instruction_n_,
                int instruction_k_, int stage_, int access_size_ = 0);
539 540

        std::string to_string() const;
541
    };
542

M
Megvii Engine Team 已提交
543
    AlgoCutlassConvolutionBase(AlgoParam algo_param) : m_algo_param{algo_param} {}
544 545 546 547 548

    // generate a cutlass::library::ConvolutionKey and find the corresponding
    // operation (cutlass kernel) from the global OperationTable
    const cutlass::library::Operation* get_cutlass_conv_op(
            const SizeArgs& args, ConvOperator conv_op, ConvType conv_type,
549
            bool use_conv_filter_unity_opt, bool without_shared_load) const;
550 551 552 553

    // execute the cutlass kernel found by get_cutlass_conv_op. we give
    // subclasses full freedom to decide where and how these arguments are
    // extracted
M
Megvii Engine Team 已提交
554 555 556 557 558 559 560 561
    void execute_cutlass_conv_op(
            const cutlass::library::Operation* op, const void* src, const void* filter,
            const void* bias, const void* z, void* dst, void* workspace, size_t n,
            size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, size_t ho,
            size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw,
            const void* alpha, const void* beta, const void* gamma, const void* delta,
            const void* theta, const void* threshold, const void* dst_scale,
            cudaStream_t stream, const void* extra_param = nullptr) const;
562 563 564 565 566 567 568 569

protected:
    AlgoParam m_algo_param;
};

class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
        : public AlgoCutlassConvolutionBase {
public:
570
    AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
571
            : AlgoCutlassConvolutionBase(algo_param),
M
Megvii Engine Team 已提交
572 573 574
              m_name{ssprintf(
                      "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
                      m_algo_param.to_string().c_str())} {}
575 576 577
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
578
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
579 580
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
581 582 583
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
584 585 586 587 588 589 590
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }
591 592

private:
M
Megvii Engine Team 已提交
593
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
594
    std::string m_name;
595 596
};

597 598 599 600 601
class ConvBiasForwardImpl::AlgoFallbackNCHWQS8 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
M
Megvii Engine Team 已提交
602 603
    const char* name() const override { return "FALLBACK_CONV_NCHW_QS8"; }
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
604
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
605
    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
606 607
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;

608 609 610 611
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

612
#if CUDA_VERSION >= 10000
M
Megvii Engine Team 已提交
613
class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final : public AlgoBase {
614
public:
M
Megvii Engine Team 已提交
615
    enum class MMATileSize : uint32_t { IMMA16x16x16, IMMA32x8x16, IMMA8x32x16 };
616 617
    AlgoInt8CHWN4IMMAImplicitGemm(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
M
Megvii Engine Team 已提交
618
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_" + to_string(m_mma_tile_size)} {}
619 620 621
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
622
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
623
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
624 625
    template <typename BiasVisitor>
    static void dispatch_nonlinear_mode(
M
Megvii Engine Team 已提交
626 627 628 629
            const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor,
            int8_t* d_z, int8_t* d_dst, const convolution::ConvParam& param,
            float alpha, float beta, float gamma, float scale, cudaStream_t stream,
            param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size);
630 631
    static std::string to_string(MMATileSize mma_tile_size);

632 633 634 635 636 637 638 639
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }

640 641 642 643 644
private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

M
Megvii Engine Team 已提交
645
class ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm final : public AlgoBase {
646 647 648 649 650
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8NCHW4IMMAImplicitGemm(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_NCHW4_IMMA_IMPLICIT_GEMM_" +
M
Megvii Engine Team 已提交
651
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {}
652 653 654
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
655 656 657 658 659 660 661 662
    const char* name() const override { return m_name.c_str(); }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
M
Megvii Engine Team 已提交
663
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
664

665
private:
M
Megvii Engine Team 已提交
666
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
667 668 669 670 671 672 673 674 675 676 677
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8CHWN4IMMAImplicitGemmReorderFilter(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_" +
M
Megvii Engine Team 已提交
678
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {}
679 680 681 682
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
683 684 685 686 687 688 689
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
M
Megvii Engine Team 已提交
690
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
691 692 693 694 695 696 697 698 699 700 701 702 703

private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_" +
M
Megvii Engine Team 已提交
704
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {}
705 706 707 708
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
709 710 711 712 713 714 715
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
M
Megvii Engine Team 已提交
716
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
717 718 719 720 721 722 723

private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};
#endif

724 725
#if CUDA_VERSION >= 10020
class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final
726
        : public AlgoCutlassConvolutionBase {
727 728
public:
    AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param)
729
            : AlgoCutlassConvolutionBase(algo_param) {
730
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
731 732 733
                ssprintf(
                        "INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
734 735 736 737 738 739
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
740
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
741
    static std::string to_string(AlgoParam algo_param);
M
Megvii Engine Team 已提交
742
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
743 744 745
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
746 747 748 749 750 751 752
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }
753

754
private:
M
Megvii Engine Team 已提交
755
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
756 757 758 759

    std::string m_name;
};

760 761 762 763 764 765
class ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm final
        : public AlgoCutlassConvolutionBase {
public:
    AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param)
            : AlgoCutlassConvolutionBase(algo_param) {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
766 767 768
                ssprintf(
                        "INT8_NHWC_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
769 770 771 772 773 774
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
775
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
776
    static std::string to_string(AlgoParam algo_param);
M
Megvii Engine Team 已提交
777
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }

private:
    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const;

M
Megvii Engine Team 已提交
793 794
    void reorder_filter(
            const ExecArgs& args, int interleaved, void* reordered_filter) const;
795 796 797 798

    std::string m_name;
};

799
class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase
800
        : public AlgoCutlassConvolutionBase {
801
public:
802
    AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param)
803
            : AlgoCutlassConvolutionBase(algo_param) {}
804

M
Megvii Engine Team 已提交
805
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
    const char* name() const override { return m_name.c_str(); }
    std::string param() const override;

    bool is_available(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    std::string to_string(AlgoParam algo_param);

protected:
    virtual DTypeEnum src_dtype() const = 0;

    // return filter_ptr, bias_ptr
    virtual std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const = 0;

    // return alpha, beta, gamma, delta, theta
    virtual std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const = 0;

    void reorder_filter(const ExecArgs& args, void* reordered_filter) const;

    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final
        : public AlgoInt4NCHW64IMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NCHW64IMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

M
Megvii Engine Team 已提交
836
    AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
837
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
838 839 840
                ssprintf(
                        "INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
841 842
                ConvBias::DirectParam{});
    }
843

844
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
845
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
846 847 848 849
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

850
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4)
851

852
private:
853
    DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; }
854

M
Megvii Engine Team 已提交
855
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
856 857 858

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;
859
};
860 861

class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final
862
        : public AlgoInt4NCHW64IMMAImplicitGemmBase {
863
public:
864 865 866
    using Base = AlgoInt4NCHW64IMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

M
Megvii Engine Team 已提交
867
    AlgoUInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
868
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
869 870 871
                ssprintf(
                        "UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
872 873
                ConvBias::DirectParam{});
    }
874

875
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
876
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
877 878 879 880
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

881
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4)
882

883 884 885
private:
    DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; }

M
Megvii Engine Team 已提交
886
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
887 888 889 890

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;

M
Megvii Engine Team 已提交
891 892 893
    void update_bias(
            const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr,
            void* reduce_workspace) const;
894 895
};

896 897
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase
        : public AlgoCutlassConvolutionBase {
898 899
public:
    AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param)
900
            : AlgoCutlassConvolutionBase(algo_param) {}
901

M
Megvii Engine Team 已提交
902
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921
    const char* name() const override { return m_name.c_str(); }
    std::string param() const override;

    bool is_available(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    std::string to_string(AlgoParam algo_param);

protected:
    virtual DTypeEnum src_dtype() const = 0;

    // return filter_ptr, bias_ptr
    virtual std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const = 0;

    // return alpha, beta, gamma, delta, theta
    virtual std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const = 0;

M
Megvii Engine Team 已提交
922 923
    void reorder_filter(
            const ExecArgs& args, int interleaved, void* reordered_filter) const;
924 925 926 927 928 929 930 931 932 933 934 935

    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final
        : public AlgoInt4NHWCIMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NHWCIMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

    AlgoInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
936 937 938
                ssprintf(
                        "INT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
939 940 941 942
                ConvBias::DirectParam{});
    }

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
943
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
944 945 946 947 948 949 950 951 952
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT4_INT4)

private:
    DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; }

M
Megvii Engine Team 已提交
953
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
954 955 956 957 958 959 960 961 962 963 964 965 966

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;
};

class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final
        : public AlgoInt4NHWCIMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NHWCIMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

    AlgoUInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
967 968 969
                ssprintf(
                        "UINT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
970 971 972 973
                ConvBias::DirectParam{});
    }

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
974
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
975 976 977 978 979 980
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_UINT4_INT4)

981
private:
982 983
    DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; }

M
Megvii Engine Team 已提交
984
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
985 986 987 988

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;

M
Megvii Engine Team 已提交
989 990 991
    void update_bias(
            const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr,
            void* reduce_workspace) const;
992
};
993 994
#endif

995 996 997 998 999 1000
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

1001
    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
1002
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
1003

1004
    const char* name() const override { return "CONVBIAS_BFLOAT16"; }
1005

M
Megvii Engine Team 已提交
1006
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
1007

1008
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
1009 1010 1011 1012
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

1013 1014 1015
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
    AlgoBase::Mapper m_all_algos_map;
1016 1017 1018 1019 1020 1021

public:
    AlgoPack();

    std::vector<AlgoBase*> all_algos,
            //! non-cudnn algos, used for heuristic if cudnn is not supported
1022
            non_cudnn_algos, bfloat16_algos;
1023 1024
    std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
    std::vector<AlgoCUDNNConv> cudnn_convs;
1025
    AlgoFallbackNCHWQS8 fallback_nchw_qs8;
1026 1027 1028 1029 1030 1031 1032
    AlgoChanwise chanwise;
    AlgoChanwiseSmall chanwise_small;
    AlgoChanwise8x8x32 chanwise8x8x32;
    AlgoInplaceMatmul inplace_matmul;
    AlgoMatmul matmul;
    AlgoMatmul8x8x32 matmul8x8x32;
    AlgoBatchedMatmul batched_matmul;
1033
    std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
1034 1035 1036 1037 1038 1039 1040
    AlgoInt8CHWN4DotProdImplicitGemm int8_chwn4_dotprod;
#if CUDA_VERSION >= 10000
    AlgoQUInt4x4x32WMMA wmma_quint4x4x32;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemm> int8_chwn4_imma;
    std::vector<AlgoInt8NCHW4IMMAImplicitGemm> int8_nchw4_imma;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemmReorderFilter>
            int8_chwn4_imma_reorder_filter;
M
Megvii Engine Team 已提交
1041
    std::vector<AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth> int8_chwn4_imma_unroll_width;
1042 1043 1044
#endif
#if CUDA_VERSION >= 10020
    std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma;
1045
    std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma;
1046
    std::vector<AlgoInt4Int4NCHW64IMMAImplicitGemm> int4_int4_nchw64_imma;
1047
    std::vector<AlgoUInt4Int4NCHW64IMMAImplicitGemm> uint4_int4_nchw64_imma;
1048 1049
    std::vector<AlgoInt4Int4NHWCIMMAImplicitGemm> int4_int4_nhwc_imma;
    std::vector<AlgoUInt4Int4NHWCIMMAImplicitGemm> uint4_int4_nhwc_imma;
1050
#endif
1051
    AlgoGroupConvGeneral group;
1052
    AlgoBFloat16 bfloat16;
1053 1054 1055 1056 1057

    AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);

    AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo);

1058 1059
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }

1060 1061 1062 1063 1064
private:
#if CUDA_VERSION >= 10000
    void fill_imma_algos();
#endif
    void fill_cudnn_algos();
1065
    void fill_dp4a_algos();
1066 1067 1068 1069 1070 1071
};

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen