opr_impl.h 17.0 KB
Newer Older
1 2 3 4 5 6 7 8 9
#pragma once

#include "include/megdnn/thin/function.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/convolution/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"
#include "src/naive/conv_bias/opr_impl.h"

10 11
#include <unordered_map>

12 13 14
namespace megdnn {
namespace fallback {

15 16
/*!
 * \brief get the pack_size according to the format
17 18
 * Note  TODO: when remove format from param,
 *       may using like this "opr::param::format specify"
19
 * */
20
size_t pack_size(param::ConvBias::Format format);
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
/*!
 * \brief fallback conv bias forward impl
 *
 * Note: this operator class serves for multiple purposes:
 *
 *  1. canonizing conv reprs into NCBKernParam and NCBKernSizeParam, and
 *     subclasses should impl by overriding *_ncb methods
 *  2. providing a default impl for group conv by calling ncb_1g* methods
 *  3. providing a conv impl faster than naive under some cases
 *  4. providing a default impl for choosing heuristic algorithm, by using the
 *     first algo that fits the workspace limit
 */
class ConvBiasImpl : public naive::ConvBiasForwardImpl {
public:
    using naive::ConvBiasForwardImpl::ConvBiasForwardImpl;
    using AlgoSelectionStrategy = detail::AlgoSelectionStrategy;
38
    using AlgoDataType = detail::AlgoDataType;
39 40

    //! implemented by exec_with_ncb_kern()
M
Megvii Engine Team 已提交
41 42 43 44
    void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
            _megdnn_tensor_in z, _megdnn_tensor_out dst, const PreprocessedFilter*,
            _megdnn_workspace workspace) override;
45
    bool is_thread_safe() const override { return true; }
46

M
Megvii Engine Team 已提交
47 48 49 50 51
    void exec_preprocess(
            const TensorLayout& src_layout, _megdnn_tensor_in filter,
            _megdnn_tensor_in bias, const TensorLayout& z_layout,
            const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
            _megdnn_workspace workspace) override;
52 53 54 55 56 57

    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
            const TensorLayout& src, const TensorLayout& filter,
            const TensorLayout& bias, const TensorLayout& z,
            const TensorLayout& dst) override;

M
Megvii Engine Team 已提交
58 59 60 61
    size_t get_preprocess_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& filter,
            const TensorLayout& bias, const TensorLayout& z,
            const TensorLayout& dst) override;
62

63
    //! implemented by get_workspace_with_ncb()
M
Megvii Engine Team 已提交
64 65 66 67
    size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& filter,
            const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
            const PreprocessedFilter*) override;
68 69 70 71 72 73

    //! implemented by get_all_algorithms_with_ncb()
    std::vector<Algorithm*> get_all_algorithms(
            const TensorLayout& src, const TensorLayout& filter,
            const TensorLayout& bias, const TensorLayout& z,
            const TensorLayout& dst) override;
74 75 76 77
    std::vector<Algorithm*> get_all_algorithms_safe(
            const TensorLayout& src, const TensorLayout& filter,
            const TensorLayout& bias, const TensorLayout& z,
            const TensorLayout& dst) override;
78 79

    //! implemented by get_algorithm_heuristic_with_ncb()
80 81
    Algorithm* get_algorithm_heuristic(
            const TensorLayout& src, const TensorLayout& filter,
M
Megvii Engine Team 已提交
82 83
            const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
            size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
84
            const AlgoAttribute& negative_attr) override;
85

