algo.h 42.9 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */

#pragma once

#include "megdnn/oprs.h"

17 18
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
19
#include "src/common/utils.h"
20
#include "src/cuda/conv_bias/conv_bias_int8.cuh"
21 22 23
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh"
24
#include "src/cuda/cudnn_wrapper.h"
25 26 27 28 29

#include <cuda.h>
#include <memory>
#include <unordered_map>

30 31 32 33 34 35 36 37 38 39 40
namespace cutlass {
namespace library {

// forward declaration of cutlass library concepts, we hope that algo.h does
// not depend on cutlass headers

class Operation;

}  // namespace library
}  // namespace cutlass

41 42 43 44 45 46 47 48 49 50 51 52 53 54
namespace megdnn {
namespace cuda {

/*!
 * \brief base class for conv bias algos
 *
 * All the algo impls should try to support non-contiguous batch dim, for group
 * conv execution.
 */
class ConvBiasForwardImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;

public:
55 56 57 58
    enum class AlgoType : uint32_t {
        CUDA_CUDNN_CONVBIAS,
        CUDA_CHANWISE,
        CUDA_CHANWISE_SMALL,
59
        CUDA_DEPTHWISE_LARGE_FILTER,
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        CUDA_CHANWISE_INT8X8X32,
        CUDA_CUDNN_CONV,
        CUDA_INPLACE_MATMUL,
        CUDA_MATMUL,
        CUDA_MATMUL_INT8X8X32,
        CUDA_BATCHED_MATMUL,
        CUDA_GROUP_CONV_GENERAL,
        CUDA_WMMA_UINT4X4X32,
        CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8,
75
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT8,
76
        CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4,
77
        CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4,
78 79
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT4_INT4,
        CUDA_IMPLICIT_GEMM_IMMA_NHWC_UINT4_INT4,
80 81 82 83 84
        CUDA_BFLOAT16,
        CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
        CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
85 86
        CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4,
        CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4,
87 88 89
        CUDA_FALLBACK_NCHW_INT4,
        CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
        CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
90 91 92
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

93
    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
94
    struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
95
        const ConvBiasForwardImpl* opr;
M
Megvii Engine Team 已提交
96
        const PreprocessedFilter* preprocessed_filter;
97

98
        std::string to_string() const;
M
Megvii Engine Team 已提交
99 100 101 102 103 104 105 106 107 108 109
        SizeArgs(
                const ConvBiasForwardImpl* opr, const TensorLayout& src,
                const TensorLayout& filter, const TensorLayout& bias,
                const TensorLayout& z, const TensorLayout& dst,
                const PreprocessedFilter* preprocessed_filter = nullptr);
        SizeArgs(
                const ConvBiasForwardImpl* opr, const TensorLayout& src,
                const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
                const TensorLayout& bias, const TensorLayout& z,
                const TensorLayout& dst,
                const PreprocessedFilter* preprocessed_filter = nullptr);
110 111

        void init_conv_bias_desc(conv_bias::CUDNNForwardDescs& desc) const {
M
Megvii Engine Team 已提交
112 113 114
            desc.set_conv_bias(
                    *src_layout, filter_meta, *dst_layout, *bias_layout, *z_layout,
                    opr->param());
115 116 117 118 119 120 121 122 123 124 125
        }

        void init_conv_desc(conv_bias::CUDNNForwardDescs& desc) const {
            desc.set_conv(*src_layout, filter_meta, *dst_layout, opr->param());
        }
    };
    struct ExecArgs : public SizeArgs {
        const TensorND *src_tensor, *filter_tensor, *bias_tensor, *z_tensor,
                *dst_tensor;
        Workspace workspace;

M
Megvii Engine Team 已提交
126 127 128 129 130
        ExecArgs(
                ConvBiasForwardImpl* opr, _megdnn_tensor_in src,
                _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z,
                _megdnn_tensor_out dst, _megdnn_workspace workspace,
                const PreprocessedFilter* preprocessed_filter = nullptr);
131 132 133 134
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;
M
Megvii Engine Team 已提交
135
    virtual size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
136
        MEGDNN_MARK_USED_VAR(args);
M
Megvii Engine Team 已提交
137 138 139 140
        return 0;
    }
    virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const {
141
        MEGDNN_MARK_USED_VAR(args);
M
Megvii Engine Team 已提交
142 143
        return {};
    }
144 145 146
    virtual void exec_preprocess(const ExecArgs& args) const {
        MEGDNN_MARK_USED_VAR(args);
    }
147 148 149 150 151

