factory.h 22.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/**
 * \file dnn/src/fallback/conv_bias/im2col/factory.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include <unordered_map>
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/conv_bias/opr_impl.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_im2col_factory_make_strategy)

namespace megdnn {
namespace fallback {
namespace im2col {

enum class StrategyType : uint32_t {
    FLOAT = 0,
26 27 28
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    FLOAT_FP16 = 1,
#else
29 30
#if !MEGDNN_DISABLE_FLOAT16
    FLOAT16_FLOAT16 = 2,
31
#endif
32 33 34
#endif
    INT8x8x32 = 3,
    INT8x8x16 = 4,
35 36 37 38
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
    QUINT8x8x32 = 5,
    QUINT8x8x32x8 = 6,
#endif
39 40 41 42 43
    QINT8x8x32 = 7,
    QINT8x8x32x8 = 8
};

struct StrategyHashParam {
44 45
    bool is_xcorr;
    bool is_square;  //! kernel_h == kernel_w, stride_h = stride_w
46 47 48
    size_t block_m;
    size_t block_n;
    size_t block_k;
49 50 51 52 53 54
    size_t kernel;
    size_t stride;

    fallback::ConvBiasImpl::NCBKernSizeParam param;
    param::ConvBias::Format format;
    fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
55 56 57
};

struct StrategyHashParamHash {
58 59 60 61
    uint64_t operator()(const StrategyHashParam& sparam) const {
        constexpr uint64_t base = 1;  //! avoid hashkey is zero
        uint64_t result =
                static_cast<uint64_t>(sparam.param.src_type.enumv()) + base;
62
        result = result ^
63
                 ((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base)
64 65
                  << 3);
        result = result ^
66
                 ((static_cast<uint64_t>(sparam.param.filter_type.enumv()) +
67 68 69
                   base)
                  << 6);
        result = result ^
70
                 ((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base)
71
                  << 9);
72
        result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12);
73
        result = result ^
74 75 76 77 78 79 80 81 82
                 ((static_cast<uint64_t>(sparam.packmode) + base) << 15);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22);
        result =
                result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26);
        result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30);
        result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34);
83
        result = result ^
84
                 ((static_cast<uint64_t>(sparam.is_square) + base) << 35);
85
        result = result ^
86
                 ((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36);
87 88 89 90 91
        return result;
    };
};

struct StrategyHashParamEqual {
92 93
    bool operator()(const StrategyHashParam& param1,
                    const StrategyHashParam& param2) const {
94 95 96 97 98 99 100 101 102 103
        bool flags = true;
        flags = param1.param.src_type == param2.param.src_type && flags;
        flags = param1.param.filter_type == param2.param.filter_type && flags;
        flags = param1.param.bias_type == param2.param.bias_type && flags;
        flags = param1.param.dst_type == param2.param.dst_type && flags;
        flags = param1.format == param2.format && flags;
        flags = param1.packmode == param2.packmode && flags;
        flags = param1.block_m == param2.block_m && flags;
        flags = param1.block_n == param2.block_n && flags;
        flags = param1.block_k == param2.block_k && flags;
104 105 106 107
        flags = param1.kernel == param2.kernel && flags;
        flags = param1.stride == param2.stride && flags;
        flags = param1.is_square == param2.is_square && flags;
        flags = param1.is_xcorr == param2.is_xcorr && flags;
108 109 110 111 112 113 114 115 116 117 118 119 120 121
        return flags;
    };
};

class StrategyDelegationStorage {
    std::mutex m_mtx;
    std::unordered_map<StrategyHashParam, std::unique_ptr<StrategyBase>,
                       StrategyHashParamHash, StrategyHashParamEqual>
            map_strategys;

public:
    ~StrategyDelegationStorage() = default;

    template <typename Strategy>
122
    Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
123 124 125 126 127 128 129 130
                  const fallback::ConvBiasImpl::NCBKernSizeParam& param,
                  StrategyType stype);
};

class Factory {
public:
    static StrategyBase* get_im2col_strategy(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
131
            fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
132 133
        static StrategyDelegationStorage storage;
        StrategyType strategytype = get_strategy_type(param);
134
        return storage.get<StrategyBase>(matmul_algo, param, strategytype);
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    }

    static StrategyType get_strategy_type(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
#define cb1(_dt, _post_ctype, _strategytype)                   \
    if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
        return _strategytype;                                  \
    }

#define cb2(_i_src_type, _i_bias_type, _i_dst_type, _src_ctype, _bias_ctype, \
            _dst_ctype, _strategytype)                                       \
    if (param.filter_type.enumv() == param.src_type.enumv() &&               \
        param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv &&          \
        param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) {          \
        return _strategytype;                                                \
    }

        cb1(dt_float32, dt_float32, StrategyType::FLOAT);
153 154 155
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
        cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16);
#else
156 157
#if !MEGDNN_DISABLE_FLOAT16
        cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16);
158
#endif
159 160 161 162 163 164 165 166
#endif

        cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32,
            StrategyType::INT8x8x32);

        cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16,
            StrategyType::INT8x8x16);

167 168 169 170 171 172 173
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
        cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32,
            dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32);

        cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm,
            dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8);
#endif
174 175 176 177 178 179 180 181 182 183
        cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32,
            dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32);

        cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8,
            dt_int8, dt_int32, dt_int8, StrategyType::QINT8x8x32x8);
#undef cb1
#undef cb2
        megdnn_throw("not support datatype in im2col strategy\n");
    }

184 185 186 187 188 189 190 191 192 193 194 195
#define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode,  \
            _midout_tag)                                              \
    MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy,        \
                 midout_iv(_midout_tag)) {                            \
        if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) {    \
            return std::make_unique<                                  \
                    Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
                             _postprocess_mode, PackMode::_packmode,  \
                             FormatMode::_format>>();                 \
        }                                                             \
    }                                                                 \
    MIDOUT_END();                                                     \
196 197
    return {};

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
#define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
            _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode,     \
            _midout_tag)                                                \
    MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy,          \
                 midout_iv(_midout_tag)) {                              \
        if (param.filter_type.enumv() == param.src_type.enumv() &&      \
            param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
            param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
            return std::make_unique<Strategy<                           \
                    _src_ctype, _bias_ctype, _dst_ctype,                \
                    DTypeTrait<_i_bias_type>::ctype,                    \
                    DTypeTrait<_i_dst_type>::ctype, _postprocess_mode,  \
                    PackMode::_packmode, FormatMode::_format>>();       \
        }                                                               \
    }                                                                   \
    MIDOUT_END();                                                       \
214 215 216 217 218
    return {};

    static std::unique_ptr<StrategyBase> make_default_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
219
            StrategyType strategytype) {
220
        MEGDNN_MARK_USED_VAR(matmul_algo);
221
        param::ConvBias::Format format = param.filter_meta.format;
222 223
        switch (strategytype) {
            case StrategyType::FLOAT:
224 225
                cb1(NCHW, DEFAULT, dt_float32, dt_float32,
                    PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash);
226
                break;
227 228 229 230 231 232
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
                    "DefaultStrategyType::FLOAT_FP16"_hash);
                break;
#else
233 234
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
235
                cb1(NCHW, DEFAULT, dt_float16, dt_float16,
236 237 238
                    PostprocessMode::NO_PROCESS,
                    "DefaultStrategyType::FLOAT16_FLOAT16"_hash);
                break;
239
#endif
240 241
#endif
            case StrategyType::INT8x8x32:
242 243 244 245 246 247 248 249 250 251 252 253
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyType::INT8x8x32"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyType::INT8x8x32"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }

254 255 256
                break;

            case StrategyType::INT8x8x16:
257 258
                cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
259 260
                    "DefaultStrategyType::INT8x8x16"_hash);
                break;
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "DefaultStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
                    PostprocessMode::QUANTIZED,
                    "DefaultStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
276
            case StrategyType::QINT8x8x32:
277 278 279 280 281 282 283 284 285 286 287 288 289
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
                        dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                        PostprocessMode::NO_PROCESS,
                        "DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
                        dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
                        dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                        "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }
290 291 292
                break;

            case StrategyType::QINT8x8x32x8:
293 294 295 296 297 298 299 300 301 302 303 304 305
                if (format == param::ConvBias::Format::NCHW) {
                    cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
                        dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                        PostprocessMode::QUANTIZED,
                        "DefaultStrategyType::QINT8x8x32x8"_hash);
                } else if (format == param::ConvBias::Format::NCHW44) {
                    cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
                        dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
                        dt_int32, dt_int8, PostprocessMode::QUANTIZED,
                        "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
                } else {
                    megdnn_throw("not support format except nchw44 and nchw\n");
                }
306 307 308 309 310 311 312 313
                break;
        }
        megdnn_throw("error not support strategy type ");
    }

    static std::unique_ptr<StrategyBase> make_nopack_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
314
            StrategyType strategytype) {
315 316 317
        MEGDNN_MARK_USED_VAR(matmul_algo);
        switch (strategytype) {
            case StrategyType::FLOAT:
318 319
                cb1(NCHW, NO_PACK, dt_float32, dt_float32,
                    PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
320
                break;
321 322 323 324 325 326
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT,
                    "NoPackStrategyType::FLOAT_FP16"_hash);
                break;
#else
327 328
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
329 330
                cb1(NCHW, NO_PACK, dt_float16, dt_float16,
                    PostprocessMode::NO_PROCESS,
331 332
                    "NoPackStrategyType::FLOAT16_FLOAT16"_hash);
                break;
333
#endif
334 335
#endif
            case StrategyType::INT8x8x32:
336 337
                cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
338 339 340 341
                    "NoPackStrategyType::INT8x8x32"_hash);
                break;

            case StrategyType::INT8x8x16:
342 343
                cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
344 345 346
                    "NoPackStrategyType::INT8x8x16"_hash);
                break;

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "NoPackStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
                    dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
                    PostprocessMode::QUANTIZED,
                    "NoPackStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
362
            case StrategyType::QINT8x8x32:
363
                cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
364 365 366 367 368 369
                    dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "NoPackStrategyType::QINT8x8x32"_hash);
                break;

            case StrategyType::QINT8x8x32x8:
370
                cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
371 372 373 374 375 376 377 378 379 380 381
                    dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                    PostprocessMode::QUANTIZED,
                    "NoPackStrategyType::QINT8x8x32x8"_hash);
                break;
        }
        megdnn_throw("error not support strategy type ");
    }

    static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
382
            StrategyType strategytype) {
383 384 385
        MEGDNN_MARK_USED_VAR(matmul_algo);
        switch (strategytype) {
            case StrategyType::FLOAT:
386 387
                cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32,
                    PostprocessMode::FLOAT,
388 389
                    "OnlyPackaStrategyType::FLOAT"_hash);
                break;
390 391 392 393 394 395 396
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
            case StrategyType::FLOAT_FP16:
                cb1(NCHW, ONLY_PACKA, dt_float16, __fp16,
                    PostprocessMode::FLOAT,
                    "OnlyPackaStrategyType::FLOAT_FP16"_hash);
                break;
#else
397 398
#if !MEGDNN_DISABLE_FLOAT16
            case StrategyType::FLOAT16_FLOAT16:
399
                cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16,
400 401 402
                    PostprocessMode::NO_PROCESS,
                    "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash);
                break;
403
#endif
404 405
#endif
            case StrategyType::INT8x8x32:
406 407
                cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
408 409 410 411
                    "OnlyPackaStrategyType::INT8x8x32"_hash);
                break;

            case StrategyType::INT8x8x16:
412 413
                cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8,
                    dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
414 415 416
                    "OnlyPackaStrategyType::INT8x8x16"_hash);
                break;

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case StrategyType::QUINT8x8x32:
                cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
                    dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8,
                    dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
                    "OnlyPackaStrategyType::QUINT8x8x32"_hash);
                break;

            case StrategyType::QUINT8x8x32x8:
                cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
                    dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8,
                    dt_int32, dt_uint8, PostprocessMode::QUANTIZED,
                    "OnlyPackaStrategyType::QUINT8x8x32x8"_hash);
                break;
#endif
432
            case StrategyType::QINT8x8x32:
433
                cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
434 435 436 437 438 439
                    dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
                    PostprocessMode::NO_PROCESS,
                    "OnlyPackaStrategyType::QINT8x8x32"_hash);
                break;

            case StrategyType::QINT8x8x32x8:
440
                cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
                    dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
                    PostprocessMode::QUANTIZED,
                    "OnlyPackaStrategyType::QINT8x8x32x8"_hash);
                break;
        }
        megdnn_throw("error not support strategy type ");
    }

#undef cb1
#undef cb2

    static std::unique_ptr<StrategyBase> make_strategy(
            fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            fallback::MatrixMulImpl::AlgoBase::PackMode packmode,
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
            StrategyType stype) {
        switch (packmode) {
            case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
459
                return make_default_strategy(matmul_algo, param, stype);
460 461
                break;
            case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
462
                return make_onlypacka_strategy(matmul_algo, param, stype);
463 464
                break;
            case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
465
                return make_nopack_strategy(matmul_algo, param, stype);
466 467 468 469 470 471 472
                break;
            default:
                megdnn_throw(
                        "not support packmode except default onlypackA "
                        "nopack");
                break;
        }
473
        megdnn_throw("factory make Strategy error please check your code");
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
    }
};

template <typename Strategy>
Strategy* StrategyDelegationStorage::get(
        fallback::MatrixMulImpl::AlgoBase* matmul_algo,
        const fallback::ConvBiasImpl::NCBKernSizeParam& param,
        StrategyType stype) {
    fallback::MatrixMulImpl::AlgoBase::PackMode packmode =
            matmul_algo->packmode();
    //! nopack mode block_m block_n block_k is zero
    size_t block_m = 0, block_n = 0, block_k = 0;
    if (packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::DEFAULT ||
        packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
        block_m = matmul_algo->get_inner_block_size().m;
        block_n = matmul_algo->get_inner_block_size().n;
        block_k = matmul_algo->get_inner_block_size().k;
    }
    StrategyHashParam sparam;
    sparam.param = param;
494
    sparam.format = param.filter_meta.format;
495 496 497 498
    sparam.packmode = packmode;
    sparam.block_m = block_m;
    sparam.block_n = block_n;
    sparam.block_k = block_k;
499 500 501 502 503 504
    sparam.kernel = param.filter_meta.spatial[0];
    sparam.stride = param.filter_meta.stride[0];
    sparam.is_square =
            param.filter_meta.spatial[0] == param.filter_meta.spatial[0];
    sparam.is_xcorr = param.filter_meta.should_flip;
    MEGDNN_LOCK_GUARD(m_mtx);
505
    if (map_strategys.find(sparam) == map_strategys.end()) {
506 507
        auto strategy =
                Factory::make_strategy(matmul_algo, packmode, param, stype);
508 509 510 511 512 513 514
        map_strategys[sparam] = std::move(strategy);
    }
    return static_cast<Strategy*>(map_strategys[sparam].get());
}
}  // namespace im2col
}  // namespace fallback
}  // namespace megdnn
515 516

// vim: syntax=cpp.doxygen