86 87 88
    //! size param for kernels with non-contiguous batch
    struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam {
        NCBKernSizeParam() = default;
M
Megvii Engine Team 已提交
89 90 91
        NCBKernSizeParam(
                const ConvolutionImpl::NCBKernSizeParam& param, DType bias_type,
                ptrdiff_t bias_bs, BiasMode bias_mode, Param::NonlineMode nonlineMode)
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
                : ConvolutionImpl::NCBKernSizeParam(param),
                  bias_type{bias_type},
                  bias_bs{bias_bs},
                  bias_mode{bias_mode},
                  nonlineMode{nonlineMode} {}
        DType bias_type;
        //! stride for batch of bias
        ptrdiff_t bias_bs;
        BiasMode bias_mode;
        Param::NonlineMode nonlineMode;
    };

    //! memory param for kernels with non-contiguous batch
    struct NCBKernParam : public NCBKernSizeParam {
        NCBKernParam() = default;
107 108 109 110
        RefPtr src_ptr;
        RefPtr filter_ptr;
        RefPtr bias_ptr;
        RefPtr dst_ptr;
111 112 113 114 115 116
        void* workspace_ptr;
        size_t workspace_size;

        template <typename T>
        const T* src() const {
            src_type.assert_is_compatible_ctype<T>();
117
            return static_cast<const T*>(src_ptr.get_ptr());
118
        }
119 120 121 122 123
        //! when format is nchwxx, multi  channel will pack into one
        //! chnannel_pack_id. pack_channel_size is the number of packed channel
        //! when format is nchwxx and channel wise, multi group will pack into
        //! one group_pack_id. group_pack_size is the number of packed group
        //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
124 125 126 127 128 129 130 131 132 133 134 135
        size_t src_offset(
                size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
                size_t group_pack_size = 1, size_t channel_pack_size = 1) const;

        size_t bias_offset(
                size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
                size_t group_pack_size = 1, size_t channel_pack_size = 1) const;

        size_t dst_offset(
                size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
                size_t group_pack_size = 1, size_t channel_pack_size = 1) const;

136
        template <typename T>
M
Megvii Engine Team 已提交
137 138 139
        const T* src(
                size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
                size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
140 141

        template <typename T>
M
Megvii Engine Team 已提交
142 143 144
        const T* bias(
                size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
                size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
145

146
        template <typename T>
M
Megvii Engine Team 已提交
147 148
        T* dst(size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
               size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
149 150 151 152

        //! when format is nchwxx and channel wise, multi group will pack into
        //! one group_pack_id. group_pack_size is the number of packed group
        //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
153 154
        size_t filter_offset(size_t group_pack_id, size_t pack_group_size = 1_z) const;

155
        template <typename T>
M
Megvii Engine Team 已提交
156
        const T* filter(size_t group_pack_id, size_t pack_group_size = 1_z) const;
157

158 159 160
        template <typename T>
        const T* filter() const {
            filter_type.assert_is_compatible_ctype<T>();
161
            return static_cast<const T*>(filter_ptr.get_ptr());
162 163 164 165 166
        }

        template <typename T>
        const T* bias() const {
            bias_type.assert_is_compatible_ctype<T>();
167
            return static_cast<const T*>(bias_ptr.get_ptr());
168 169 170 171 172
        }

        template <typename T>
        T* dst() const {
            dst_type.assert_is_compatible_ctype<T>();
173
            return static_cast<T*>(dst_ptr.get_ptr());
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
        }

        template <typename T>
        T* workspace() const {
            return static_cast<T*>(workspace_ptr);
        }
    };
    /**
     * \brief Kernel run time id, This information is used for getting the work
     * data
     */
    struct NCBKernIndex {
        size_t thread_id = 0;  //!< Thread id
        CpuNDRange ndrange_id;
    };

    //! move arm_common to fallback
    virtual bool is_matmul_quantized_prefer(
192
            const ConvBiasImpl::NCBKernSizeParam& ncb_param) const {
193 194 195 196
        MEGDNN_MARK_USED_VAR(ncb_param);
        return true;
    };

M
Megvii Engine Team 已提交
197 198
    using ncb_kern_t = thin_function<void(
            const NCBKernParam& param, const NCBKernIndex& ncb_index)>;
199 200 201 202 203 204 205
    struct NCBKern {
        ncb_kern_t kern;  //!< conv kern parallel ptr
        CpuNDRange global_size;
    };

    class AlgoBase : public Algorithm {
    public:
M
Megvii Engine Team 已提交
206
        AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; }
207 208 209 210 211 212 213 214 215 216 217

        enum class AlgoType : uint32_t {
            //! fallback
            FB_NAIVE = 1 << 0,
            FB_WINOGRAD_F32,
            FB_WINOGRAD_4X4_F32,
            FB_WINOGRAD_QS8,
            FB_WINOGRAD_8X8_QS8,
            FB_CONV1x1,
            FB_CONV1x1_GEMV,
            FB_IM2COL,
218 219
            GI_COMMON_WINOGRAD_F23_4X4_FP32,
            GI_COMMON_WINOGRAD_F63_FP32,
220
            GI_COMMON_WINOGRAD_F43_FP32,
221
            GI_COMMON_WINOGRAD_F63_4X4_FP32,
222
            GI_COMMON_WINOGRAD_F43_4X4_FP32,
223 224 225
            GI_COMMON_WINOGRAD_F54_FP32,
            GI_COMMON_WINOGRAD_F45_FP32,
            GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32,
226
            GI_COMMON_WINOGRAD_F43_4X4_NCHW44_F32,
227 228
            GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32,
            GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
229
            GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16,
230
            GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16,
231 232 233 234 235
            GI_COMMON_DIRECT_FP32,
            GI_COMMON_DIRECT_STRD1_FP32,
            GI_COMMON_DIRECT_STRD2_FP32,
            GI_COMMON_DIRECT_NCHW44_FP32,
            GI_COMMON_DIRECT_NCHW_NCHW44_FP32,
236
            GI_COMMON_DIRECT_NCHW_NCHW44_AGENT_FP32,
237
            GI_COMMON_CHWNWISE_NCHW44_F32,
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

#if MEGDNN_X86
            X86_DIRECT = 1 << 8,
            X86_DIRECT_STRD2,
            X86_WINOGRAD_F63_8x8_F32,
            X86_WINOGRAD_F23_8x8_F32,
            X86_MKLDNN,
            X86_CHANWISE_AVX2_STRD1_QINT8,
            X86_CHANWISE_AVX2_STRD2_QINT8,
            X86_DIRECT_AVX2_STRD1_INT8,
            X86_DIRECT_AVX2_STRD2_INT8,
            X86_MKLDNN_QINT8,
            X86_MKLDNN_MATMUL_QINT8,
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
            ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8,
            ARM_COMMON_WINOGRAD_F45_FP16,
            ARM_COMMON_WINOGRAD_F63_FP16,
            ARM_COMMON_WINOGRAD_F23_8X8_FP16,
            ARM_COMMON_DIRECT_FP16,
            ARM_COMMON_DIRECT_STRD1_FP16,
258
            ARM_COMMON_CHWNWISE_NCHW88_F16,
259
            ARM_COMMON_DIRECT_NCHW88_FP16,
260
            ARM_COMMON_DIRECT_NCHW_NCHW88_FP16,
261 262 263 264 265 266 267
            ARM_COMMON_DIRECT_STRD1_S8,
            ARM_COMMON_DIRECT_STRD2_S8,
            ARM_COMMON_DIRECT_NCHW44,
            ARM_COMMON_DIRECT_NCHW_NCHW44_S8,
            ARM_COMMON_CHANWISE_STRD1_NCHW44_S8,
            ARM_COMMON_CHANWISE_STRD2_NCHW44_S8,
            ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8,
268 269 270
            //! LARGE for large filter
            ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8,
            ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8,
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
            ARM_COMMON_DIRECT_STRD1_DOT_S8,
            ARM_COMMON_DIRECT_STRD2_DOT_S8,
            ARM_COMMON_DIRECT_NCHW44_DOT_S8,
            ARM_COMMON_WINOGRAD_F23_8X8_S8,
            ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32,
            ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8,
            ARM_COMMON_DIRECT_INT8X8X16,
            ARM_COMMON_DIRECT_NCHW44_INT8X8X16,
            ARM_COMMON_DIRECT_STRD2_INT8X8X16,
            ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16,
            ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16,
            ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16,
            ARM_COMMON_DIRECT_STRD1_QU8,
            ARM_COMMON_DIRECT_STRD2_QU8,
            ARM_COMMON_DIRECT_STRD1_DOT_QU8,
            ARM_COMMON_DIRECT_STRD2_DOT_QU8,
#if MEGDNN_AARCH64
            AARCH64_DIRECT_STRD2_FP16,
            AARCH64_DIRECT_STRD2_FP32,
            AARCH64_MATMUL_S8,
            AARCH64_MATMUL_QU8,
#else
            ARMV7_MATMUL_S8,
            ARMV7_MATMUL_QU8,
295
#endif  // MEGDNN_AARCH64
296 297 298
#endif
        };

299 300
        virtual ~AlgoBase() = default;
        virtual bool usable(
301
                const NCBKernSizeParam& param,
302
                AlgoSelectionStrategy algo_selection_strategy) const = 0;
303
        virtual size_t get_workspace(const NCBKernSizeParam& param) const = 0;
304 305

        virtual SmallVector<NCBKern> dispatch_kerns(
306
                const NCBKernSizeParam& param) const = 0;
307

308
        virtual SmallVector<NCBKern> dispatch_preprocess_kerns(
309
                const NCBKernSizeParam&) const {
310 311 312 313 314
            return {};
        };

        //! get the layouts of weight_prerocess dst
        virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
315
                const NCBKernSizeParam&) const {
316 317 318 319
            return {};
        };

        //! get the workspace when weight_prerocess
320
        virtual size_t get_preprocess_workspace(const NCBKernSizeParam&) const {
321 322 323
            return 0_z;
        };

324 325
        //! Temporarily used to identify whether the matmul algorithm is
        //! is_preferred.
M
Megvii Engine Team 已提交
326
        virtual bool is_preferred(const NCBKernSizeParam&) const { return false; }
327

M
Megvii Engine Team 已提交
328 329 330 331 332
        bool usable_attribute(
                const NCBKernSizeParam& param,
                AlgoSelectionStrategy algo_selection_strategy,
                const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
                const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) const {
333 334
            return contain_attribute_all(positive_attr) &&
                   !contain_attribute_any(negative_attr) &&
335
                   usable(param, algo_selection_strategy);
336
        }
337 338 339

        //! get the type of the algo
        virtual ConvAlgoTypePack get_algo_type() const = 0;
340
        using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
341 342
    };

343
    using AlgoMapper = AlgoBase::Mapper;
344 345 346
    /**
     * \brief get all the algorithm for the opr.
     */
347
    virtual SmallVector<AlgoBase*> get_all_packed_algo();
348

349 350 351 352 353 354 355 356 357 358 359
    /**
     * \brief select algo according to input algo type
     */
    SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type);

    /**
     * \brief suggest algo category according to the param
     */
    virtual SmallVector<AlgoCategory> suggest_algo_category_order(
            const NCBKernSizeParam& param) const;