    bool is_available_wk(const SizeArgs& args, size_t limit) {
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }

152 153
    bool is_available_attribute(
            const SizeArgs& args,
154 155
            const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
            const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
156
            size_t limit = std::numeric_limits<size_t>::max()) {
157
        return contain_attribute_all(positive_attr) &&
M
Megvii Engine Team 已提交
158
               !contain_attribute_any(negative_attr) && is_available_wk(args, limit);
159 160
    }

M
Megvii Engine Team 已提交
161
    AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) {
162 163 164
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(
                req <= workspace.size,
M
Megvii Engine Team 已提交
165 166
                "conv bias fwd algo %s: required workspace %zu bytes, got %zu", name(),
                req, workspace.size);
167 168 169 170 171 172 173 174
        return *this;
    }

    virtual bool is_cudnn() const { return false; }
};

class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
public:
175 176
    AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
M
Megvii Engine Team 已提交
177 178 179
        megdnn_assert(
                CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
                CudnnAlgoPack::conv_fwd_algos().end());
180 181 182 183
        m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
        m_name = ConvBiasForward::algo_name<DefaultParam>(
                "CUDNN:ConvBiasActivation:" + m_attr.name, {});
    }
184 185 186 187 188 189 190 191

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    param::Convolution get_param_convolution(const SizeArgs& args) const;
    bool is_available(const SizeArgs&) const override;

    const char* name() const override { return m_name.c_str(); }

192 193 194 195 196
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
197 198 199
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
200 201
        return ret;
    }
202 203 204 205 206

    cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }

    bool is_cudnn() const override { return true; }

207 208 209 210 211 212 213 214
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }

215 216 217
private:
    std::string m_name;
    cudnnConvolutionFwdAlgo_t m_cudnn_enum;
218
    CudnnAlgoPack::Attr m_attr;
219 220 221 222 223 224 225 226 227 228
};

class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
229
            m_name = ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE", {});
230 231 232
        }
        return m_name.c_str();
    }
M
Megvii Engine Team 已提交
233
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
234

235 236
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)

237 238 239 240 241 242 243 244 245 246 247 248
private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
249
            m_name = ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE_SMALL", {});
250 251 252
        }
        return m_name.c_str();
    }
253
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
M
Megvii Engine Team 已提交
254
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
255 256 257 258 259

private:
    mutable std::string m_name;
};

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
class ConvBiasForwardImpl::AlgoDepthwiseLargeFilter final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<DirectParam>(
                    "DEPTHWISE_LARGE_FILTER", {});
        }
        return m_name.c_str();
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER)
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }

private:
    mutable std::string m_name;
};

280 281 282 283 284 285 286
class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
287
            m_name = ConvBiasForward::algo_name<DirectParam>("CHANNEL_WISE_8X8X32", {});
288 289 290
        }
        return m_name.c_str();
    }
291
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32)
M
Megvii Engine Team 已提交
292
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
293 294 295 296 297 298 299

private:
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
public:
M
Megvii Engine Team 已提交
300 301 302 303
    AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(
                CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
                CudnnAlgoPack::conv_fwd_algos().end());
304 305 306 307
        m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
        m_name = ConvBiasForward::algo_name<DefaultParam>(
                "CUDNN:Convolution:" + m_attr.name, {});
    }
308 309 310 311 312

    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

313 314 315 316 317
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
318 319 320
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
321 322
        return ret;
    }
323 324 325 326 327 328

    const char* name() const override { return m_name.c_str(); }

    cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }

    bool is_cudnn() const override { return true; }
329 330 331 332 333 334 335 336 337

    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }

338 339 340
private:
    std::string m_name;
    cudnnConvolutionFwdAlgo_t m_cudnn_enum;
