factory.h 25.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/**
 * \file dnn/src/fallback/conv_bias/im2col/factory.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include <unordered_map>
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/conv_bias/opr_impl.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_im2col_factory_make_strategy)

namespace megdnn {
namespace fallback {
namespace im2col {

enum class StrategyType : uint32_t {
    FLOAT = 0,
26 27 28
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    FLOAT_FP16 = 1,
#else
29 30
#if !MEGDNN_DISABLE_FLOAT16
    FLOAT16_FLOAT16 = 2,
31
#endif
32 33 34
#endif
    INT8x8x32 = 3,
    INT8x8x16 = 4,
35 36 37 38
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
    QUINT8x8x32 = 5,
    QUINT8x8x32x8 = 6,
#endif
39 40 41 42 43
    QINT8x8x32 = 7,
    QINT8x8x32x8 = 8
};

struct StrategyHashParam {
44 45
    bool is_xcorr;
    bool is_square;  //! kernel_h == kernel_w, stride_h = stride_w
46 47 48
    size_t block_m;
    size_t block_n;
    size_t block_k;
49 50 51 52 53 54
    size_t kernel;
    size_t stride;

    fallback::ConvBiasImpl::NCBKernSizeParam param;
    param::ConvBias::Format format;
    fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
55 56 57
};

struct StrategyHashParamHash {
58 59 60 61
    uint64_t operator()(const StrategyHashParam& sparam) const {
        constexpr uint64_t base = 1;  //! avoid hashkey is zero
        uint64_t result =
                static_cast<uint64_t>(sparam.param.src_type.enumv()) + base;
62
        result = result ^
63
                 ((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base)
64 65
                  << 3);
        result = result ^
66
                 ((static_cast<uint64_t>(sparam.param.filter_type.enumv()) +
67 68 69
                   base)
                  << 6);
        result = result ^
70
                 ((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base)
71
                  << 9);
72
        result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12);
73
        result = result ^
74 75 76 77 78 79 80 81 82
                 ((static_cast<uint64_t>(sparam.packmode) + base) << 15);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26);
        result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30);
        result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34);
83
        result = result ^
84
                 ((static_cast<uint64_t>(sparam.is_square) + base) << 35);
85
        result = result ^
86
                 ((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36);
87 88 89 90 91
        return result;
    };
};

struct StrategyHashParamEqual {
92 93
    bool operator()(const StrategyHashParam& param1,
                    const StrategyHashParam& param2) const {
94 95 96 97 98 99 100 101 102 103
        bool flags = true;
        flags = param1.param.src_type == param2.param.src_type && flags;
        flags = param1.param.filter_type == param2.param.filter_type && flags;
        flags = param1.param.bias_type == param2.param.bias_type && flags;
        flags = param1.param.dst_type == param2.param.dst_type && flags;
        flags = param1.format == param2.format && flags;
        flags = param1.packmode == param2.packmode && flags;
        flags = param1.block_m == param2.block_m && flags;
        flags = param1.block_n == param2.block_n && flags;
        flags = param1.block_k == param2.block_k && flags;
104 105 106 107
        flags = param1.kernel == param2.kernel && flags;
        flags = param1.stride == param2.stride && flags;
        flags = param1.is_square == param2.is_square && flags;
        flags = param1.is_xcorr == param2.is_xcorr && flags;
108 109 110 111 112 113 114 115 116 117 118 119 120 121
        return flags;
    };
};

class StrategyDelegationStorage {
    std::mutex m_mtx;
    std::unordered_map<StrategyHashParam, std::unique_ptr<StrategyBase>,
                       StrategyHashParamHash, StrategyHashParamEqual>
            map_strategys;

public:
    ~StrategyDelegationStorage() = default;

    template <typename Strategy>
122
    Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
123 124 125 126 127 128 129 130
                  const fallback::ConvBiasImpl::NCBKernSizeParam& param,
                  StrategyType stype);
};

class Factory {
public:
    static StrategyBase* get_im2col_strategy(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
131
            fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
132 133
        static StrategyDelegationStorage storage;
        StrategyType strategytype = get_strategy_type(param);
134
        return storage.get<StrategyBase>(matmul_algo, param, strategytype);
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    }

    static StrategyType get_strategy_type(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
#define cb1(_dt, _post_ctype, _strategytype)                   \
    if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
        return _strategytype;                                  \
    }

#define cb2(_i_src_type, _i_bias_type, _i_dst_type, _src_ctype, _bias_ctype, \
            _dst_ctype, _strategytype)                                       \
    if (param.filter_type.enumv() == param.src_type.enumv() &&               \
        param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv &&          \
        param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) {          \
        return _strategytype;                                                \
    }

        cb1(dt_float32, dt_float32, StrategyType::FLOAT);
153 154 155
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
        cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16);
#else
156 157
#if !MEGDNN_DISABLE_FLOAT16
        cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16);
158
#endif
159 160 161 162 163 164 165 166
#endif

        cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32,
            StrategyType::INT8x8x32);

        cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16,
            StrategyType::INT8x8x16);

167 168 169 170 171 172 173
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
        cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32,
            dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32);

        cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm,
            dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8);
#endif
174 175 176 177 178 179 180 181 182 183
        cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32,
            dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32);

        cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8,
            dt_int8, dt_int32, dt_int8, StrategyType::QINT8x8x32x8);
#undef cb1
#undef cb2
        megdnn_throw("not support datatype in im2col strategy\n");
    }

184 185 186 187 188 189 190 191 192 193 194 195
#define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode,  \
            _midout_tag)                                              \
    MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy,        \
                 midout_iv(_midout_tag)) {                            \
        if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) {    \
            return std::make_unique<                                  \
                    Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
                             _postprocess_mode, PackMode::_packmode,  \
                             FormatMode::_format>>();                 \
        }                                                             \
    }                                                                 \
    MIDOUT_END();                                                     \
196 197
    return {};

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
#define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
            _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode,     \
            _midout_tag)                                                \
    MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy,          \
                 midout_iv(_midout_tag)) {                              \
        if (param.filter_type.enumv() == param.src_type.enumv() &&      \
            param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
            param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
            return std::make_unique<Strategy<                           \
                    _src_ctype, _bias_ctype, _dst_ctype,                \
                    DTypeTrait<_i_bias_type>::ctype,                    \
                    DTypeTrait<_i_dst_type>::ctype, _postprocess_mode,  \
                    PackMode::_packmode, FormatMode::_format>>();       \
        }                                                               \
    }                                                                   \
    MIDOUT_END();                                                       \
214 215 216 217 218
    return {};

    static std::unique_ptr<StrategyBase> make_default_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
219
            StrategyType strategytype) {
220
        MEGDNN_MARK_USED_VAR(matmul_algo);
221
        param::ConvBias::Format format = param.filter_meta.format;
222 223
        switch (strategytype) {
            case StrategyType::FLOAT:
224 225 226 227 228
                if (format == param::ConvBias::Format::NCHW) {
                    cb1(NCHW, DEFAULT, dt_float32, dt_float32,
                        PostprocessMode::FLOAT,
                        "DefaultStrategyType::FLOAT"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
229

230
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
231
                    auto matmul_block = matmul_algo->get_inner_block_size();
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
                        //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse
                    if ((matmul_block.m == 8 || matmul_block.m == 4) &&
                        matmul_block.n == 12 && matmul_block.k == 1 &&
                        param.filter_meta.spatial[0] == 3 &&
                        param.filter_meta.spatial[1] == 3 &&
                        param.filter_meta.stride[0] == 2 &&
                        param.filter_meta.stride[1] == 2 &&
                        !param.filter_meta.should_flip) {
                        MIDOUT_BEGIN(
                                megdnn_fallback_im2col_factory_make_strategy,
                                midout_iv(
                                        "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) {
                            return std::make_unique<
                                    StrategyFuseXx12x1Nchw44K3x3S2<
                                            float, float,
                                            PostprocessMode::FLOAT>>();
248
                        }
249 250 251
                        MIDOUT_END();
                        return {};
                    }
252 253
#endif

254 255 256 257
                    cb1(NCHW44, DEFAULT, dt_float32, dt_float32,
                        PostprocessMode::FLOAT,
                        "DefaultStrategyTypeNCHW44::FLOAT"_hash);
                } else {
258 259 260 261 262
                    megdnn_throw(
                            ssprintf("Current only support layout "
                                     "NCHW44/NCHW for im2col "
                                     "algo, but got %d\n",
                                     uint32_t(format)));
263
                }
264
                break;
265 266 267 268 269 270
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
                    "DefaultStrategyType::FLOAT_FP16"_hash);
                break;
#else
271 272
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
273
                cb1(NCHW, DEFAULT, dt_float16, dt_float16,
274 275 276
                    PostprocessMode::NO_PROCESS,
                    "DefaultStrategyType::FLOAT16_FLOAT16"_hash);
                break;
277
#endif
278 279
#endif
            case StrategyType::INT8x8x32:
280 281 282 283
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyType::INT8x8x32"_hash);
284 285
                } else if (format == param::ConvBias::Format::NCHW44 ||
                           format == param::ConvBias::Format::NCHW44_DOT) {
286 287 288 289
                    cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyType::INT8x8x32"_hash);
                } else {
290 291 292 293 294
                    megdnn_throw(
                            ssprintf("Current only support layout "
                                     "NCHW44/NCHW/NCHW_DOT for im2col "
                                     "algo, but got %d\n",
                                     uint32_t(format)));
295 296
                }

297 298 299
                break;

            case StrategyType::INT8x8x16:
300 301
                cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
302 303
                    "DefaultStrategyType::INT8x8x16"_hash);
                break;
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "DefaultStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
                    PostprocessMode::QUANTIZED,
                    "DefaultStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
319
            case StrategyType::QINT8x8x32:
320 321 322 323 324
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
                        dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                        PostprocessMode::NO_PROCESS,
                        "DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
325 326
                } else if (format == param::ConvBias::Format::NCHW44 ||
                           format == param::ConvBias::Format::NCHW44_DOT) {
327 328 329 330 331
                    cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
                        dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
                } else {
332 333 334 335 336
                    megdnn_throw(
                            ssprintf("Current only support layout "
                                     "NCHW44/NCHW/NCHW_DOT for im2col "
                                     "algo, but got %d\n",
                                     uint32_t(format)));
337
                }
338 339 340
                break;

            case StrategyType::QINT8x8x32x8:
341 342 343 344 345
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
                        dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                        PostprocessMode::QUANTIZED,
                        "DefaultStrategyType::QINT8x8x32x8"_hash);
346 347
                } else if (format == param::ConvBias::Format::NCHW44 ||
                           format == param::ConvBias::Format::NCHW44_DOT) {
348 349
                    if (format == param::ConvBias::Format::NCHW44) {
                        //! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse
350 351
#if MEGDNN_AARCH64
                        auto matmul_block = matmul_algo->get_inner_block_size();
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
                        if (matmul_block.m == 4 && matmul_block.n == 4 &&
                            matmul_block.k == 16 &&
                            param.filter_meta.spatial[0] == 3 &&
                            param.filter_meta.spatial[1] == 3 &&
                            param.filter_meta.stride[0] == 1 &&
                            param.filter_meta.stride[1] == 1 &&
                            !param.filter_meta.should_flip) {
                            MIDOUT_BEGIN(
                                    megdnn_fallback_im2col_factory_make_strategy,
                                    midout_iv(
                                            "DefaultStrategyType::INT8x8x32_4x4x16"_hash)) {
                                return std::make_unique<
                                        StrategyFuse4x4x16Nchw44<
                                                dt_qint32, dt_qint8,
                                                PostprocessMode::QUANTIZED>>();
                            }
                            MIDOUT_END();
                            return {};
                        }
371
#endif
372
                    } else {
373 374
#if MEGDNN_AARCH64
                        auto matmul_block = matmul_algo->get_inner_block_size();
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
                        //! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse
                        if (matmul_block.m == 8 && matmul_block.n == 12 &&
                            matmul_block.k == 4 &&
                            param.filter_meta.spatial[0] == 3 &&
                            param.filter_meta.spatial[1] == 3 &&
                            param.filter_meta.stride[0] == 1 &&
                            param.filter_meta.stride[1] == 1 &&
                            !param.filter_meta.should_flip) {
                            MIDOUT_BEGIN(
                                    megdnn_fallback_im2col_factory_make_strategy,
                                    midout_iv(
                                            "DefaultStrategyType::INT8x8x32_8x12x4"_hash)) {
                                return std::make_unique<
                                        StrategyFuse8x12x4Nchw44Dot<
                                                dt_qint32, dt_qint8,
                                                PostprocessMode::QUANTIZED>>();
                            }
                            MIDOUT_END();
                            return {};
                        }
#endif
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
#if MEGDNN_ARMV7
                        auto matmul_block = matmul_algo->get_inner_block_size();
                        if (matmul_block.m == 8 && matmul_block.n == 4 &&
                            matmul_block.k == 4 &&
                            param.filter_meta.spatial[0] == 3 &&
                            param.filter_meta.spatial[1] == 3 &&
                            param.filter_meta.stride[0] == 2 &&
                            param.filter_meta.stride[1] == 2 &&
                            !param.filter_meta.should_flip) {
                            MIDOUT_BEGIN(
                                    megdnn_fallback_im2col_factory_make_strategy,
                                    midout_iv(
                                            "DefaultStrategyType::INT8x8x32_8x4x4_s2"_hash)) {
                                return std::make_unique<
                                        StrategyFuse8x4x4Nchw44DotK3x3S2<
                                                dt_qint32, dt_qint8,
                                                PostprocessMode::QUANTIZED>>();
                            }
                            MIDOUT_END();
                            return {};
                        }
#endif
                    }
419 420 421 422 423
                    cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
                        dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
                        dt_int32, dt_int8, PostprocessMode::QUANTIZED,
                        "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
                } else {
424 425 426 427
                    megdnn_throw(ssprintf("Current only support layout "
                                          "NCHW44/NCHW/NCHW_DOT for im2col "
                                          "algo, but got %d\n",
                                          uint32_t(format)));
428
                }
429 430
                break;
        }
431 432
        megdnn_throw(ssprintf("Unsupported strategy type %u in default mode",
                              uint32_t(strategytype)));
433 434 435 436 437
    }

    static std::unique_ptr<StrategyBase> make_nopack_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
438
            StrategyType strategytype) {
439 440 441
        MEGDNN_MARK_USED_VAR(matmul_algo);
        switch (strategytype) {
            case StrategyType::FLOAT:
442 443
                cb1(NCHW, NO_PACK, dt_float32, dt_float32,
                    PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
444
                break;
445 446
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#else
447 448
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
449 450
                cb1(NCHW, NO_PACK, dt_float16, dt_float16,
                    PostprocessMode::NO_PROCESS,
451 452
                    "NoPackStrategyType::FLOAT16_FLOAT16"_hash);
                break;
453
#endif
454 455
#endif
            case StrategyType::INT8x8x16:
456 457
                cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
458 459
                    "NoPackStrategyType::INT8x8x16"_hash);
                break;
460 461 462 463
            case StrategyType::INT8x8x32:
                cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                    "NoPackStrategyType::INT8x8x32"_hash);
464
                break;
465 466 467 468
            default:
                megdnn_throw(
                        ssprintf("Unsupported strategy type %u in no_pack mode",
                                 uint32_t(strategytype)));
469 470
                break;
        }
471 472
        megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode",
                              uint32_t(strategytype)));
473 474 475 476 477
    }

    static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
478
            StrategyType strategytype) {
479 480 481
        MEGDNN_MARK_USED_VAR(matmul_algo);
        switch (strategytype) {
            case StrategyType::FLOAT:
482 483
                cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32,
                    PostprocessMode::FLOAT,
484 485
                    "OnlyPackaStrategyType::FLOAT"_hash);
                break;
486 487 488 489
            default:
                megdnn_throw(ssprintf(
                        "Unsupported strategy type %u in onlypacka mode",
                        uint32_t(strategytype)));
490 491
                break;
        }
492 493
        megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode",
                              uint32_t(strategytype)));
494 495 496 497 498 499 500 501 502 503 504 505
    }

#undef cb1
#undef cb2

    static std::unique_ptr<StrategyBase> make_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            fallback::MatrixMulImpl::AlgoBase::PackMode packmode,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
            StrategyType stype) {
        switch (packmode) {
            case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
506
                return make_default_strategy(matmul_algo, param, stype);
507 508
                break;
            case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
509
                return make_onlypacka_strategy(matmul_algo, param, stype);
510 511
                break;
            case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
512
                return make_nopack_strategy(matmul_algo, param, stype);
513 514 515 516 517 518 519
                break;
            default:
                megdnn_throw(
                        "not support packmode except default onlypackA "
                        "nopack");
                break;
        }
520
        megdnn_throw("factory make Strategy error please check your code");
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
    }
};

template <typename Strategy>
Strategy* StrategyDelegationStorage::get(
        fallback::MatrixMulImpl::AlgoBase* matmul_algo,
        const fallback::ConvBiasImpl::NCBKernSizeParam& param,
        StrategyType stype) {
    fallback::MatrixMulImpl::AlgoBase::PackMode packmode =
            matmul_algo->packmode();
    //! nopack mode block_m block_n block_k is zero
    size_t block_m = 0, block_n = 0, block_k = 0;
    if (packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::DEFAULT ||
        packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
        block_m = matmul_algo->get_inner_block_size().m;
        block_n = matmul_algo->get_inner_block_size().n;
        block_k = matmul_algo->get_inner_block_size().k;
    }
    StrategyHashParam sparam;
    sparam.param = param;
541
    sparam.format = param.filter_meta.format;
542 543 544 545
    sparam.packmode = packmode;
    sparam.block_m = block_m;
    sparam.block_n = block_n;
    sparam.block_k = block_k;
546 547 548
    sparam.kernel = param.filter_meta.spatial[0];
    sparam.stride = param.filter_meta.stride[0];
    sparam.is_square =
549
            param.filter_meta.spatial[0] == param.filter_meta.spatial[1];
550 551
    sparam.is_xcorr = param.filter_meta.should_flip;
    MEGDNN_LOCK_GUARD(m_mtx);
552
    if (map_strategys.find(sparam) == map_strategys.end()) {
553 554
        auto strategy =
                Factory::make_strategy(matmul_algo, packmode, param, stype);
555 556 557 558 559 560 561
        map_strategys[sparam] = std::move(strategy);
    }
    return static_cast<Strategy*>(map_strategys[sparam].get());
}
}  // namespace im2col
}  // namespace fallback
}  // namespace megdnn
562 563

// vim: syntax=cpp.doxygen