factory.h 22.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/**
 * \file dnn/src/fallback/conv_bias/im2col/factory.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include <unordered_map>
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/conv_bias/opr_impl.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_im2col_factory_make_strategy)

namespace megdnn {
namespace fallback {
namespace im2col {

enum class StrategyType : uint32_t {
    FLOAT = 0,
26 27 28
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    FLOAT_FP16 = 1,
#else
29 30
#if !MEGDNN_DISABLE_FLOAT16
    FLOAT16_FLOAT16 = 2,
31
#endif
32 33 34
#endif
    INT8x8x32 = 3,
    INT8x8x16 = 4,
35 36 37 38
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
    QUINT8x8x32 = 5,
    QUINT8x8x32x8 = 6,
#endif
39 40 41 42 43
    QINT8x8x32 = 7,
    QINT8x8x32x8 = 8
};

struct StrategyHashParam {
44 45
    bool is_xcorr;
    bool is_square;  //! kernel_h == kernel_w, stride_h = stride_w
46 47 48
    size_t block_m;
    size_t block_n;
    size_t block_k;
49 50 51 52 53 54
    size_t kernel;
    size_t stride;

    fallback::ConvBiasImpl::NCBKernSizeParam param;
    param::ConvBias::Format format;
    fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
55 56 57
};

struct StrategyHashParamHash {
58 59 60 61
    uint64_t operator()(const StrategyHashParam& sparam) const {
        constexpr uint64_t base = 1;  //! avoid hashkey is zero
        uint64_t result =
                static_cast<uint64_t>(sparam.param.src_type.enumv()) + base;
62
        result = result ^
63
                 ((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base)
64 65
                  << 3);
        result = result ^
66
                 ((static_cast<uint64_t>(sparam.param.filter_type.enumv()) +
67 68 69
                   base)
                  << 6);
        result = result ^
70
                 ((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base)
71
                  << 9);
72
        result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12);
73
        result = result ^
74 75 76 77 78 79 80 81 82
                 ((static_cast<uint64_t>(sparam.packmode) + base) << 15);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26);
        result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30);
        result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34);
83
        result = result ^
84
                 ((static_cast<uint64_t>(sparam.is_square) + base) << 35);
85
        result = result ^
86
                 ((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36);
87 88 89 90 91
        return result;
    };
};

struct StrategyHashParamEqual {
92 93
    bool operator()(const StrategyHashParam& param1,
                    const StrategyHashParam& param2) const {
94 95 96 97 98 99 100 101 102 103
        bool flags = true;
        flags = param1.param.src_type == param2.param.src_type && flags;
        flags = param1.param.filter_type == param2.param.filter_type && flags;
        flags = param1.param.bias_type == param2.param.bias_type && flags;
        flags = param1.param.dst_type == param2.param.dst_type && flags;
        flags = param1.format == param2.format && flags;
        flags = param1.packmode == param2.packmode && flags;
        flags = param1.block_m == param2.block_m && flags;
        flags = param1.block_n == param2.block_n && flags;
        flags = param1.block_k == param2.block_k && flags;
104 105 106 107
        flags = param1.kernel == param2.kernel && flags;
        flags = param1.stride == param2.stride && flags;
        flags = param1.is_square == param2.is_square && flags;
        flags = param1.is_xcorr == param2.is_xcorr && flags;
108 109 110 111 112 113 114 115 116 117 118 119 120 121
        return flags;
    };
};

class StrategyDelegationStorage {
    std::mutex m_mtx;
    std::unordered_map<StrategyHashParam, std::unique_ptr<StrategyBase>,
                       StrategyHashParamHash, StrategyHashParamEqual>
            map_strategys;

public:
    ~StrategyDelegationStorage() = default;

    template <typename Strategy>
122
    Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
123 124 125 126 127 128 129 130
                  const fallback::ConvBiasImpl::NCBKernSizeParam& param,
                  StrategyType stype);
};

class Factory {
public:
    static StrategyBase* get_im2col_strategy(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
131
            fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
132 133
        static StrategyDelegationStorage storage;
        StrategyType strategytype = get_strategy_type(param);
134
        return storage.get<StrategyBase>(matmul_algo, param, strategytype);
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    }

    static StrategyType get_strategy_type(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
#define cb1(_dt, _post_ctype, _strategytype)                   \
    if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
        return _strategytype;                                  \
    }

#define cb2(_i_src_type, _i_bias_type, _i_dst_type, _src_ctype, _bias_ctype, \
            _dst_ctype, _strategytype)                                       \
    if (param.filter_type.enumv() == param.src_type.enumv() &&               \
        param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv &&          \
        param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) {          \
        return _strategytype;                                                \
    }

        cb1(dt_float32, dt_float32, StrategyType::FLOAT);
153 154 155
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
        cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16);
#else
156 157
#if !MEGDNN_DISABLE_FLOAT16
        cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16);
158
#endif
159 160 161 162 163 164 165 166
#endif

        cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32,
            StrategyType::INT8x8x32);

        cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16,
            StrategyType::INT8x8x16);

167 168 169 170 171 172 173
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
        cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32,
            dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32);

        cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm,
            dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8);
#endif
174 175 176 177 178 179 180 181 182 183
        cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32,
            dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32);

        cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8,
            dt_int8, dt_int32, dt_int8, StrategyType::QINT8x8x32x8);
#undef cb1
#undef cb2
        megdnn_throw("not support datatype in im2col strategy\n");
    }

184 185 186 187 188 189 190 191 192 193 194 195
#define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode,  \
            _midout_tag)                                              \
    MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy,        \
                 midout_iv(_midout_tag)) {                            \
        if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) {    \
            return std::make_unique<                                  \
                    Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
                             _postprocess_mode, PackMode::_packmode,  \
                             FormatMode::_format>>();                 \
        }                                                             \
    }                                                                 \
    MIDOUT_END();                                                     \
196 197
    return {};

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
#define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
            _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode,     \
            _midout_tag)                                                \
    MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy,          \
                 midout_iv(_midout_tag)) {                              \
        if (param.filter_type.enumv() == param.src_type.enumv() &&      \
            param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
            param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
            return std::make_unique<Strategy<                           \
                    _src_ctype, _bias_ctype, _dst_ctype,                \
                    DTypeTrait<_i_bias_type>::ctype,                    \
                    DTypeTrait<_i_dst_type>::ctype, _postprocess_mode,  \
                    PackMode::_packmode, FormatMode::_format>>();       \
        }                                                               \
    }                                                                   \
    MIDOUT_END();                                                       \
214 215 216 217 218
    return {};

    static std::unique_ptr<StrategyBase> make_default_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
219
            StrategyType strategytype) {
220
        MEGDNN_MARK_USED_VAR(matmul_algo);
221
        param::ConvBias::Format format = param.filter_meta.format;
222 223
        switch (strategytype) {
            case StrategyType::FLOAT:
224 225 226 227 228 229 230 231 232 233 234
                if (format == param::ConvBias::Format::NCHW) {
                    cb1(NCHW, DEFAULT, dt_float32, dt_float32,
                        PostprocessMode::FLOAT,
                        "DefaultStrategyType::FLOAT"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb1(NCHW44, DEFAULT, dt_float32, dt_float32,
                        PostprocessMode::FLOAT,
                        "DefaultStrategyTypeNCHW44::FLOAT"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }
235
                break;
236 237 238 239 240 241
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
                    "DefaultStrategyType::FLOAT_FP16"_hash);
                break;
#else
242 243
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
244
                cb1(NCHW, DEFAULT, dt_float16, dt_float16,
245 246 247
                    PostprocessMode::NO_PROCESS,
                    "DefaultStrategyType::FLOAT16_FLOAT16"_hash);
                break;
248
#endif
249 250
#endif
            case StrategyType::INT8x8x32:
251 252 253 254 255 256 257 258 259 260 261 262
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyType::INT8x8x32"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyType::INT8x8x32"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }

263 264 265
                break;

            case StrategyType::INT8x8x16:
266 267
                cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
268 269
                    "DefaultStrategyType::INT8x8x16"_hash);
                break;
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "DefaultStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
                    PostprocessMode::QUANTIZED,
                    "DefaultStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
285
            case StrategyType::QINT8x8x32:
286 287 288 289 290 291 292 293 294 295 296 297 298
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
                        dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                        PostprocessMode::NO_PROCESS,
                        "DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
                        dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }
299 300 301
                break;

            case StrategyType::QINT8x8x32x8:
302 303 304 305 306 307 308 309 310 311 312 313 314
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
                        dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                        PostprocessMode::QUANTIZED,
                        "DefaultStrategyType::QINT8x8x32x8"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
                        dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
                        dt_int32, dt_int8, PostprocessMode::QUANTIZED,
                        "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }
315 316 317 318 319 320 321 322
                break;
        }
        megdnn_throw("error not support strategy type ");
    }

    static std::unique_ptr<StrategyBase> make_nopack_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
323
            StrategyType strategytype) {
324 325 326
        MEGDNN_MARK_USED_VAR(matmul_algo);
        switch (strategytype) {
            case StrategyType::FLOAT:
327 328
                cb1(NCHW, NO_PACK, dt_float32, dt_float32,
                    PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
329
                break;
330 331 332 333 334 335
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT,
                    "NoPackStrategyType::FLOAT_FP16"_hash);
                break;
#else
336 337
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
338 339
                cb1(NCHW, NO_PACK, dt_float16, dt_float16,
                    PostprocessMode::NO_PROCESS,
340 341
                    "NoPackStrategyType::FLOAT16_FLOAT16"_hash);
                break;
342
#endif
343 344
#endif
            case StrategyType::INT8x8x32:
345 346
                cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
347 348 349 350
                    "NoPackStrategyType::INT8x8x32"_hash);
                break;

            case StrategyType::INT8x8x16:
351 352
                cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
353 354 355
                    "NoPackStrategyType::INT8x8x16"_hash);
                break;

356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "NoPackStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
                    PostprocessMode::QUANTIZED,
                    "NoPackStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
371
            case StrategyType::QINT8x8x32:
372
                cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
373 374 375 376 377 378
                    dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "NoPackStrategyType::QINT8x8x32"_hash);
                break;

            case StrategyType::QINT8x8x32x8:
379
                cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
380 381 382 383 384 385 386 387 388 389 390
                    dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                    PostprocessMode::QUANTIZED,
                    "NoPackStrategyType::QINT8x8x32x8"_hash);
                break;
        }
        megdnn_throw("error not support strategy type ");
    }

    static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
391
            StrategyType strategytype) {
392 393 394
        MEGDNN_MARK_USED_VAR(matmul_algo);
        switch (strategytype) {
            case StrategyType::FLOAT:
395 396
                cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32,
                    PostprocessMode::FLOAT,
397 398
                    "OnlyPackaStrategyType::FLOAT"_hash);
                break;
399 400 401 402 403 404 405
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, ONLY_PACKA, dt_float16, __fp16,
                    PostprocessMode::FLOAT,
                    "OnlyPackaStrategyType::FLOAT_FP16"_hash);
                break;
#else
406 407
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
408
                cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16,
409 410 411
                    PostprocessMode::NO_PROCESS,
                    "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash);
                break;
412
#endif
413 414
#endif
            case StrategyType::INT8x8x32:
415 416
                cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
417 418 419 420
                    "OnlyPackaStrategyType::INT8x8x32"_hash);
                break;

            case StrategyType::INT8x8x16:
421 422
                cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
423 424 425
                    "OnlyPackaStrategyType::INT8x8x16"_hash);
                break;

426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
                    dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                    "OnlyPackaStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
                    dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8,
                    dt_int32, dt_uint8, PostprocessMode::QUANTIZED,
                    "OnlyPackaStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
441
            case StrategyType::QINT8x8x32:
442
                cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
443 444 445 446 447 448
                    dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "OnlyPackaStrategyType::QINT8x8x32"_hash);
                break;

            case StrategyType::QINT8x8x32x8:
449
                cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
                    dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                    PostprocessMode::QUANTIZED,
                    "OnlyPackaStrategyType::QINT8x8x32x8"_hash);
                break;
        }
        megdnn_throw("error not support strategy type ");
    }

#undef cb1
#undef cb2

    static std::unique_ptr<StrategyBase> make_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            fallback::MatrixMulImpl::AlgoBase::PackMode packmode,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
            StrategyType stype) {
        switch (packmode) {
            case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
468
                return make_default_strategy(matmul_algo, param, stype);
469 470
                break;
            case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
471
                return make_onlypacka_strategy(matmul_algo, param, stype);
472 473
                break;
            case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
474
                return make_nopack_strategy(matmul_algo, param, stype);
475 476 477 478 479 480 481
                break;
            default:
                megdnn_throw(
                        "not support packmode except default onlypackA "
                        "nopack");
                break;
        }
482
        megdnn_throw("factory make Strategy error please check your code");
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
    }
};

template <typename Strategy>
Strategy* StrategyDelegationStorage::get(
        fallback::MatrixMulImpl::AlgoBase* matmul_algo,
        const fallback::ConvBiasImpl::NCBKernSizeParam& param,
        StrategyType stype) {
    fallback::MatrixMulImpl::AlgoBase::PackMode packmode =
            matmul_algo->packmode();
    //! nopack mode block_m block_n block_k is zero
    size_t block_m = 0, block_n = 0, block_k = 0;
    if (packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::DEFAULT ||
        packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
        block_m = matmul_algo->get_inner_block_size().m;
        block_n = matmul_algo->get_inner_block_size().n;
        block_k = matmul_algo->get_inner_block_size().k;
    }
    StrategyHashParam sparam;
    sparam.param = param;
503
    sparam.format = param.filter_meta.format;
504 505 506 507
    sparam.packmode = packmode;
    sparam.block_m = block_m;
    sparam.block_n = block_n;
    sparam.block_k = block_k;
508 509 510 511 512 513
    sparam.kernel = param.filter_meta.spatial[0];
    sparam.stride = param.filter_meta.stride[0];
    sparam.is_square =
            param.filter_meta.spatial[0] == param.filter_meta.spatial[0];
    sparam.is_xcorr = param.filter_meta.should_flip;
    MEGDNN_LOCK_GUARD(m_mtx);
514
    if (map_strategys.find(sparam) == map_strategys.end()) {
515 516
        auto strategy =
                Factory::make_strategy(matmul_algo, packmode, param, stype);
517 518 519 520 521 522 523
        map_strategys[sparam] = std::move(strategy);
    }
    return static_cast<Strategy*>(map_strategys[sparam].get());
}
}  // namespace im2col
}  // namespace fallback
}  // namespace megdnn
524 525

// vim: syntax=cpp.doxygen