algos.cpp 21.9 KB
Newer Older
1
#include "megdnn/opr_param_defs.h"
2

3 4
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
5
#include "src/fallback/conv_bias/im2col/algos.h"
M
Megvii Engine Team 已提交
6 7
#include "src/fallback/conv_bias/im2col/factory.h"
#include "src/fallback/conv_bias/im2col/im2col_kerns.h"
8 9
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/naive/convolution/helper.h"
10

11
#include "midout.h"
12

13 14 15 16
MIDOUT_DECL(megdnn_fallback_im2col)

using namespace megdnn;
using namespace fallback;
17
using namespace im2col;
18

19 20
namespace {
static fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param(
M
Megvii Engine Team 已提交
21 22
        const fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t ohw_tile_size,
        size_t oc_tile_size) {
23
    auto format = param::MatrixMul::Format::DEFAULT;
24
    size_t pack_oc_size = pack_size(param.filter_meta.format);
25 26
    if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
        format = param::MatrixMul::Format::MK4;
M
Megvii Engine Team 已提交
27
    } else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) {
28
        format = param::MatrixMul::Format::MK4_DOT;
29 30
    } else if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
        format = param::MatrixMul::Format::MK8;
31
    }
32 33 34 35
    size_t M = oc_tile_size;
    size_t N = ohw_tile_size;
    size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] *
               param.filter_meta.spatial[1];
M
Megvii Engine Team 已提交
36
    size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N * pack_oc_size;
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
                        param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
                       (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
                        param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
    return {param.filter_type,
            param.src_type,
            is_dst_8bit ? param.bias_type : param.dst_type,
            M,
            N,
            K,
            LDA,
            LDB,
            LDC,
            false,
            false,
            param::MatrixMul::ComputeMode::DEFAULT,
53
            format};
54 55
}

56
static void choice_ohw_oc_block(
M
Megvii Engine Team 已提交
57 58 59
        const fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& oc_tile_size,
        size_t& ohw_tile_size, size_t block_m, size_t block_n,
        const size_t m_ohw_tile_size,
60 61 62
        fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) {
    //! calculate m_oc_tile_size in choice_ohw_oc_block() fucntion,
    //! when ohw_tile_size < this value ohw_tile_size = ohw
63
    size_t DEFAULT_OHW_MIN_TILE_SIZE = round_up(static_cast<size_t>(32), block_n);
64 65
    //! when nr_threads > 1 and round(ohw,nr_threads)>nr_threads,
    //! oc_tile_size = DEFAULT_OC_TILE_SIZE
66
    size_t DEFAULT_OC_TILE_SIZE = round_up(static_cast<size_t>(512), block_m);
67 68
    //! when oc_tile_size > this value m_oc_tile_size =
    //! DEFAULT_OC_MAX_TILE_SIZE
69
    size_t DEFAULT_OC_MAX_TILE_SIZE = round_up(static_cast<size_t>(1024), block_m);
70 71
    //! when oc_tile_size < this value oc_tile_size =
    //! DEFAULT_OC_MIN_TILE_SIZE the purpose is aligning the calculation
72
    size_t DEFAULT_OC_MIN_TILE_SIZE = round_up(static_cast<size_t>(128), block_m);
73 74 75
    size_t nr_threads = param.nr_threads;
    size_t OC = param.filter_meta.ocpg;
    size_t ohw = param.osz[0] * param.osz[1];
76 77
    oc_tile_size = DEFAULT_OC_TILE_SIZE;
    ohw_tile_size = m_ohw_tile_size;
78

79 80
    oc_tile_size = std::min(oc_tile_size, OC);
    ohw_tile_size = std::min(ohw_tile_size, ohw);
81 82

    if (nr_threads > 1) {
83 84 85 86 87 88 89 90 91
        if (ohw / ohw_tile_size < nr_threads) {
            ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n);
            if (ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) {
                ohw_tile_size = ohw;
                oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m);
                if (oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) {
                    oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE;
                } else if (oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) {
                    oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE;
92 93 94 95
                }
            }
        }
    } else {
96 97
        //! in no_pack mode don't do block operation when using single thread
        if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
98 99
            ohw_tile_size = ohw;
            oc_tile_size = OC;
100 101 102 103
        }
    }
}