341
    CudnnAlgoPack::Attr m_attr;
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359

    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! compute small matmul in the kernel
class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>(
                    "INPLACE_MATMUL", {});
        }
        return m_name.c_str();
    }
360
    MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
M
Megvii Engine Team 已提交
361
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
362 363 364 365 366 367 368 369

private:
    mutable std::string m_name;
};

//! im2col and matmul, with dilation
class ConvBiasForwardImpl::AlgoMatmul final : public AlgoBase {
    template <typename T>
M
Megvii Engine Team 已提交
370
    static void exec_internal(const ExecArgs& args, const WorkspaceBundle& bundle);
371 372 373 374 375 376 377 378

public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
379
            m_name = ConvBiasForward::algo_name<ConvBias::MatmulParam>("MATMUL", {});
380 381 382
        }
        return m_name.c_str();
    }
383 384

    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
385
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
386
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
387
    AlgoAttribute attribute() const override {
M
Megvii Engine Team 已提交
388
        return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
389
    }
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoMatmul8x8x32 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
                    "MATMUL8X8X32", {});
        }
        return m_name.c_str();
    }
408
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32)
M
Megvii Engine Team 已提交
409
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429

private:
    bool need_src_unroll(const SizeArgs& args) const;
    bool need_filter_reshape(const SizeArgs& args) const;
    template <Param::Format>
    WorkspaceBundle get_bundle(const SizeArgs& args) const;
    template <Param::Format>
    void exec_internal(const ExecArgs& args) const;
    mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoBatchedMatmul final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
430
                    "BATCHED_MATMUL", {});
431 432 433
        }
        return m_name.c_str();
    }
434 435

    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
436
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
437

438
    AlgoAttribute attribute() const override {
M
Megvii Engine Team 已提交
439
        return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
440 441
    }

442
    MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
443 444 445 446 447 448 449 450 451 452 453 454 455

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
    mutable std::string m_name;
};

//! implement group conv by another algo
class ConvBiasForwardImpl::AlgoGroupConvGeneral final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

456
    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
457
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
458

459 460
    const char* name() const override {
        if (m_name.empty()) {
M
Megvii Engine Team 已提交
461
            m_name = ConvBiasForward::algo_name<DirectParam>("CUDA:GROUP_CONV", {});
462 463
        }
        return m_name.c_str();
464 465
    }

M
Megvii Engine Team 已提交
466
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
467

468 469
    MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)

470 471
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
472
    mutable std::string m_name;
473 474 475 476 477 478 479 480 481 482
};

#if CUDA_VERSION >= 10000
class ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA final : public AlgoBase {
public:
    AlgoQUInt4x4x32WMMA() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return "QUINT4x4x32_WMMA"; }
M
Megvii Engine Team 已提交
483
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
484

485
private:
M
Megvii Engine Team 已提交
486
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
487 488
    bool use_kernel_fhxfw(const SizeArgs& args) const;
    size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const;
489
    MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
490 491 492
};
#endif

M
Megvii Engine Team 已提交
493
class ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm final : public AlgoBase {
494 495 496 497 498
public:
    AlgoInt8CHWN4DotProdImplicitGemm() = default;
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
M
Megvii Engine Team 已提交
499 500
    const char* name() const override { return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"; }
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
501 502
    template <typename BiasVisitor>
    static void dispatch_nonlinear_mode(
M
Megvii Engine Team 已提交
503 504 505
            const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor,
            const int8_t* d_z, int8_t* d_dst, const convolution::ConvParam& param,
            float alpha, float beta, float gamma, float scale, cudaStream_t stream,
506
            param::ConvBias::NonlineMode nonlinear_mode);
507
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8)
508 509
};

510 511 512 513 514 515 516 517
/*********************** Cutlass Algorithms ************************/

