algo.h 11.4 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/backward_data/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

#pragma once

#include <unordered_map>
16 17 18 19
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"
#include "src/cuda/cudnn_wrapper.h"
20 21 22 23 24 25 26 27 28 29

namespace megdnn {
namespace cuda {

/*!
 * \brief base class for convolution algos
 *
 * All the algo impls should try to support non-contiguous batch dim, for group
 * conv execution.
 */
30 31 32
class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;
33

34 35 36 37 38 39 40 41
public:
    enum class AlgoType : uint32_t {
        CUDA_CUDNN,
        CUDA_MATMUL,
        CUDA_CHANWISE,
        CUDA_CHANWISE_SMALL,
        CUDA_BFLOAT16,
        CUDA_GROUP_CONV_GENERAL,
42 43
        CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8
44 45 46 47 48 49 50 51 52 53 54 55 56
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
    struct SizeArgs {
        HandleImpl* handle;
        CanonizedFilterMeta filter_meta;
        const TensorLayout *diff_layout, *grad_layout, *filter_layout;
        ConvolutionBackwardDataImpl* opr;

        std::string to_string() const;
        void init_desc(convolution::CUDNNBwdDataDescs& desc) const {
            desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
57
        }
58 59 60 61 62 63 64 65 66
        SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
                 const TensorLayout& diff, const TensorLayout& grad);
        SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
                 const CanonizedFilterMeta& filter_meta,
                 const TensorLayout& diff, const TensorLayout& grad);

        convolution::ForwardSizeArgs as_fwd_args() const {
            return {handle, grad_layout, filter_layout, filter_meta,
                    diff_layout};
67
        }
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    };
    struct ExecArgs : public SizeArgs {
        const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
        Workspace workspace;

        ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
                 _megdnn_tensor_in diff, _megdnn_tensor_out grad,
                 _megdnn_workspace workspace);
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;

    bool is_available_wk(const SizeArgs& args, size_t limit) {
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }

    bool is_available_reproducible(
            const SizeArgs& args, bool reproducible = true,
            size_t limit = std::numeric_limits<size_t>::max()) {
88 89
        return (!reproducible ||
                contain_attribute(AlgoAttribute::REPRODUCIBLE)) &&
90 91 92 93 94 95 96 97 98 99 100 101 102 103
               is_available_wk(args, limit);
    }

    AlgoBase& check_workspace(const SizeArgs& args,
                              const Workspace& workspace) {
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(req <= workspace.size,
                      "conv bwd data algo %s: "
                      "required workspace %zu bytes, got %zu",
                      name(), req, workspace.size);
        return *this;
    }

    virtual bool is_cudnn() const { return false; }
104 105 106 107
};

class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
    cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
108
    CudnnAlgoPack::Attr m_attr;
109

110 111 112 113 114 115 116
public:
    AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) !=
                      CudnnAlgoPack::conv_bwd_data_algos().end());
        m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum);
    }
117

118 119 120
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
121

122
    const char* name() const override { return m_attr.name.c_str(); }
123 124 125 126 127 128 129
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
        return ret;
    }
130
    cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
131

132 133
    bool is_cudnn() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
134

135 136 137 138 139
    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }
140 141 142
};

//! im2col and matmul, with dilation
143 144 145
class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase {
    template <typename T>
    static void exec_internal(const ExecArgs& args);
146

147 148 149 150
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
151

152 153 154 155
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

156 157
    const char* name() const override { return "MATMUL"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
158 159 160
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
161 162
};

163 164 165 166 167
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
168

169 170
    const char* name() const override { return "CHANNEL_WISE"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
171 172 173
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
174 175
};

176 177 178 179 180
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
181

182 183
    const char* name() const override { return "CHANNEL_WISE_SMALL"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
184 185 186
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
187 188
};

189 190 191 192 193 194
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

195 196 197 198 199 200 201
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

    const char* name() const override {
        return "CONVOLUTION_BACKWARD_DATD_BFLOAT16";
    }
202 203 204 205

    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
206 207 208

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
209
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
210 211
};

212
//! implement group conv by another algo
213 214 215
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final
        : public AlgoBase {
    AlgoBase* m_impl;
216 217
    std::string m_name;

218 219
public:
    AlgoGroupConvGeneral(AlgoBase* impl);
220

221 222 223
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
224

225
    const char* name() const override { return m_name.c_str(); }
226

227 228 229 230

    static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
                                 TensorLayout& grad_pg);
    MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
231 232 233 234 235 236 237
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
        return ret;
    }
238

239 240 241 242 243
    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_impl, ret);
        return ret;
    }
244 245
};

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final
        : public AlgoBase {
public:
    struct AlgoParam {
        int threadblock_m;
        int threadblock_n;
        int threadblock_k;
        int warp_m;
        int warp_n;
        int warp_k;
        int stage;
        std::string to_string() {
            return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m,
                            threadblock_n, threadblock_k, warp_m, warp_n,
                            warp_k, stage);
        }
    };
    AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
            : m_algo_param{algo_param},
              m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
                              m_algo_param.to_string().c_str())} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
    AlgoParam m_algo_param;
    std::string m_name;
};

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
class ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm final
        : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM";
    }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8);
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
};

300
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
301 302
    // defined in cudnn.cpp
    void fill_cudnn_algos();
303 304
    // defined in implicit_gemm_int8_nchw4_dp4a.cpp
    void fill_int8_dp4a_algos();
305

306
    AlgoBase::Mapper m_all_algos_map;
307

308 309
public:
    AlgoPack();
310

311 312 313 314 315 316
    std::vector<AlgoCUDNN> cudnn;
    AlgoMatmul matmul;
    AlgoChanwise chanwise;
    AlgoChanwiseSmall chanwise_small;
    std::vector<AlgoGroupConvGeneral> gconv;
    std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
317
    AlgoBFloat16 bfloat16;
318
    std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
319
    AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod;
320

321
    std::vector<AlgoBase*>
322 323 324
            //! all algorithms
            all_algos,
            //! non-cudnn algos, used for heuristic if cudnn is not supported
325
            non_cudnn_algos, bfloat16_algos, int8_algos;
326 327

    AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
328

329
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
330 331
};

332 333
}  // namespace cuda
}  // namespace megdnn
334 335

// vim: syntax=cpp.doxygen