104 105 106 107 108
static size_t packA_group_size(
        const MatrixMulImpl::AlgoBase* matmul_algo,
        const fallback::MatrixMulImpl::KernSizeParam& matmul_param,
        const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
        size_t packa_parallel_times) {
M
Megvii Engine Team 已提交
109
    if (matmul_desc.packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::DEFAULT) {
110
        return matmul_algo->get_bundle(matmul_param).get_size(0);
M
Megvii Engine Team 已提交
111 112 113 114
    } else if (
            matmul_desc.packmode ==
            fallback::MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
        return packa_parallel_times * matmul_algo->get_bundle(matmul_param).get_size(0);
115
    }
M
Megvii Engine Team 已提交
116 117 118
    megdnn_assert(
            matmul_desc.packmode ==
            fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK);
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    //! nopack mode return 0;
    return 0;
}

static WorkspaceBundle get_thread_bundle(
        const fallback::ConvBiasImpl::NCBKernSizeParam& param,
        const MatrixMulImpl::AlgoBase* matmul_algo,
        const fallback::MatrixMulImpl::KernSizeParam& matmul_param,
        const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
        size_t oc_tile_size, size_t ohw_tile_size) {
    if (matmul_desc.packmode == Pack_Mode::DEFAULT) {
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
                midout_iv("ConvBiasImpl::AlgoIm2col::get_bundle_dft"_hash)) {
            Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
M
Megvii Engine Team 已提交
134 135
            return defaultkern.get_thread_bundle(
                    param, matmul_param, matmul_algo, ohw_tile_size, oc_tile_size);
136 137
        }
        MIDOUT_END();
M
Megvii Engine Team 已提交
138 139 140
    } else if (
            matmul_desc.packmode ==
            fallback::MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
141 142
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
M
Megvii Engine Team 已提交
143
                midout_iv("ConvBiasImpl::AlgoIm2col::get_bundle_onlypacka"_hash)) {
144
            Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
M
Megvii Engine Team 已提交
145 146
            return onlypackakern.get_thread_bundle(
                    param, matmul_param, matmul_algo, ohw_tile_size, oc_tile_size);
147 148 149
        }
        MIDOUT_END();
    } else {
M
Megvii Engine Team 已提交
150 151 152
        megdnn_assert(
                matmul_desc.packmode ==
                fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK);
153 154
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
M
Megvii Engine Team 已提交
155
                midout_iv("ConvBiasImpl::AlgoIm2col::get_thread_bundle_nopack"_hash)) {
156
            Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
M
Megvii Engine Team 已提交
157 158
            return nopackkern.get_thread_bundle(
                    param, matmul_param, matmul_algo, ohw_tile_size, oc_tile_size);
159 160 161 162 163 164 165 166 167 168
        }
        MIDOUT_END();
    }
    return {nullptr, {}};
}