/* The inheritance of cutlass algorithm classes:
 *
 * AlgoCutlassConvolutionBase
 * +
 * +--- AlgoInt8NCHW4DotProdImplicitGemm
 * +--- AlgoInt8NCHW32IMMAImplicitGemm
518
 * +--- AlgoInt8NHWCIMMAImplicitGemm
519 520 521 522 523 524 525 526 527
 * +
 * +--- AlgoInt4NCHW64IMMAImplicitGemmBase
 * +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm
 * +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm
 * +
 * +--- AlgoInt4NHWCIMMAImplicitGemmBase
 * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm
 * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm
 * +
528 529
 * +--- AlgoFloat32NCHWImplicitBatchedGemm
 * +--- AlgoFloat16NCHWHMMAImplicitBatchedGemm
530 531 532 533 534 535
 */

/*
 * The base class for all cutlass algorithm classes
 */
class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase {
536
public:
537 538 539 540 541 542
    // corresponds to cutlass::conv::Operator. we hope that algo.h does not
    // depend on cutlass headers
    enum class ConvOperator { kFprop, kDgrad, kWgrad };

    // corresponds to cutlass::conv::ConvType. we hope that algo.h does not
    // depend on cutlass headers
543 544 545 546 547 548 549
    enum class ConvType {
        kConvolution,
        kBatchConvolution,
        kLocal,
        kLocalShare,
        kDepthwiseConvolution,
    };
550 551

    // common parameters for operation selection
552 553 554 555 556 557 558
    struct AlgoParam {
        int threadblock_m;
        int threadblock_n;
        int threadblock_k;
        int warp_m;
        int warp_n;
        int warp_k;
559 560 561
        int instruction_m;
        int instruction_n;
        int instruction_k;
562
        int stage;
563 564
        int access_size;

M
Megvii Engine Team 已提交
565 566 567 568
        AlgoParam(
                int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_,
                int warp_n_, int warp_k_, int instruction_m_, int instruction_n_,
                int instruction_k_, int stage_, int access_size_ = 0);
569 570

        std::string to_string() const;
571
    };
572

M
Megvii Engine Team 已提交
573
    AlgoCutlassConvolutionBase(AlgoParam algo_param) : m_algo_param{algo_param} {}
574 575 576 577 578

    // generate a cutlass::library::ConvolutionKey and find the corresponding
    // operation (cutlass kernel) from the global OperationTable
    const cutlass::library::Operation* get_cutlass_conv_op(
            const SizeArgs& args, ConvOperator conv_op, ConvType conv_type,
579
            bool use_conv_filter_unity_opt, bool without_shared_load) const;
580 581 582 583

    // execute the cutlass kernel found by get_cutlass_conv_op. we give
    // subclasses full freedom to decide where and how these arguments are
    // extracted
M
Megvii Engine Team 已提交
584 585 586 587 588 589 590
    void execute_cutlass_conv_op(
            const cutlass::library::Operation* op, const void* src, const void* filter,
            const void* bias, const void* z, void* dst, void* workspace, size_t n,
            size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, size_t ho,
            size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw,
            const void* alpha, const void* beta, const void* gamma, const void* delta,
            const void* theta, const void* threshold, const void* dst_scale,
591 592
            cudaStream_t stream, const void* extra_param = nullptr,
            size_t groups = 1) const;
593 594 595 596 597 598 599 600

protected:
    AlgoParam m_algo_param;
};

class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
        : public AlgoCutlassConvolutionBase {
public:
601
    AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
602
            : AlgoCutlassConvolutionBase(algo_param),
M
Megvii Engine Team 已提交
603 604 605
              m_name{ssprintf(
                      "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
                      m_algo_param.to_string().c_str())} {}
606 607 608
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
609
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
610 611
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
612 613 614
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
615 616 617 618 619 620 621
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }
622 623

private:
M
Megvii Engine Team 已提交
624
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
625
    std::string m_name;
626 627
};

628 629 630 631 632
class ConvBiasForwardImpl::AlgoFallbackNCHWQS8 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
M
Megvii Engine Team 已提交
633 634
    const char* name() const override { return "FALLBACK_CONV_NCHW_QS8"; }
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
635
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
636
    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
637 638
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;

639 640 641 642
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