360
protected:
M
Megvii Engine Team 已提交
361 362
    virtual void exec_with_ncb_kern(
            const NCBKernParam& param, ConvBiasImpl::Algorithm* algo);
363

M
Megvii Engine Team 已提交
364 365
    virtual void exec_preprocess_with_ncb_kern(
            const NCBKernParam& param, Algorithm* algo);
366

367 368 369 370 371
    virtual std::vector<Algorithm*> get_all_algorithms_with_ncb(
            const NCBKernSizeParam& param);

    virtual Algorithm* get_algorithm_heuristic_with_ncb(
            const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
M
Megvii Engine Team 已提交
372
            const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr);
373 374 375 376 377 378

    const char* get_algorithm_set_name() const override;

private:
    class AlgoNaive;
    class AlgoIm2col;
379
    class AlgoConv1x1;
380
    class AlgoConv1x1Gemv;
381 382 383 384
    class AlgoWinogradF32;
    class AlgoWinogradF32_4x4;
    class AlgoWinogradQS8;
    class AlgoWinogradQS8_8x8;
385 386 387

    class AlgoFP32WinogradF23_4x4;
    class AlgoFP32WinogradF63;
388
    class AlgoFP32WinogradF43;
389
    class AlgoFP32WinogradF63_4x4;
390
    class AlgoFP32WinogradF43_4x4;