static WorkspaceBundle get_bundle(
        const fallback::ConvBiasImpl::NCBKernSizeParam& param,
        MatrixMulImpl::AlgoBase* matmul_algo, size_t oc_tile_size,
        size_t ohw_tile_size) {
169 170 171 172 173 174 175 176 177 178 179 180 181 182
    UNPACK_CONV_F32_NCB_KERN_SIZES(param);
    MEGDNN_MARK_USED_VAR(OH);
    MEGDNN_MARK_USED_VAR(OW);
    MEGDNN_MARK_USED_VAR(FH);
    MEGDNN_MARK_USED_VAR(FW);
    MEGDNN_MARK_USED_VAR(SW);
    MEGDNN_MARK_USED_VAR(SH);

    auto IW2 = IH + 2 * PH;
    auto IH2 = IW + 2 * PW;
    bool no_need_pading = (PH == 0 && PW == 0);
    size_t padding = 0, packa_size = 0, packa_group_size = 0;
    size_t nr_threads = param.nr_threads;
    size_t GROUP = param.filter_meta.group;
183 184 185 186 187 188 189 190 191
    fallback::MatrixMulImpl::AlgoBase::MatmulDescription matmul_desc =
            matmul_algo->matmul_description();
    bool default_pack = matmul_desc.packmode == Pack_Mode::DEFAULT;

    //! packmode is default should use oc
    //! packmode is onlypackA should use oc_tile_size
    auto im2col_kern_param = get_matmul_kern_param(
            param, ohw_tile_size, default_pack ? OC : oc_tile_size);
    if (is_enable_filter_preprocess(param)) {
192
        packa_group_size = 0;
193 194
    } else {
        size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
M
Megvii Engine Team 已提交
195 196
        packa_group_size = packA_group_size(
                matmul_algo, im2col_kern_param, matmul_desc, oc_parallel_times);
197
    }
198

199 200 201
    if (no_need_pading) {
        padding = 0;  //! not need  padding
    } else {
M
Megvii Engine Team 已提交
202 203
        padding =
                (GROUP * N * IC * IH2 * IW2) * sizeof(param.src_type);  //! for padding
204
    }
205

206
    packa_size = GROUP * packa_group_size;  //! for packA  size = GROUP * a_size
207

M
Megvii Engine Team 已提交
208 209 210 211
    WorkspaceBundle ws = get_thread_bundle(
            param, matmul_algo, im2col_kern_param, matmul_desc, oc_tile_size,
            ohw_tile_size);
    return {nullptr, {padding, packa_size, ws.total_size_in_bytes() * nr_threads}};
212 213
}

214 215
}  // namespace

M
Megvii Engine Team 已提交
216
size_t ConvBiasImpl::AlgoIm2col::get_workspace(const NCBKernSizeParam& p) const {
217
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 0) {
218 219 220
        fallback::MatrixMulImpl::AlgoBase::MatmulDescription matmul_desc =
                m_matmul_algo->matmul_description();
        size_t oc_tile_size = 0, ohw_tile_size = 0;
M
Megvii Engine Team 已提交
221 222 223
        choice_ohw_oc_block(
                p, oc_tile_size, ohw_tile_size, matmul_desc.innerblocksize.m,
                matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode);
224 225
        return get_bundle(p, m_matmul_algo, oc_tile_size, ohw_tile_size)
                .total_size_in_bytes();
226 227 228 229 230 231
    }
    MIDOUT_END();
    return 0;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