643
#if CUDA_VERSION >= 10000
M
Megvii Engine Team 已提交
644
class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final : public AlgoBase {
645
public:
M
Megvii Engine Team 已提交
646
    enum class MMATileSize : uint32_t { IMMA16x16x16, IMMA32x8x16, IMMA8x32x16 };
647 648
    AlgoInt8CHWN4IMMAImplicitGemm(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
M
Megvii Engine Team 已提交
649
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_" + to_string(m_mma_tile_size)} {}
650 651 652
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
653
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
654
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
655 656
    template <typename BiasVisitor>
    static void dispatch_nonlinear_mode(
M
Megvii Engine Team 已提交
657 658 659 660
            const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor,
            int8_t* d_z, int8_t* d_dst, const convolution::ConvParam& param,
            float alpha, float beta, float gamma, float scale, cudaStream_t stream,
            param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size);
661 662
    static std::string to_string(MMATileSize mma_tile_size);

663 664 665 666 667 668 669 670
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }

671 672 673 674 675
private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

M
Megvii Engine Team 已提交
676
class ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm final : public AlgoBase {
677 678 679 680 681
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8NCHW4IMMAImplicitGemm(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_NCHW4_IMMA_IMPLICIT_GEMM_" +
M
Megvii Engine Team 已提交
682
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {}
683 684 685
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
686 687 688 689 690 691 692 693
    const char* name() const override { return m_name.c_str(); }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
M
Megvii Engine Team 已提交
694
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
695

696
private:
M
Megvii Engine Team 已提交
697
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
698 699 700 701 702 703 704 705 706 707 708
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8CHWN4IMMAImplicitGemmReorderFilter(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_" +
M
Megvii Engine Team 已提交
709
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {}
710 711 712 713
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
714 715 716 717 718 719 720
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
M
Megvii Engine Team 已提交
721
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
722 723 724 725 726 727 728 729 730 731 732 733 734

private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth final
        : public AlgoBase {
public:
    using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize;
    AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth(MMATileSize mma_tile_size)
            : m_mma_tile_size{mma_tile_size},
              m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_" +
M
Megvii Engine Team 已提交
735
                     AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {}
736 737 738 739
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
740 741 742 743 744 745 746
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_mma_tile_size, ret);
        return ret;
    }
M
Megvii Engine Team 已提交
747
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
748 749 750 751 752 753 754

private:
    MMATileSize m_mma_tile_size;
    std::string m_name;
};
#endif

755 756
#if CUDA_VERSION >= 10020
class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final
757
        : public AlgoCutlassConvolutionBase {
758 759
public:
    AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param)
760
            : AlgoCutlassConvolutionBase(algo_param) {
761
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
762 763 764
                ssprintf(
                        "INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
765 766 767 768 769 770
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
771
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
772
    static std::string to_string(AlgoParam algo_param);
M
Megvii Engine Team 已提交
773
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
774 775 776
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
777 778 779 780 781 782 783
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }
784

785
private:
M
Megvii Engine Team 已提交
786
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
787 788 789 790

    std::string m_name;
};

791 792 793 794 795 796
class ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm final
        : public AlgoCutlassConvolutionBase {
public:
    AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param)
            : AlgoCutlassConvolutionBase(algo_param) {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
797 798 799
                ssprintf(
                        "INT8_NHWC_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
800 801 802 803 804 805
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
M
Megvii Engine Team 已提交
806
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
807
    static std::string to_string(AlgoParam algo_param);
M
Megvii Engine Team 已提交
808
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT8)

    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_algo_param, ret);
        return ret;
    }

private:
    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const;

M
Megvii Engine Team 已提交
824 825
    void reorder_filter(
            const ExecArgs& args, int interleaved, void* reordered_filter) const;
826 827 828 829

    std::string m_name;
};

830
class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase
831
        : public AlgoCutlassConvolutionBase {
832
public:
833
    AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param)
834
            : AlgoCutlassConvolutionBase(algo_param) {}
835