391 392 393
    class AlgoFP32WinogradF54;
    class AlgoFP32WinogradF45;
    class AlgoFP32WinogradF23_4x4_NCHW44;
394
    class AlgoFP32WinogradF43_4x4_NCHW44;
395 396 397
    class AlgoFP32WinogradF63_4x4_NCHW44;
    class AlgoFP32WinogradF73_4x4_NCHW44;

398
    class AlgoFP16WinogradF23_8x8_NCHW88;
399 400
    class AlgoFP16WinogradF43_8x8_NCHW88;

401 402 403 404
    class AlgoF32Direct;
    class AlgoF32DirectStride1;
    class AlgoF32DirectStride2;
    class AlgoF32DirectNCHWNCHW44;
405
    class AlgoF32DirectNCHWNCHW44AGENT;
406 407 408
    class AlgoF32ChannelWiseNCHW44;
    class AlgoF32DirectNCHW44;

409 410 411 412 413 414 415
    class AlgoPack;

    NCBKernSizeParam m_prev_selected_algo_sizep;
    Algorithm* m_prev_selected_algo = nullptr;

    bool is_naive_algo(ConvBiasImpl::Algorithm* algo);

416
    Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
417

418 419 420 421 422
    //! get algorithm set by user or by heuristic
    Algorithm* get_algorithm(
            const NCBKernSizeParam& param,
            size_t workspace_size = std::numeric_limits<size_t>::max());