232
        const NCBKernSizeParam& param) const {
233
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) {
234 235 236
        size_t OH = param.osz[0];
        size_t OW = param.osz[1];
        size_t OC = param.filter_meta.ocpg;
237
        size_t ohw = OH * OW;
238
        size_t oc_tile_size = 0, ohw_tile_size = 0;
239

240
        auto matmul_desc = m_matmul_algo->matmul_description();
241

242 243 244 245
        bool default_pack = matmul_desc.packmode == Pack_Mode::DEFAULT;
        bool no_pack = matmul_desc.packmode == Pack_Mode::NO_PACK;
        bool only_packA = matmul_desc.packmode == Pack_Mode::ONLY_PACKA;
        bool enable_filter_preprocess = is_enable_filter_preprocess(param);
M
Megvii Engine Team 已提交
246 247 248
        choice_ohw_oc_block(
                param, oc_tile_size, ohw_tile_size, matmul_desc.innerblocksize.m,
                matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode);
249

250
        size_t packa_parallel_times = 0;
251
        size_t pack_oc_size = pack_size(param.filter_meta.format);
252
        if (only_packA) {
253
            packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
254
        } else if (default_pack) {
M
Megvii Engine Team 已提交
255
            packa_parallel_times = div_ceil<size_t>(OC, matmul_desc.innerblocksize.m);
256 257 258
        }

        auto matmul_param = get_matmul_kern_param(
259
                param, ohw_tile_size, default_pack ? OC : oc_tile_size);
260

261 262
        WorkspaceBundle bundle =
                get_bundle(param, m_matmul_algo, oc_tile_size, ohw_tile_size);
M
Megvii Engine Team 已提交
263 264 265
        WorkspaceBundle bundle_thread = get_thread_bundle(
                param, m_matmul_algo, matmul_param, matmul_desc, oc_tile_size,
                ohw_tile_size);
266

267 268 269 270 271 272 273
        StrategyParam strategyparam;
        strategyparam.ohw = ohw;
        strategyparam.is_dst_8bit =
                (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
                 param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
                (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
                 param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
274
        strategyparam.is_ohw_size_bigger = (ohw_tile_size >= ohw);
275 276
        strategyparam.skip_copy_dst =
                strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit;
277
        strategyparam.oc_tile_size = oc_tile_size;
278
        strategyparam.pack_oc_size = pack_oc_size;
279 280 281
        strategyparam.enable_filter_preprocess = enable_filter_preprocess;
        strategyparam.packA_group_size = packA_group_size(
                m_matmul_algo, matmul_param, matmul_desc, packa_parallel_times);
282

283
        SmallVector<ConvBiasImpl::NCBKern> ret_kern;
284 285 286
        StrategyBase* im2colstrategy =
                Factory::get_im2col_strategy(param, m_matmul_algo);
        if (default_pack) {
M
Megvii Engine Team 已提交
287 288 289
            MIDOUT_BEGIN(
                    megdnn_fallback_im2col,
                    midout_iv("dispatch_kerns_default_pack"_hash)) {
290
                return Im2colKerns<Pack_Mode::DEFAULT>().get_kerns(
M
Megvii Engine Team 已提交
291 292 293
                        param, bundle, bundle_thread, strategyparam, matmul_param,
                        im2colstrategy, m_matmul_algo, ohw_tile_size, oc_tile_size,
                        pack_oc_size);
294 295 296 297
            }
            MIDOUT_END();
            return {};
        } else if (only_packA) {
M
Megvii Engine Team 已提交
298 299 300
            MIDOUT_BEGIN(
                    megdnn_fallback_im2col,
                    midout_iv("dispatch_kerns_onlypacka"_hash)) {
301
                return Im2colKerns<Pack_Mode::ONLY_PACKA>().get_kerns(
M
Megvii Engine Team 已提交
302 303 304
                        param, bundle, bundle_thread, strategyparam, matmul_param,
                        im2colstrategy, m_matmul_algo, ohw_tile_size, oc_tile_size,
                        pack_oc_size);
305
            }
306 307 308
            MIDOUT_END();
            return {};
        } else if (no_pack) {
M
Megvii Engine Team 已提交
309 310
            MIDOUT_BEGIN(
                    megdnn_fallback_im2col, midout_iv("dispatch_kerns_no_pack"_hash)) {
311
                return Im2colKerns<Pack_Mode::NO_PACK>().get_kerns(
M
Megvii Engine Team 已提交
312 313 314
                        param, bundle, bundle_thread, strategyparam, matmul_param,
                        im2colstrategy, m_matmul_algo, ohw_tile_size, oc_tile_size,
                        pack_oc_size);
315 316
            }
            MIDOUT_END();
317
            return {};
318 319
        }
        return {};
320 321 322 323 324 325
    }
    MIDOUT_END();
    return {};
}

bool ConvBiasImpl::AlgoIm2col::usable(
M
Megvii Engine Team 已提交
326
        const NCBKernSizeParam& param,
327 328
        AlgoSelectionStrategy /*algo_selection_strategy*/) const {
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) {
329
        auto format = param.filter_meta.format;
330 331
        auto matmul_desc = m_matmul_algo->matmul_description();
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
332
        if (format != param::ConvBias::Format::NCHW &&
333
            format != param::ConvBias::Format::NCHW44 &&
334 335
            format != param::ConvBias::Format::NCHW44_DOT &&
            format != param::ConvBias::Format::NCHW88) {
336 337
            return false;
        }
338 339 340 341 342
        if (format == param::ConvBias::Format::NCHW88) {
            if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
                return false;
            }
        }
343 344 345 346 347 348
        if (format == param::ConvBias::Format::NCHW44 ||
            format == param::ConvBias::Format::NCHW44_DOT) {
            //! current NCHW44 im2col only support DEFAULT mode matmul
            if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
                return false;
                //! nchw44 hybird mode and channel wise is not support
M
Megvii Engine Team 已提交
349 350 351
            } else if (
                    param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
                    param.filter_meta.ocpg == 1) {
352 353 354 355
                return false;
            }
        }