M
Megvii Engine Team 已提交
836
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
    const char* name() const override { return m_name.c_str(); }
    std::string param() const override;

    bool is_available(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    std::string to_string(AlgoParam algo_param);

protected:
    virtual DTypeEnum src_dtype() const = 0;

    // return filter_ptr, bias_ptr
    virtual std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const = 0;

    // return alpha, beta, gamma, delta, theta
    virtual std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const = 0;

    void reorder_filter(const ExecArgs& args, void* reordered_filter) const;

    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final
        : public AlgoInt4NCHW64IMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NCHW64IMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

M
Megvii Engine Team 已提交
867
    AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
868
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
869 870 871
                ssprintf(
                        "INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
872 873
                ConvBias::DirectParam{});
    }
874

875
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
876
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
877 878 879 880
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

881
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4)
882

883
private:
884
    DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; }
885

M
Megvii Engine Team 已提交
886
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
887 888 889

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;
890
};
891 892

class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final
893
        : public AlgoInt4NCHW64IMMAImplicitGemmBase {
894
public:
895 896 897
    using Base = AlgoInt4NCHW64IMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

M
Megvii Engine Team 已提交
898
    AlgoUInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
899
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
900 901 902
                ssprintf(
                        "UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
903 904
                ConvBias::DirectParam{});
    }
905

906
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
907
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
908 909 910 911
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

912
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4)
913

914 915 916
private:
    DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; }

M
Megvii Engine Team 已提交
917
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
918 919 920 921

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;

M
Megvii Engine Team 已提交
922 923 924
    void update_bias(
            const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr,
            void* reduce_workspace) const;
925 926
};

927 928
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase
        : public AlgoCutlassConvolutionBase {
929 930
public:
    AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param)
931
            : AlgoCutlassConvolutionBase(algo_param) {}
932

M
Megvii Engine Team 已提交
933
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
    const char* name() const override { return m_name.c_str(); }
    std::string param() const override;

    bool is_available(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

    std::string to_string(AlgoParam algo_param);

protected:
    virtual DTypeEnum src_dtype() const = 0;

    // return filter_ptr, bias_ptr
    virtual std::tuple<void*, void*> prepare_filter_bias(
            const ExecArgs& args) const = 0;

    // return alpha, beta, gamma, delta, theta
    virtual std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const = 0;

M
Megvii Engine Team 已提交
953 954
    void reorder_filter(
            const ExecArgs& args, int interleaved, void* reordered_filter) const;
955 956 957 958 959 960 961 962 963 964 965 966

    std::string m_name;
};

class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final
        : public AlgoInt4NHWCIMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NHWCIMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

    AlgoInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
967 968 969
                ssprintf(
                        "INT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
970 971 972 973
                ConvBias::DirectParam{});
    }

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
974
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
975 976 977 978 979 980 981 982 983
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_INT4_INT4)

private:
    DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; }

M
Megvii Engine Team 已提交
984
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
985 986 987 988 989 990 991 992 993 994 995 996 997

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;
};

class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final
        : public AlgoInt4NHWCIMMAImplicitGemmBase {
public:
    using Base = AlgoInt4NHWCIMMAImplicitGemmBase;
    using AlgoParam = Base::AlgoParam;

    AlgoUInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
M
Megvii Engine Team 已提交
998 999 1000
                ssprintf(
                        "UINT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s",
                        to_string(m_algo_param).c_str()),
1001 1002 1003 1004
                ConvBias::DirectParam{});
    }

    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
M
Megvii Engine Team 已提交
1005
    size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
1006 1007 1008 1009 1010 1011
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const SizeArgs& args) const override;
    void exec_preprocess(const ExecArgs& args) const override;

    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NHWC_UINT4_INT4)

1012
private:
1013 1014
    DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; }

M
Megvii Engine Team 已提交
1015
    std::tuple<void*, void*> prepare_filter_bias(const ExecArgs& args) const override;
1016 1017 1018 1019

    std::tuple<float, float, float, float, float> get_constants(
            const ExecArgs& args) const override;

M
Megvii Engine Team 已提交
1020 1021 1022
    void update_bias(
            const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr,
            void* reduce_workspace) const;
1023
};
1024 1025
#endif