423 424 425 426 427 428
    NCBKernSizeParam make_ncb_kern_size_param(
            const TensorLayout& src, const TensorLayout& filter,
            const TensorLayout& bias, const TensorLayout& dst,
            const PreprocessedFilter* preprocessed_filter);

    NCBKernParam make_ncb_kern_param(
M
Megvii Engine Team 已提交
429 430
            _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
            _megdnn_tensor_out dst, _megdnn_workspace workspace,
431
            const PreprocessedFilter* preprocessed_filter);
432 433

    static const AlgoPack& algo_pack();
434 435
};

M
Megvii Engine Team 已提交
436 437
inline bool is_enable_filter_preprocess(const ConvBiasImpl::NCBKernSizeParam& param) {
    return param.preprocessed_filter && param.preprocessed_filter->tensors.size() >= 1;
438
}
439 440 441 442 443 444 445 446 447 448 449 450
}  // namespace fallback
}  // namespace megdnn

//! unpack NCBKernSizeParam into local variables (N, IC, IH, IW, ...)
#define UNPACK_CONV_NCB_KERN_SIZES(_p)                                       \
    auto N = _p.n, IC = _p.filter_meta.icpg, IH = _p.isz[0], IW = _p.isz[1], \
         OC = _p.filter_meta.ocpg, OH = _p.osz[0], OW = _p.osz[1],           \
         FH = _p.filter_meta.spatial[0], FW = _p.filter_meta.spatial[1],     \
         SH = _p.filter_meta.stride[0], SW = _p.filter_meta.stride[1],       \
         PH = _p.filter_meta.padding[0], PW = _p.filter_meta.padding[1]

// vim: syntax=cpp.doxygen