algo.h 11.4 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/backward_data/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

#pragma once

#include <unordered_map>
16 17 18 19
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"
#include "src/cuda/cudnn_wrapper.h"
20 21 22 23 24 25 26 27 28 29

namespace megdnn {
namespace cuda {

/*!
 * \brief base class for convolution algos
 *
 * All the algo impls should try to support non-contiguous batch dim, for group
 * conv execution.
 */
30 31 32
class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;
33

34 35 36 37 38 39 40 41
public:
    enum class AlgoType : uint32_t {
        CUDA_CUDNN,
        CUDA_MATMUL,
        CUDA_CHANWISE,
        CUDA_CHANWISE_SMALL,
        CUDA_BFLOAT16,
        CUDA_GROUP_CONV_GENERAL,
42 43
        CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8
44 45 46 47 48 49 50 51 52 53 54 55 56
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
    struct SizeArgs {
        HandleImpl* handle;
        CanonizedFilterMeta filter_meta;
        const TensorLayout *diff_layout, *grad_layout, *filter_layout;
        ConvolutionBackwardDataImpl* opr;

        std::string to_string() const;
        void init_desc(convolution::CUDNNBwdDataDescs& desc) const {
            desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
57
        }
58 59 60 61 62 63 64 65 66
        SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
                 const TensorLayout& diff, const TensorLayout& grad);
        SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
                 const CanonizedFilterMeta& filter_meta,
                 const TensorLayout& diff, const TensorLayout& grad);

        convolution::ForwardSizeArgs as_fwd_args() const {
            return {handle, grad_layout, filter_layout, filter_meta,
                    diff_layout};
67
        }
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    };
    struct ExecArgs : public SizeArgs {
        const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
        Workspace workspace;

        ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
                 _megdnn_tensor_in diff, _megdnn_tensor_out grad,
                 _megdnn_workspace workspace);
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;

    bool is_available_wk(const SizeArgs& args, size_t limit) {
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }

85 86
    bool is_available_attribute(
            const SizeArgs& args,
87 88
            const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
            const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
89
            size_t limit = std::numeric_limits<size_t>::max()) {
90 91 92
        return contain_attribute_all(positive_attr) &&
               !contain_attribute_any(negative_attr) &&
               is_available_wk(args, limit);
93 94 95 96 97 98 99 100 101 102 103 104 105
    }

    AlgoBase& check_workspace(const SizeArgs& args,
                              const Workspace& workspace) {
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(req <= workspace.size,
                      "conv bwd data algo %s: "
                      "required workspace %zu bytes, got %zu",
                      name(), req, workspace.size);
        return *this;
    }

    virtual bool is_cudnn() const { return false; }
106 107 108 109
};

class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
    cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
110
    CudnnAlgoPack::Attr m_attr;
111

112 113 114 115 116 117 118
public:
    AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) !=
                      CudnnAlgoPack::conv_bwd_data_algos().end());
        m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum);
    }
119

120 121 122
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
123

124
    const char* name() const override { return m_attr.name.c_str(); }
125 126 127 128 129 130 131
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
        return ret;
    }
132
    cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
133

134 135
    bool is_cudnn() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
136

137 138 139 140 141
    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }
142 143 144
};

//! im2col and matmul, with dilation
145 146 147
class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase {
    template <typename T>
    static void exec_internal(const ExecArgs& args);
148

149 150 151 152
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
153

154 155 156 157
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

158 159
    const char* name() const override { return "MATMUL"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
160 161 162
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
163 164
};

165 166 167 168 169
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
170

171 172
    const char* name() const override { return "CHANNEL_WISE"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
173 174 175
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
176 177
};

178 179 180 181 182
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
183

184 185
    const char* name() const override { return "CHANNEL_WISE_SMALL"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
186 187 188
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
189 190
};

191 192 193 194 195 196
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

197 198 199 200 201 202 203
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

    const char* name() const override {
        return "CONVOLUTION_BACKWARD_DATD_BFLOAT16";
    }
204 205 206 207

    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
208 209 210

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
211
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
212 213
};

214
//! implement group conv by another algo
215 216 217
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final
        : public AlgoBase {
    AlgoBase* m_impl;
218 219
    std::string m_name;

220 221
public:
    AlgoGroupConvGeneral(AlgoBase* impl);
222

223 224 225
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
226

227
    const char* name() const override { return m_name.c_str(); }
228

229 230 231 232

    static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
                                 TensorLayout& grad_pg);
    MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
233 234
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
235
        if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
236 237 238 239
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
        return ret;
    }
240 241
};

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final
        : public AlgoBase {
public:
    struct AlgoParam {
        int threadblock_m;
        int threadblock_n;
        int threadblock_k;
        int warp_m;
        int warp_n;
        int warp_k;
        int stage;
        std::string to_string() {
            return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m,
                            threadblock_n, threadblock_k, warp_m, warp_n,
                            warp_k, stage);
        }
    };
    AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
            : m_algo_param{algo_param},
              m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
                              m_algo_param.to_string().c_str())} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
    AlgoParam m_algo_param;
    std::string m_name;
};

278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
class ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm final
        : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM";
    }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8);
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
};

296
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
297 298
    // defined in cudnn.cpp
    void fill_cudnn_algos();
299 300
    // defined in implicit_gemm_int8_nchw4_dp4a.cpp
    void fill_int8_dp4a_algos();
301

302
    AlgoBase::Mapper m_all_algos_map;
303

304 305
public:
    AlgoPack();
306

307 308 309 310 311 312
    std::vector<AlgoCUDNN> cudnn;
    AlgoMatmul matmul;
    AlgoChanwise chanwise;
    AlgoChanwiseSmall chanwise_small;
    std::vector<AlgoGroupConvGeneral> gconv;
    std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
313
    AlgoBFloat16 bfloat16;
314
    std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
315
    AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod;
316

317
    std::vector<AlgoBase*>
318 319 320
            //! all algorithms
            all_algos,
            //! non-cudnn algos, used for heuristic if cudnn is not supported
321
            non_cudnn_algos, bfloat16_algos, int8_algos;
322 323

    AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
324

325
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
326 327
};

328 329
}  // namespace cuda
}  // namespace megdnn
330 331

// vim: syntax=cpp.doxygen