#else
356 357
        if (format != param::ConvBias::Format::NCHW &&
            format != param::ConvBias::Format::NCHW44) {
358 359
            return false;
        }
360 361 362 363 364 365 366 367 368 369 370
        if (format == param::ConvBias::Format::NCHW44) {
            //! current NCHW44 im2col only support DEFAULT mode matmul
            if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
                return false;
                //! nchw44 hybird mode and channel wise is not support
            } else if (
                    param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
                    param.filter_meta.ocpg == 1) {
                return false;
            }
        }
371 372 373 374 375
#endif
        if (param.src_type.enumv() != param.filter_type.enumv() ||
            (param.src_type.enumv() != DTypeEnum::Int8 &&
             param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
             param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
376
#if !MEGDNN_DISABLE_FLOAT16
377
             param.src_type.enumv() != DTypeEnum::Float16 &&
378
#endif
379
             param.src_type.enumv() != DTypeEnum::Float32)) {
380 381
            return false;
        }
382 383 384 385 386 387 388 389

        //! x86 disable  Quntized8Asymm
#if MEGDNN_X86
        if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
            return false;
        }
#endif

390 391 392 393 394 395 396 397 398
        //! 8x8x32 and 8x8x8 and NO_PACK is not supported
        if (matmul_desc.packmode == Pack_Mode::NO_PACK &&
            param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
            param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
            (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
             param.dst_type.enumv() == DTypeEnum::QuantizedS32)) {
            return false;
        }

399
        //! make sure 8x8x16 and 8x8x32 biasmode is  nobias and nonlineMode is
400 401
        //! identity otherwise return false mean that 8x8x32 and 8x8x16 not
        //! support PostProcess
402 403 404
        if (param.dst_type.enumv() == DTypeEnum::Int16 ||
            param.dst_type.enumv() == DTypeEnum::Int32 ||
            param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
405
            if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
406 407
                return false;
            }
408
        }
409
        size_t oc_tile_size = 0, ohw_tile_size = 0;
M
Megvii Engine Team 已提交
410 411 412
        choice_ohw_oc_block(
                param, oc_tile_size, ohw_tile_size, matmul_desc.innerblocksize.m,
                matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode);
413
        fallback::MatrixMulImpl::KernSizeParam matmul_param =
414
                get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
415 416
        bool matmulusable = m_matmul_algo->usable(matmul_param);
        return matmulusable &&
M
Megvii Engine Team 已提交
417
               (!(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
418
                  param.filter_meta.spatial[0] == 1 &&
419 420
                  param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
                  param.filter_meta.stride[0] == 1)) &&
M
Megvii Engine Team 已提交
421
               (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
422 423 424 425 426 427 428
                param.filter_meta.dilation[0] == 1) &&
               param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
    }
    MIDOUT_END();
    return false;
}

M
Megvii Engine Team 已提交
429
SmallVector<TensorLayout> ConvBiasImpl::AlgoIm2col::deduce_preprocessed_filter_layout(
430
        const NCBKernSizeParam& param) const {
M
Megvii Engine Team 已提交
431 432 433
    MIDOUT_BEGIN(
            megdnn_fallback_im2col,
            midout_iv("deduce_preprocessed_filter_layout"_hash)) {
434 435 436 437 438 439 440 441 442 443 444 445 446 447
        fallback::MatrixMulImpl::AlgoBase::MatmulDescription matmul_desc =
                m_matmul_algo->matmul_description();

        //! only support default_pack and only_packa mode
        if (matmul_desc.packmode == Pack_Mode::NO_PACK) {
            return {};
        }

        size_t GROUP = param.filter_meta.group;
        bool default_pack = matmul_desc.packmode == Pack_Mode::DEFAULT;

        size_t OC = param.filter_meta.ocpg;
        SmallVector<TensorLayout> preprocessed_layouts;
        size_t oc_tile_size = 0, ohw_tile_size = 0;
M
Megvii Engine Team 已提交
448 449 450
        choice_ohw_oc_block(
                param, oc_tile_size, ohw_tile_size, matmul_desc.innerblocksize.m,
                matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode);
451 452 453 454 455 456 457 458
        auto matmul_param = get_matmul_kern_param(
                param, ohw_tile_size, default_pack ? OC : oc_tile_size);

        size_t packa_parallel_times = div_ceil<size_t>(
                OC, default_pack ? matmul_desc.innerblocksize.m : oc_tile_size);

        size_t packa_group_size = packA_group_size(
                m_matmul_algo, matmul_param, matmul_desc, packa_parallel_times);
M
Megvii Engine Team 已提交
459
        preprocessed_layouts.push_back({{GROUP, packa_group_size}, dtype::Int8()});
460 461 462 463 464 465
        return preprocessed_layouts;
    }
    MIDOUT_END();
    return {};
}

M
Megvii Engine Team 已提交
466
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_preprocess_kerns(
467 468 469 470 471 472 473
        const NCBKernSizeParam& param) const {
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 3) {
        size_t OC = param.filter_meta.ocpg;
        size_t oc_tile_size = 0, ohw_tile_size = 0;
        size_t GROUP = param.filter_meta.group;
        fallback::MatrixMulImpl::AlgoBase::MatmulDescription matmul_desc =
                m_matmul_algo->matmul_description();
M
Megvii Engine Team 已提交
474 475 476
        choice_ohw_oc_block(
                param, oc_tile_size, ohw_tile_size, matmul_desc.innerblocksize.m,
                matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode);
477 478 479 480 481 482 483 484 485 486 487
        WorkspaceBundle bundle =
                get_bundle(param, m_matmul_algo, oc_tile_size, ohw_tile_size);

        Pack_Mode packmode = matmul_desc.packmode;
        bool default_pack = packmode == Pack_Mode::DEFAULT;
        bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
        size_t packa_parallel_times = 0;

        if (only_packA) {
            packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
        } else if (default_pack) {
M
Megvii Engine Team 已提交
488
            packa_parallel_times = div_ceil<size_t>(OC, matmul_desc.innerblocksize.m);
489 490 491 492 493 494 495
        } else {
            return {};
        }
        auto matmul_param = get_matmul_kern_param(
                param, ohw_tile_size, default_pack ? OC : oc_tile_size);

        StrategyParam strategyparam;
M
Megvii Engine Team 已提交
496
        strategyparam.enable_filter_preprocess = is_enable_filter_preprocess(param);
497 498 499 500 501 502 503 504 505 506 507 508
        strategyparam.packA_group_size = packA_group_size(
                m_matmul_algo, matmul_param, matmul_desc, packa_parallel_times);
        SmallVector<ConvBiasImpl::NCBKern> ret_kern;
        StrategyBase* im2colstrategy =
                Factory::get_im2col_strategy(param, m_matmul_algo);

        auto kern_packA = [bundle, matmul_algo = m_matmul_algo, matmul_param,
                           im2colstrategy, strategyparam = strategyparam,
                           matmul_desc = matmul_desc](
                                  const NCBKernParam& param,
                                  const NCBKernIndex& ncb_index) mutable {
            bundle.set(param.workspace_ptr);
M
Megvii Engine Team 已提交
509 510 511
            im2colstrategy->packA_kern(
                    bundle, param, matmul_param, matmul_algo, ncb_index, matmul_desc,
                    strategyparam);
512 513 514 515 516 517 518 519
        };
        ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
        return ret_kern;
    }
    MIDOUT_END();
    return {};
}

520
// vim: syntax=cpp.doxygen