1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073
class ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final
        : public AlgoCutlassConvolutionBase {
public:
    AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param)
            : AlgoCutlassConvolutionBase(algo_param) {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf(
                        "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s",
                        m_algo_param.to_string().c_str()),
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
        return 0;
    }
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); };
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32);

private:
    std::string m_name;
};

class ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final
        : public AlgoCutlassConvolutionBase {
public:
    AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param)
            : AlgoCutlassConvolutionBase(algo_param) {
        m_name = ConvBias::algo_name<ConvBias::DirectParam>(
                ssprintf(
                        "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s",
                        m_algo_param.to_string().c_str()),
                ConvBias::DirectParam{});
    }
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
        return 0;
    }
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); };
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16);

private:
    std::string m_name;
};

1074 1075 1076 1077 1078 1079
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

1080
    std::vector<SearchItem> get_subopr_list(
M
Megvii Engine Team 已提交
1081
            const TensorLayoutArray& layouts, const OperatorBase* opr) const override;
1082

1083
    const char* name() const override { return "CONVBIAS_BFLOAT16"; }
1084

M
Megvii Engine Team 已提交
1085
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
1086

1087
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
1088 1089 1090 1091
private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

1092 1093 1094
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
    AlgoBase::Mapper m_all_algos_map;
1095 1096 1097 1098 1099 1100

public:
    AlgoPack();

    std::vector<AlgoBase*> all_algos,
            //! non-cudnn algos, used for heuristic if cudnn is not supported
1101
            non_cudnn_algos, bfloat16_algos;
1102 1103
    std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
    std::vector<AlgoCUDNNConv> cudnn_convs;
1104
    AlgoFallbackNCHWQS8 fallback_nchw_qs8;
1105 1106
    AlgoChanwise chanwise;
    AlgoChanwiseSmall chanwise_small;
1107
    AlgoDepthwiseLargeFilter depthwise_large_filter;
1108 1109 1110 1111 1112
    AlgoChanwise8x8x32 chanwise8x8x32;
    AlgoInplaceMatmul inplace_matmul;
    AlgoMatmul matmul;
    AlgoMatmul8x8x32 matmul8x8x32;
    AlgoBatchedMatmul batched_matmul;
1113
    std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
1114 1115 1116 1117 1118 1119 1120
    AlgoInt8CHWN4DotProdImplicitGemm int8_chwn4_dotprod;
#if CUDA_VERSION >= 10000
    AlgoQUInt4x4x32WMMA wmma_quint4x4x32;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemm> int8_chwn4_imma;
    std::vector<AlgoInt8NCHW4IMMAImplicitGemm> int8_nchw4_imma;
    std::vector<AlgoInt8CHWN4IMMAImplicitGemmReorderFilter>
            int8_chwn4_imma_reorder_filter;
M
Megvii Engine Team 已提交
1121
    std::vector<AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth> int8_chwn4_imma_unroll_width;
1122 1123 1124
#endif
#if CUDA_VERSION >= 10020
    std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma;
1125
    std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma;
1126
    std::vector<AlgoInt4Int4NCHW64IMMAImplicitGemm> int4_int4_nchw64_imma;
1127
    std::vector<AlgoUInt4Int4NCHW64IMMAImplicitGemm> uint4_int4_nchw64_imma;
1128 1129
    std::vector<AlgoInt4Int4NHWCIMMAImplicitGemm> int4_int4_nhwc_imma;
    std::vector<AlgoUInt4Int4NHWCIMMAImplicitGemm> uint4_int4_nhwc_imma;
1130
#endif
1131 1132
    std::vector<AlgoFloat32NCHWFMAImplicitBatchedGemm> f32_implicit_bmm;
    std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> f16_implicit_bmm;
1133
    AlgoGroupConvGeneral group;
1134
    AlgoBFloat16 bfloat16;
1135 1136 1137 1138 1139

    AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);

    AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo);

1140 1141
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }

1142 1143 1144 1145 1146
private:
#if CUDA_VERSION >= 10000
    void fill_imma_algos();
#endif
    void fill_cudnn_algos();
1147
    void fill_dp4a_algos();
1148
    void fill_dwconv_algos();
1149 1150 1151 1152 1153 1154
};

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen