algos.cpp 32.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/**
 * \file dnn/src/fallback/conv_bias/im2col/algos.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/fallback/conv_bias/im2col/algos.h"
13
#include "src/fallback/conv_bias/im2col/factory.h"
14 15 16 17 18
#include "megdnn/opr_param_defs.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/naive/convolution/helper.h"
19

20
#include "midout.h"
21

22 23 24 25
MIDOUT_DECL(megdnn_fallback_im2col)

using namespace megdnn;
using namespace fallback;
26
using namespace im2col;
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

/*======================== AlgoIm2col=======================*/
/*!
 *  *\brief The index of all parts workspace in im2col workspace bundel
 *  *Through witch can convenient get the needed ptr
 */
struct Im2colBundelIndex {
    static constexpr size_t BUNDLE_PADDING_INDEX = 0_z;
    static constexpr size_t BUNDLE_PACKA_INDEX = 1_z;
    static constexpr size_t BUNDLE_THREAD_INDEX = 2_z;
};

using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode;

//! Process one input channel copy padding
static void copy_padding_kern(WorkspaceBundle bundle,
                              const ConvBiasImpl::NCBKernParam& param,
44
                              const ConvBiasImpl::NCBKernIndex& ncb_index,
45 46
                              StrategyBase* im2colstrategy, size_t pack_oc_size) {
    im2colstrategy->copy_padding_kern(bundle, param, ncb_index, pack_oc_size);
47
}
48

49
//! packA_kern
50 51 52 53 54 55 56 57 58
static void packA_kern(
        WorkspaceBundle bundle,
        const fallback::ConvBiasImpl::NCBKernParam& param,
        fallback::MatrixMulImpl::KernSizeParam matmulparam,
        fallback::MatrixMulImpl::AlgoBase* matmul_algo,
        const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
        StrategyBase* im2colstrategy,
        const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
        size_t pack_oc_size) {
59
    im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo,
60
                               ncb_index, matmul_desc, pack_oc_size);
61
}
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

/*!
 * *\brief Im2colKerns collects all the im2col kerns in it
 */

template <Pack_Mode packmode>
class Im2colKerns;

template <>
class Im2colKerns<Pack_Mode::DEFAULT> {
public:
    //! conv kernel
    static void kerns(
            WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
            const ConvBiasImpl::NCBKernParam& param,
            fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
78 79
            const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
80
            StrategyParam strategyparam,
81
            fallback::ConvBiasImpl::NCBKernIndex ncb_index,
82 83
            size_t ohw_tile_size, StrategyBase* im2colstrategy) {
        size_t OC = param.filter_meta.ocpg;
84
        size_t output_block_size = std::min(
85 86
                ohw_tile_size,
                strategyparam.ohw - ncb_index.ndrange_id[2] * ohw_tile_size);
87
        size_t output_block_oc_size = std::min(
88 89 90 91 92 93 94 95 96 97 98 99 100 101
                strategyparam.oc_tile_size,
                OC - ncb_index.ndrange_id[3] * strategyparam.oc_tile_size);

        strategyparam.batch_id = ncb_index.ndrange_id[0];
        strategyparam.group_id = ncb_index.ndrange_id[1];
        strategyparam.oc_cur_index =
                ncb_index.ndrange_id[3] *
                strategyparam.oc_tile_size;
        strategyparam.oc_end_index = strategyparam.oc_cur_index +
                                     output_block_oc_size;
        strategyparam.ohw_cur_index =
                ncb_index.ndrange_id[2] * ohw_tile_size;
        strategyparam.output_block_oc_size = output_block_oc_size;
        strategyparam.output_block_size = output_block_size;
102 103

        bundle.set(param.workspace_ptr);
104 105 106 107
        bundle_thread.set(
                static_cast<int8_t*>(
                        bundle.get(Im2colBundelIndex::BUNDLE_THREAD_INDEX)) +
                bundle_thread.total_size_in_bytes() * ncb_index.thread_id);
108 109 110 111
        fallback::MatrixMulImpl::KernParam matmul_param;
        static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
                matmul_kernsize_param;

112 113 114
        //! 1.Im2col
        im2colstrategy->exec_im2col(bundle, bundle_thread, strategyparam, param,
                                    matmul_param, matmul_algo);
115

116 117
        //! 2.packb and matmul compute
        im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
118 119
                                    matmul_param, matmul_algo, ncb_index,
                                    matmul_desc);
120

121 122 123
        //! 3.postprocess and copy dst if need
        im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
    }
124

125 126 127 128 129 130 131
    WorkspaceBundle get_thread_bundle(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
            fallback::MatrixMulImpl::KernSizeParam im2col_kern_param,
            MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size,
            size_t oc_tile_size) {
        size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
               FW = param.filter_meta.spatial[1];
132
        size_t pack_oc_size = pack_size(param.filter_meta.format);
133 134 135 136 137
        size_t im2col = 0, packb = 0, bias_temp = 0;
        bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
        megdnn_assert(default_pack, "only support default packa");
        size_t im2col_dst_size =
                IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
138 139
        size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size *
                                 sizeof(param.bias_type);
140 141 142 143 144 145 146 147
        //! matmul_dst and im2col_dst use the same memory
        WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param);
        packb = wb.get_size(1);
        im2col = std::max(im2col_dst_size, matmul_dst_size);
        if (param.bias_mode == megdnn::BiasMode::BIAS) {
            bias_temp = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
        }
        return {nullptr, {packb, im2col, bias_temp}};
148 149 150 151 152 153 154 155 156 157 158
    }
};

template <>
class Im2colKerns<Pack_Mode::ONLY_PACKA> {
public:
    //! conv kernel
    static void kerns(
            WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
            const ConvBiasImpl::NCBKernParam& param,
            fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
159 160
            const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
161
            StrategyParam strategyparam,
162
            fallback::ConvBiasImpl::NCBKernIndex ncb_index,
163 164
            size_t ohw_tile_size, StrategyBase* im2colstrategy) {
        size_t OC = param.filter_meta.ocpg;
165
        size_t output_block_size = std::min(
166 167
                ohw_tile_size,
                strategyparam.ohw - ncb_index.ndrange_id[2] * ohw_tile_size);
168
        size_t output_block_oc_size = std::min(
169 170
                strategyparam.oc_tile_size,
                OC - ncb_index.ndrange_id[3] * strategyparam.oc_tile_size);
171 172

        bundle.set(param.workspace_ptr);
173 174 175 176
        bundle_thread.set(
                static_cast<int8_t*>(
                        bundle.get(Im2colBundelIndex::BUNDLE_THREAD_INDEX)) +
                bundle_thread.total_size_in_bytes() * ncb_index.thread_id);
177 178 179 180 181

        fallback::MatrixMulImpl::KernParam matmul_param;
        static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
                matmul_kernsize_param;

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        strategyparam.batch_id = ncb_index.ndrange_id[0];
        strategyparam.group_id = ncb_index.ndrange_id[1];
        strategyparam.oc_cur_index =
                ncb_index.ndrange_id[3] *
                strategyparam.oc_tile_size;
        strategyparam.oc_end_index = strategyparam.oc_cur_index +
                                     output_block_oc_size;
        strategyparam.ohw_cur_index =
                ncb_index.ndrange_id[2] * ohw_tile_size;
        strategyparam.output_block_oc_size = output_block_oc_size;
        strategyparam.output_block_size = output_block_size;

        //! 1.Im2col
        im2colstrategy->exec_im2col(bundle, bundle_thread, strategyparam, param,
                                    matmul_param, matmul_algo);

        //! 2.packb and matmul compute
        im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
200 201
                                    matmul_param, matmul_algo, ncb_index,
                                    matmul_desc);
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228

        //! 3.postprocess and copy dst if need
        im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
    }
    WorkspaceBundle get_thread_bundle(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
            fallback::MatrixMulImpl::KernSizeParam im2col_kern_param,
            MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size,
            size_t oc_tile_size) {
        size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
               FW = param.filter_meta.spatial[1];

        size_t im2col = 0, packb = 0, matmul_dst = 0, bias_temp = 0;
        bool only_packA = matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
        megdnn_assert(only_packA, "onlysupport onlypackA mode");
        size_t im2col_dst_size =
                IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
        size_t matmul_dst_size =
                oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
        //! matmul_dst and im2col_dst use the same memory
        WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param);
        packb = wb.get_size(1);
        im2col = im2col_dst_size;
        matmul_dst = matmul_dst_size;
        if (param.bias_mode == megdnn::BiasMode::BIAS) {
            bias_temp = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
        }
229

230
        return {nullptr, {packb, im2col, matmul_dst, bias_temp}};
231 232 233 234 235 236 237 238 239 240 241
    }
};

template <>
class Im2colKerns<Pack_Mode::NO_PACK> {
public:
    //! conv kernel
    static void kerns(
            WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
            const ConvBiasImpl::NCBKernParam& param,
            fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
242 243
            const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
            const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
244
            StrategyParam strategyparam,
245
            fallback::ConvBiasImpl::NCBKernIndex ncb_index,
246 247
            size_t ohw_tile_size, StrategyBase* im2colstrategy) {
        size_t OC = param.filter_meta.ocpg;
248
        size_t output_block_size = std::min(
249 250
                ohw_tile_size,
                strategyparam.ohw - ncb_index.ndrange_id[2] * ohw_tile_size);
251
        size_t output_block_oc_size = std::min(
252 253 254 255 256 257 258 259 260 261 262 263 264 265
                strategyparam.oc_tile_size,
                OC - ncb_index.ndrange_id[3] * strategyparam.oc_tile_size);

        strategyparam.batch_id = ncb_index.ndrange_id[0];
        strategyparam.group_id = ncb_index.ndrange_id[1];
        strategyparam.oc_cur_index =
                ncb_index.ndrange_id[3] *
                strategyparam.oc_tile_size;
        strategyparam.oc_end_index = strategyparam.oc_cur_index +
                                     output_block_oc_size;
        strategyparam.ohw_cur_index =
                ncb_index.ndrange_id[2] * ohw_tile_size;
        strategyparam.output_block_oc_size = output_block_oc_size;
        strategyparam.output_block_size = output_block_size;
266 267

        bundle.set(param.workspace_ptr);
268 269 270 271
        bundle_thread.set(
                static_cast<int8_t*>(
                        bundle.get(Im2colBundelIndex::BUNDLE_THREAD_INDEX)) +
                bundle_thread.total_size_in_bytes() * ncb_index.thread_id);
272 273 274 275 276

        fallback::MatrixMulImpl::KernParam matmul_param;
        static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
                matmul_kernsize_param;

277 278 279
        //! 1.Im2col
        im2colstrategy->exec_im2col(bundle, bundle_thread, strategyparam, param,
                                    matmul_param, matmul_algo);
280

281 282
        //! 2.packb and matmul compute
        im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
283 284
                                    matmul_param, matmul_algo, ncb_index,
                                    matmul_desc);
285

286 287 288 289 290 291 292 293 294 295 296
        //! 3.postprocess and copy dst if need
        im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
    }
    WorkspaceBundle get_thread_bundle(
            const fallback::ConvBiasImpl::NCBKernSizeParam& param,
            fallback::MatrixMulImpl::KernSizeParam im2col_kern_param,
            MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size,
            size_t oc_tile_size) {
        size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
               FW = param.filter_meta.spatial[1];
        size_t ohw = param.osz[0] * param.osz[1];
297

298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        size_t im2col = 0, matmul_dst = 0, bias_temp = 0, matmul_compute = 0;
        bool no_pack = matmul_algo->packmode() == Pack_Mode::NO_PACK;
        megdnn_assert(no_pack, "only support no pack");
        bool is_dst_8bit =
                (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
                 param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
                (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
                 param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
        size_t im2col_dst_size =
                IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
        size_t matmul_dst_size =
                oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
        im2col = im2col_dst_size;
        if (is_dst_8bit) {
            matmul_dst = matmul_dst_size;
        } else {
            matmul_dst = ohw_tile_size >= ohw ? 0 : matmul_dst_size;
        }
        matmul_compute = matmul_algo->get_workspace(im2col_kern_param);
        if (param.bias_mode == megdnn::BiasMode::BIAS) {
            bias_temp = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
        }
320

321
        return {nullptr, {im2col, matmul_dst, bias_temp, matmul_compute}};
322 323 324 325 326 327 328
    }
};

fallback::MatrixMulImpl::KernSizeParam
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
                                                 size_t ohw_tile_size,
                                                 size_t oc_tile_size) const {
329
    auto format = param::MatrixMul::Format::DEFAULT;
330
    size_t pack_oc_size = pack_size(param.filter_meta.format);
331 332
    if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
        format = param::MatrixMul::Format::MK4;
333 334
    } else if(param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT){
        format = param::MatrixMul::Format::MK4_DOT;
335
    }
336 337 338 339
    size_t M = oc_tile_size;
    size_t N = ohw_tile_size;
    size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] *
               param.filter_meta.spatial[1];
340 341
    size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N,
           LDC = N * pack_oc_size;
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
    bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
                        param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
                       (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
                        param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
    return {param.filter_type,
            param.src_type,
            is_dst_8bit ? param.bias_type : param.dst_type,
            M,
            N,
            K,
            LDA,
            LDB,
            LDC,
            false,
            false,
            param::MatrixMul::ComputeMode::DEFAULT,
358
            format};
359 360 361
}

void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
362 363
        const NCBKernSizeParam& param, size_t& oc_tile_size,
        size_t& ohw_tile_size, size_t block_m, size_t block_n,
364
        fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const {
365 366 367
    size_t nr_threads = param.nr_threads;
    size_t OC = param.filter_meta.ocpg;
    size_t ohw = param.osz[0] * param.osz[1];
368 369
    oc_tile_size = DEFAULT_OC_TILE_SIZE;
    ohw_tile_size = m_ohw_tile_size;
370

371 372
    oc_tile_size = std::min(oc_tile_size, OC);
    ohw_tile_size = std::min(ohw_tile_size, ohw);
373 374

    if (nr_threads > 1) {
375 376 377 378 379 380 381 382 383
        if (ohw / ohw_tile_size < nr_threads) {
            ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n);
            if (ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) {
                ohw_tile_size = ohw;
                oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m);
                if (oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) {
                    oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE;
                } else if (oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) {
                    oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE;
384 385 386 387
                }
            }
        }
    } else {
388 389
        //! in no_pack mode don't do block operation when using single thread
        if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
390 391
            ohw_tile_size = ohw;
            oc_tile_size = OC;
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
        }
    }
}

WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
        const NCBKernSizeParam& param) const {
    UNPACK_CONV_F32_NCB_KERN_SIZES(param);
    MEGDNN_MARK_USED_VAR(OC);
    MEGDNN_MARK_USED_VAR(OH);
    MEGDNN_MARK_USED_VAR(OW);
    MEGDNN_MARK_USED_VAR(FH);
    MEGDNN_MARK_USED_VAR(FW);
    MEGDNN_MARK_USED_VAR(SW);
    MEGDNN_MARK_USED_VAR(SH);

    auto IW2 = IH + 2 * PH;
    auto IH2 = IW + 2 * PW;
    bool no_need_pading = (PH == 0 && PW == 0);
    size_t padding = 0, packa_size = 0, packa_group_size = 0;
    size_t nr_threads = param.nr_threads;
    size_t GROUP = param.filter_meta.group;
413 414 415 416
    fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
            m_matmul_algo->matmul_description();
    bool need_pack = mdesc.packmode == Pack_Mode::DEFAULT;
    bool only_packA = mdesc.packmode == Pack_Mode::ONLY_PACKA;
417
    size_t oc_tile_size = 0, ohw_tile_size = 0;
418 419 420
    choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
                        mdesc.innerblocksize.m, mdesc.innerblocksize.n,
                        mdesc.packmode);
421 422
    if (need_pack || only_packA) {
        auto im2col_kern_param = get_matmul_kern_param(
423 424
                param, ohw_tile_size, only_packA ? oc_tile_size : OC);
        size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
425 426 427 428 429 430
        WorkspaceBundle wb = m_matmul_algo->get_bundle(im2col_kern_param);
        packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0)
                                      : wb.get_size(0);
    } else {  //! not support pack,not need pack
        packa_group_size = 0;
    }
431

432 433 434 435 436 437
    if (no_need_pading) {
        padding = 0;  //! not need  padding
    } else {
        padding = (GROUP * N * IC * IH2 * IW2) *
                  sizeof(param.src_type);  //! for padding
    }
438

439
    packa_size = GROUP * packa_group_size;  //! for packA  size = GROUP * a_size
440
    WorkspaceBundle ws = {nullptr, {}};
441
    auto im2col_kern_param =
442
            get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
443

444
    if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
445 446 447 448 449 450 451 452 453
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
                midout_iv("ConvBiasImpl::AlgoIm2col::get_bundle_dft"_hash)) {
            Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
            ws = defaultkern.get_thread_bundle(param, im2col_kern_param,
                                               m_matmul_algo, ohw_tile_size,
                                               oc_tile_size);
        }
        MIDOUT_END();
454
    } else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) {
455 456 457 458 459 460 461 462 463
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
                midout_iv("ConvBiasImpl::AlgoIm2col::get_bundle_packa"_hash)) {
            Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
            ws = onlypackakern.get_thread_bundle(param, im2col_kern_param,
                                                 m_matmul_algo, ohw_tile_size,
                                                 oc_tile_size);
        }
        MIDOUT_END();
464
    } else {
465 466 467 468 469 470 471 472 473
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
                midout_iv("ConvBiasImpl::AlgoIm2col::get_bundle_other"_hash)) {
            Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
            ws = nopackkern.get_thread_bundle(param, im2col_kern_param,
                                              m_matmul_algo, ohw_tile_size,
                                              oc_tile_size);
        }
        MIDOUT_END();
474
    }
475

476 477
    return {nullptr,
            {padding, packa_size, ws.total_size_in_bytes() * nr_threads}};
478 479 480 481 482 483 484 485 486 487 488 489
}

size_t ConvBiasImpl::AlgoIm2col::get_workspace(
        ConvBiasImpl*, const NCBKernSizeParam& p) const {
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 0) {
        return get_bundle(p).total_size_in_bytes();
    }
    MIDOUT_END();
    return 0;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
490
        ConvBiasImpl*, const NCBKernSizeParam& param) const {
491
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) {
492 493 494 495 496 497 498
        UNPACK_CONV_F32_NCB_KERN_SIZES(param);
        MEGDNN_MARK_USED_VAR(SH);
        MEGDNN_MARK_USED_VAR(SW);
        MEGDNN_MARK_USED_VAR(IH);
        MEGDNN_MARK_USED_VAR(IW);
        MEGDNN_MARK_USED_VAR(FH);
        MEGDNN_MARK_USED_VAR(FW);
499
        size_t oc_tile_size = 0, ohw_tile_size = 0;
500
        size_t ohw = OH * OW;
501 502
        size_t GROUP = param.filter_meta.group;
        WorkspaceBundle bundle = get_bundle(param);
503
        WorkspaceBundle bundle_thread = {nullptr, {}};
504
        bool need_padding = (PH != 0 || PW != 0);
505 506 507 508 509

        fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
                m_matmul_algo->matmul_description();

        Pack_Mode packmode = mdesc.packmode;
510 511 512
        bool default_pack = packmode == Pack_Mode::DEFAULT;
        bool no_pack = packmode == Pack_Mode::NO_PACK;
        bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
513

514 515 516
        choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
                            mdesc.innerblocksize.m, mdesc.innerblocksize.n,
                            mdesc.packmode);
517 518 519

        size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
        size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
520
        size_t packa_parallel_times = 0;
521
        size_t pack_oc_size = pack_size(param.filter_meta.format);
522

523
        if (only_packA) {
524
            packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
525
        } else if (default_pack) {
526
            packa_parallel_times = div_ceil<size_t>(OC, mdesc.innerblocksize.m);
527 528 529
        }

        auto matmul_param = get_matmul_kern_param(
530
                param, ohw_tile_size, only_packA ? oc_tile_size : OC);
531
        if (mdesc.packmode == Pack_Mode::DEFAULT) {
532 533
            Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
            bundle_thread = defaultkern.get_thread_bundle(
534 535
                    param, matmul_param, m_matmul_algo, ohw_tile_size,
                    oc_tile_size);
536
        } else if (mdesc.packmode == Pack_Mode::ONLY_PACKA) {
537 538
            Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
            bundle_thread = onlypackakern.get_thread_bundle(
539 540
                    param, matmul_param, m_matmul_algo, ohw_tile_size,
                    oc_tile_size);
541 542 543
        } else {
            Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
            bundle_thread = nopackkern.get_thread_bundle(
544 545
                    param, matmul_param, m_matmul_algo, ohw_tile_size,
                    oc_tile_size);
546
        }
547

548 549 550 551 552 553 554
        StrategyParam strategyparam;
        strategyparam.ohw = ohw;
        strategyparam.is_dst_8bit =
                (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
                 param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
                (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
                 param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
555
        strategyparam.is_ohw_size_bigger = (ohw_tile_size >= ohw);
556 557
        strategyparam.skip_copy_dst =
                strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit;
558
        strategyparam.oc_tile_size = oc_tile_size;
559
        strategyparam.pack_oc_size = pack_oc_size;
560

561 562 563 564
        SmallVector<ConvBiasImpl::NCBKern> ret_kern;
        MIDOUT_BEGIN(
                megdnn_fallback_im2col,
                midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) {
565 566 567 568
            StrategyBase* im2colstrategy =
                    Factory::get_im2col_strategy(param, m_matmul_algo);
            auto kern_padding = [bundle, im2colstrategy,
                                 pack_oc_size = pack_oc_size](
569 570
                                        const NCBKernParam& param,
                                        const NCBKernIndex& ncb_index) {
571 572
                copy_padding_kern(bundle, param, ncb_index, im2colstrategy,
                                  pack_oc_size);
573 574 575
            };

            auto kern_packA = [bundle, matmul_algo = m_matmul_algo,
576
                               matmul_param, im2colstrategy,
577 578 579
                               pack_oc_size = pack_oc_size,
                               mdesc = mdesc](const NCBKernParam& param,
                                              const NCBKernIndex& ncb_index) {
580
                packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index,
581
                           im2colstrategy, mdesc, pack_oc_size);
582 583 584 585 586
            };
            if (default_pack) {
                auto kern_compute_default =
                        [bundle, bundle_thread, matmul_param,
                         matmul_algo = m_matmul_algo,
587
                         ohw_tile_size = ohw_tile_size,
588
                         strategyparam = strategyparam, matmul_desc = mdesc,
589 590 591 592
                         im2colstrategy](const NCBKernParam& param,
                                         const NCBKernIndex& ncb_index) {
                            Im2colKerns<Pack_Mode::DEFAULT>::kerns(
                                    bundle, bundle_thread, param, matmul_param,
593 594
                                    matmul_algo, matmul_desc, strategyparam,
                                    ncb_index, ohw_tile_size, im2colstrategy);
595 596 597 598
                        };
                ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});

                if (need_padding) {
599 600
                    ret_kern.push_back({kern_padding,
                                        {param.n, GROUP, IC / pack_oc_size}});
601 602 603 604 605 606 607 608 609
                }
                ret_kern.push_back(
                        {kern_compute_default,
                         {N, GROUP, ohw_parallel_times, oc_parallel_times}});
            } else if (only_packA) {
                auto kern_compute_onlypackA =
                        [bundle, bundle_thread, matmul_param,
                         matmul_algo = m_matmul_algo,
                         strategyparam = strategyparam,
610
                         ohw_tile_size = ohw_tile_size, matmul_desc = mdesc,
611 612 613 614
                         im2colstrategy](const NCBKernParam& param,
                                         const NCBKernIndex& ncb_index) {
                            Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns(
                                    bundle, bundle_thread, param, matmul_param,
615 616
                                    matmul_algo, matmul_desc, strategyparam,
                                    ncb_index, ohw_tile_size, im2colstrategy);
617 618 619 620 621 622 623 624 625 626 627 628 629
                        };
                ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
                if (need_padding) {
                    ret_kern.push_back({kern_padding, {param.n, GROUP, IC}});
                }
                ret_kern.push_back(
                        {kern_compute_onlypackA,
                         {N, GROUP, ohw_parallel_times, oc_parallel_times}});
            } else if (no_pack) {
                auto kern_compute_nopack =
                        [bundle, bundle_thread, matmul_param,
                         matmul_algo = m_matmul_algo,
                         strategyparam = strategyparam,
630
                         ohw_tile_size = ohw_tile_size, matmul_desc = mdesc,
631 632 633 634
                         im2colstrategy](const NCBKernParam& param,
                                         const NCBKernIndex& ncb_index) {
                            Im2colKerns<Pack_Mode::NO_PACK>::kerns(
                                    bundle, bundle_thread, param, matmul_param,
635 636
                                    matmul_algo, matmul_desc, strategyparam,
                                    ncb_index, ohw_tile_size, im2colstrategy);
637 638 639 640 641 642 643 644 645 646 647 648 649
                        };

                if (need_padding) {
                    ret_kern.push_back({kern_padding, {param.n, GROUP, IC}});
                }
                ret_kern.push_back(
                        {kern_compute_nopack,
                         {N, GROUP, ohw_parallel_times, oc_parallel_times}});
            }
            return ret_kern;
        }
        MIDOUT_END();
        return {};
650 651 652 653 654 655 656 657 658
    }
    MIDOUT_END();
    return {};
}

bool ConvBiasImpl::AlgoIm2col::usable(
        ConvBiasImpl* opr, const NCBKernSizeParam& param,
        AlgoSelectionStrategy /*algo_selection_strategy*/) const {
    MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) {
659 660 661 662 663 664
        if (opr->param().format != param::ConvBias::Format::NCHW &&
            opr->param().format != param::ConvBias::Format::NCHW44_DOT &&
            opr->param().format != param::ConvBias::Format::NCHW44) {
            return false;
        }

665 666 667 668 669
        if(param.src_type.enumv() != param.filter_type.enumv()) {
            return false;
        }

        if (param.src_type.enumv() != DTypeEnum::Int8 &&
670 671 672 673 674 675 676 677
            param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
            param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
            param.src_type.enumv() != DTypeEnum::Float16 &&
#endif
            param.src_type.enumv() != DTypeEnum::Float32) {
            return false;
        }
678
        //! make sure 8x8x16 and 8x8x32 biasmode is  nobias and nonlineMode is
679 680
        //! identity otherwise return false mean that 8x8x32 and 8x8x16 not
        //! support PostProcess
681 682 683 684 685 686 687
        if (param.dst_type.enumv() == DTypeEnum::Int16 ||
            param.dst_type.enumv() == DTypeEnum::Int32 ||
            param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
            if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
                param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
                return false;
            }
688
        }
689 690
        fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
                m_matmul_algo->matmul_description();
691 692
        if (opr->param().format == param::ConvBias::Format::NCHW44 ||
            opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
693
            //! current NCHW44 im2col only support DEFAULT mode matmul
694
            if (mdesc.packmode != Pack_Mode::DEFAULT) {
695 696 697 698 699 700 701
                return false;
                //! nchw44 hybird mode and channel wise is not support
            } else if (param.filter_meta.icpg < 4_z ||
                       param.filter_meta.icpg == 1 ||
                       param.filter_meta.ocpg == 1) {
                return false;
            }
702 703
        }

704
        size_t oc_tile_size = 0, ohw_tile_size = 0;
705 706 707
        choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
                            mdesc.innerblocksize.m, mdesc.innerblocksize.n,
                            m_matmul_algo->packmode());
708
        fallback::MatrixMulImpl::KernSizeParam matmul_param =
709
                get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
710 711
        bool matmulusable = m_matmul_algo->usable(matmul_param);
        return matmulusable &&
712 713
               (!(param.filter_meta.spatial[0] ==
                          param.filter_meta.spatial[1] &&
714
                  param.filter_meta.spatial[0] == 1 &&
715 716
                  param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
                  param.filter_meta.stride[0] == 1)) &&
717 718 719 720 721 722 723 724 725 726
               (param.filter_meta.dilation[0] ==
                        param.filter_meta.dilation[1] &&
                param.filter_meta.dilation[0] == 1) &&
               param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
    }
    MIDOUT_END();
    return false;
}

// vim: syntax=cpp.doxygen