opr_impl.cpp 28.6 KB
Newer Older
1
/**
2
 g * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3
 *
4
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
8 9
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
10
 */
M
Megvii Engine Team 已提交
11
#include "src/fallback/conv_bias/opr_impl.h"
12 13 14 15 16
#include "src/common/algo_chooser.h"
#include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/algos.h"
17
#include "src/fallback/conv_bias/conv1x1/algos.h"
18
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
19
#include "src/fallback/conv_bias/im2col/algos.h"
20
#include "src/fallback/convolution/opr_impl.h"
21 22 23
#include "src/naive/convolution/algorithms.h"
#include "src/naive/handle.h"

24 25 26 27 28 29 30 31
#if MEGDNN_X86
#include "src/x86/conv_bias/opr_impl.h"
#elif MEGDNN_AARCH64
#include "src/aarch64/conv_bias/opr_impl.h"
#elif MEGDNN_ARMV7
#include "src/armv7/conv_bias/opr_impl.h"
#endif

32 33 34 35 36
#include <cstring>

using namespace megdnn;
using namespace fallback;

37
size_t megdnn::fallback::pack_size(param::ConvBias::Format format) {
38
    switch (format) {
39
        case param::ConvBias::Format::NCHW44:
40
        case param::ConvBias::Format::NCHW44_DOT:
41 42 43 44 45 46 47 48 49
        case param::ConvBias::Format::NCHW4:
            return 4_z;
        case param::ConvBias::Format::NCHW88:
            return 8_z;
        default:
            return 1_z;
    }
}

50 51 52 53 54 55 56 57
namespace {
template <typename T>
void incr_ptr(T*& dst, ptrdiff_t delta) {
    dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
}

}  // namespace

58 59 60 61 62 63 64 65 66 67 68 69 70
#if MEGDNN_X86
#define SKIP_GEMV()
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
//! fallback to naive implementation, which may cause performance very low, so
//! here we just enable im2col for gemv in x86 backend.
//! FIXME: remove it when we add direct conv support for int8x8x16
#else
#define SKIP_GEMV()                                                            \
    if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \
        continue;                                                              \
    }
#endif

71 72 73
class ConvBiasImpl::AlgoPack : NonCopyableObj {
    AlgoNaive algo_naive;
    SmallVector<std::unique_ptr<AlgoBase>> refhold;
74 75
    SmallVector<AlgoBase*> m_all_algos;
    AlgoBase::Mapper m_all_algos_map;
76 77 78

public:
    AlgoPack() {
79
        refhold.emplace_back(new AlgoConv1x1Gemv());
80
        m_all_algos.emplace_back(refhold.back().get());
81

82 83
        static CpuOprDelegationStorage<> storage;
        auto matmul_opr = storage.get<MatrixMul>();
84 85
        auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
                                      ->get_all_packed_algo();
86
        for (auto&& algo : matmul_algos) {
87 88 89 90 91 92
#if MEGDNN_X86
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
//! fallback to naive implementation, which may cause performance very low, so
//! here we just enable im2col for gemv in x86 backend.
//! FIXME: remove it when we add direct conv support for int8x8x16
#else
M
Megvii Engine Team 已提交
93
            if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
94 95
                continue;
            }
96 97
#endif

98 99 100 101
//! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci
//! test. so we just disable all im2col and conv1x1 in riscv64
//! FIXME: remove it when impl postprocess for riscv64
#if !MEGDNN_RISCV64
102 103
            for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) {
                refhold.emplace_back(new AlgoIm2col(
M
Megvii Engine Team 已提交
104
                        static_cast<MatrixMulImpl::AlgoBase*>(algo), ohw_tile_size));
105
                m_all_algos.emplace_back(refhold.back().get());
106
            }
107
            for (size_t oc_tile_size : {48, 24}) {
108
                refhold.emplace_back(new AlgoConv1x1(
M
Megvii Engine Team 已提交
109
                        static_cast<MatrixMulImpl::AlgoBase*>(algo), oc_tile_size));
110
                m_all_algos.emplace_back(refhold.back().get());
111
            }
112 113
#endif

114
#if 0
115 116 117 118 119
        //! As these algos maybe very slow, it will make fastrun search slow, so
        //! we disable it, but for the test of strategyhelper, we just keep it.
        //! FIXME: I do not know a better way to do it.
            refhold.emplace_back(new AlgoWinogradF32(
                    static_cast<MatrixMulImpl::AlgoBase*>(algo)));
120
            m_all_algos.emplace_back(refhold.back().get());
121 122
            refhold.emplace_back(new AlgoWinogradF32_4x4(
                    static_cast<MatrixMulImpl::AlgoBase*>(algo)));
123
            m_all_algos.emplace_back(refhold.back().get());
124 125
            refhold.emplace_back(new AlgoWinogradQS8(
                    static_cast<MatrixMulImpl::AlgoBase*>(algo)));
126
            m_all_algos.emplace_back(refhold.back().get());
127 128
            refhold.emplace_back(new AlgoWinogradQS8_8x8(
                    static_cast<MatrixMulImpl::AlgoBase*>(algo)));
129
            m_all_algos.emplace_back(refhold.back().get());
130 131
#endif
        }
132 133 134 135 136
        m_all_algos.emplace_back(&algo_naive);

        for (auto&& algo : m_all_algos) {
            m_all_algos_map.emplace(algo->info().desc, algo);
        }
137
    }
138 139
    const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
140 141
};

142 143 144 145 146 147 148
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
    static AlgoPack algo_pack;
    return algo_pack;
}

SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() {
    return algo_pack().all_algos();
149
}
150 151 152

SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
        ConvAlgoTypePack target_type) {
M
Megvii Engine Team 已提交
153 154 155
    megdnn_assert(
            nr_type_contain(target_type.data_type),
            "ConvBias algo selection only support one type");
156
    SmallVector<ConvBiasImpl::AlgoBase*> algos;
157
    for (auto&& algo : get_all_packed_algo()) {
158 159 160 161 162 163 164 165 166
        auto algo_type = algo->get_algo_type();
        if (contain_data_type(algo_type.data_type, target_type.data_type) &&
            algo_type.algo_category == target_type.algo_category) {
            algos.push_back(algo);
        }
    }
    return algos;
}

167 168 169
bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) {
    return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
}
170

M
Megvii Engine Team 已提交
171 172 173 174 175 176 177 178 179 180 181
#define NCB_ALGO_FUNC(name, algo, param) static_cast<AlgoBase*>(algo)->name(param)

void ConvBiasImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_in z, _megdnn_tensor_out dst,
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
    check_exec(
            src.layout, filter.layout, bias.layout, z.layout, dst.layout,
            workspace.size, preprocessed_filter);
    auto fparam =
            make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter);
182
    auto&& algo = get_algorithm(fparam, workspace.size);
183
    if (!is_naive_algo(algo) &&
184
        NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
185 186
        exec_with_ncb_kern(fparam, algo);
    } else {
M
Megvii Engine Team 已提交
187 188
        naive::ConvBiasForwardImpl::exec(
                src, filter, bias, z, dst, preprocessed_filter, workspace);
189 190 191
    }
}

M
Megvii Engine Team 已提交
192 193 194 195 196
void ConvBiasImpl::exec_preprocess(
        const TensorLayout& src_layout, _megdnn_tensor_in filter,
        _megdnn_tensor_in bias, const TensorLayout& z_layout,
        const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
        _megdnn_workspace workspace) {
197 198 199
    //! exec_preprocess currently only support preprocess weights and bias
    //! before exec, src/dst/z will be ignored, just set to nullptr
    TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
M
Megvii Engine Team 已提交
200 201
    auto fparam =
            make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter);
202
    //! should not pass workspace_size limit otherwise can not find match algo
203 204
    auto&& algo = get_algorithm(fparam);
    if (!is_naive_algo(algo) &&
M
Megvii Engine Team 已提交
205
        NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= workspace.size) {
206 207 208
        exec_preprocess_with_ncb_kern(fparam, algo);
    } else {
        naive::ConvBiasForwardImpl::exec_preprocess(
M
Megvii Engine Team 已提交
209 210
                src_layout, filter, bias, z_layout, dst_layout, preprocessed_filter,
                workspace);
211 212 213
    }
}

214
size_t ConvBiasImpl::get_workspace_in_bytes(
M
Megvii Engine Team 已提交
215 216
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst,
217
        const PreprocessedFilter* preprocessed_filter) {
218 219
    TensorLayoutArray layouts{src, filter, bias, z, dst};
    HeuristicCache::Key key{this->handle(), this->get_opr_type(),
M
Megvii Engine Team 已提交
220 221
                            layouts.data(), layouts.size(),
                            &this->param(), sizeof(this->param())};
222 223 224 225 226
    auto rst = HeuristicCache::instance().get(key);
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

M
Megvii Engine Team 已提交
227
    auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, preprocessed_filter);
228
    auto&& algo = get_algorithm(fparam);
229
    if (is_naive_algo(algo)) {
230 231
        return naive::ConvBiasForwardImpl::get_workspace_in_bytes(
                src, filter, bias, z, dst, preprocessed_filter);
232
    } else {
233 234 235 236 237
        return NCB_ALGO_FUNC(get_workspace, algo, fparam);
    }
}

size_t ConvBiasImpl::get_preprocess_workspace_in_bytes(
M
Megvii Engine Team 已提交
238 239
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
240
    auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
241
    auto&& algo = get_algorithm(fparam);
242 243 244 245 246 247 248 249 250
    if (is_naive_algo(algo)) {
        return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes(
                src, filter, bias, z, dst);
    } else {
        return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
    }
}

SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout(
M
Megvii Engine Team 已提交
251 252
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
253
    auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
254
    auto&& algo = get_algorithm(fparam);
255 256 257 258 259
    if (is_naive_algo(algo)) {
        return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout(
                src, filter, bias, z, dst);
    } else {
        return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
260 261 262 263
    }
}

std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms(
M
Megvii Engine Team 已提交
264 265
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
266
    auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
267 268
    auto ret = get_all_algorithms_with_ncb(fparam);
    if (ret.empty()) {
M
Megvii Engine Team 已提交
269 270
        return naive::ConvBiasForwardImpl::get_all_algorithms_safe(
                src, filter, bias, z, dst);
271 272 273
    }
    return ret;
}
274
std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe(
M
Megvii Engine Team 已提交
275 276 277
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
    auto ret_safe = ConvBiasImpl::get_all_algorithms(src, filter, bias, z, dst);
278 279
    return ret_safe;
}
280 281

ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
M
Megvii Engine Team 已提交
282 283 284
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst, size_t workspace_limit_in_bytes,
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
285
    auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
286
    auto result = get_algorithm_heuristic_with_ncb(
287
            fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
288 289
    if (result == nullptr) {
        result = naive::ConvBiasForwardImpl::get_algorithm_heuristic(
M
Megvii Engine Team 已提交
290 291
                src, filter, bias, z, dst, workspace_limit_in_bytes, positive_attr,
                negative_attr);
292 293 294 295
    }
    return result;
}

296 297
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
        const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
M
Megvii Engine Team 已提交
298
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
299 300 301
    if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
        return nullptr;
    }
302 303 304 305 306 307
    auto algo_data_type = param.deduce_algo_data_type();
    auto suggest_category_order = suggest_algo_category_order(param);
    for (auto category : suggest_category_order) {
        auto&& origin_algos = select_algo_type({algo_data_type, category});
        ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
        for (auto i : origin_algos) {
308
            bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
309 310
                    param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
                    negative_attr);
M
Megvii Engine Team 已提交
311 312
            if (usable_attribute && static_cast<AlgoBase*>(i)->get_workspace(param) <=
                                            workspace_limit_in_bytes) {
313 314 315 316 317 318 319 320 321 322 323 324 325
                //! store the first usable algo if no prefer algo, choose it as
                //! the target algo
                if (!heuristic_algo) {
                    heuristic_algo = i;
                }
                //! choose the first prefer algo
                if (i->is_preferred(param)) {
                    return i;
                }
            }
        }
        if (heuristic_algo) {
            return heuristic_algo;
326 327 328 329 330
        }
    }
    return nullptr;
}

331
ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
M
Megvii Engine Team 已提交
332 333
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
        const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) {
334
    auto safe_u32 = [](size_t v) -> uint32_t {
M
Megvii Engine Team 已提交
335 336
        megdnn_assert(
                v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
337 338 339 340 341 342
        return v;
    };
    size_t spatial_pos;
    if (param().format == Param::Format::NCHW88 ||
        param().format == Param::Format::NCHW8 ||
        param().format == Param::Format::NCHW4 ||
343
        param().format == Param::Format::NCHW44 ||
344
        param().format == Param::Format::NCHW44_DOT ||
345 346
        param().format == Param::Format::NCHW ||
        param().format == Param::Format::NCHW32 ||
347
        param().format == Param::Format::NCHW64) {
348
        spatial_pos = 2;
M
Megvii Engine Team 已提交
349 350 351
    } else if (
            param().format == Param::Format::NHWC ||
            param().format == Param::Format::NHWCD4) {
352 353
        spatial_pos = 1;
    } else {
M
Megvii Engine Team 已提交
354
        megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
355 356 357 358 359 360 361 362 363 364 365 366
    }
    BiasMode bias_mode;
    if (bias.ndim == 0) {
        bias_mode = BiasMode::NO_BIAS;
    } else if (bias.eq_shape(dst)) {
        bias_mode = BiasMode::BIAS;
    } else {
        //! just check the ndim, the detail shape check is in check_exec
        megdnn_assert(bias.ndim == dst.ndim);
        bias_mode = BiasMode::BROADCAST_CHANNEL_BIAS;
    }

M
Megvii Engine Team 已提交
367 368 369 370
    static_assert(
            sizeof(CanonizedFilterMeta) == sizeof(ConvolutionImpl::CanonizedFilterMeta),
            "sizeof CanonizedFilterMeta in convolution and conv_bias "
            "should be equal");
371 372
    auto&& fm = check_layout_fwd(src, filter, dst);
    auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm);
373

374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
    size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
                                ->megcore_dispatcher()
                                ->nr_threads();
    return {{safe_u32(src[0]),
             {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
             {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
             conv_fm,
             src.dtype,
             filter.dtype,
             dst.dtype,
             src.stride[0],
             dst.stride[0],
             {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
             {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
             param().compute_mode,
389 390 391
             nr_threads,
             reinterpret_cast<const ConvolutionForward::PreprocessedFilter*>(
                     preprocessed_filter)},
392 393 394 395 396 397 398 399
            bias.dtype,
            bias.stride[0],
            bias_mode,
            param().nonlineMode};
}

ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
400 401
        _megdnn_tensor_out dst, _megdnn_workspace workspace,
        const PreprocessedFilter* preprocessed_filter) {
402
    NCBKernParam ret;
M
Megvii Engine Team 已提交
403 404
    static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
            src.layout, filter.layout, bias.layout, dst.layout, preprocessed_filter);
405 406 407 408
    ret.src_ptr = src.get_ref_ptr();
    ret.filter_ptr = filter.get_ref_ptr();
    ret.bias_ptr = bias.get_ref_ptr();
    ret.dst_ptr = dst.get_ref_ptr();
409 410 411 412 413
    ret.workspace_ptr = workspace.raw_ptr;
    ret.workspace_size = workspace.size;
    return ret;
}

M
Megvii Engine Team 已提交
414 415
void ConvBiasImpl::exec_with_ncb_kern(
        const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
416
    auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
417
    for (auto&& kernel : ncb_kerns) {
418
        auto run = [kernel, param](size_t index, size_t thread_id) {
419
            CpuNDRange ndrange_id(kernel.global_size, index);
420
            kernel.kern(param, {thread_id, ndrange_id});
421 422 423 424 425 426
        };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
                run, kernel.global_size.total_size());
    }
}

427 428
void ConvBiasImpl::exec_preprocess_with_ncb_kern(
        const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
429
    auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
430 431 432 433 434 435 436 437
    for (auto&& kernel : ncb_kerns) {
        auto run = [kernel, param](size_t index, size_t thread_id) {
            CpuNDRange ndrange_id(kernel.global_size, index);
            kernel.kern(param, {thread_id, ndrange_id});
        };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
                run, kernel.global_size.total_size());
    }
438 439 440 441 442 443 444
}

std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
        const NCBKernSizeParam& param) {
    MEGDNN_MARK_USED_VAR(param);
    std::vector<Algorithm*> algos;
    std::vector<Algorithm*> prefer_algos;
445
    for (auto&& algo : get_all_packed_algo()) {
446 447
        if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
            if (algo->is_preferred(param)) {
448 449 450 451 452 453 454 455 456 457 458
                prefer_algos.push_back(algo);
            } else {
                algos.push_back(algo);
            }
        }
    }
    //! Prefer algo inserted from begin
    algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end());
    return algos;
}

459 460
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(
        const AlgorithmDesc& desc) {
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
    if (!desc.valid()) {
        return nullptr;
    } else {
        switch (desc.handle_type) {
            case Handle::HandleType::FALLBACK: {
                const auto& map = algo_pack().all_algos_map();
                megdnn_assert(map.find(desc) != map.end());
                return map.at(desc);
            };

#if MEGDNN_X86
            case Handle::HandleType::X86:
                return x86::ConvBiasImpl::get_algo_from_desc(desc);
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
            case Handle::HandleType::ARM_COMMON:
                return arm_common::ConvBiasImpl::get_algo_from_desc(desc);
#if MEGDNN_AARCH64
            case Handle::HandleType::AARCH64:
                return aarch64::ConvBiasImpl::get_algo_from_desc(desc);
#else
            case Handle::HandleType::ARMV7:
                return armv7::ConvBiasImpl::get_algo_from_desc(desc);
#endif
#endif
            case Handle::HandleType::NAIVE: {
                auto algo = static_cast<naive::HandleImpl*>(handle())
                                    ->default_conv_bias_fwd_algo();
                megdnn_assert(algo->info().desc == desc);
                return algo;
            }
            default:
                megdnn_throw("Unknown handle type");
                return nullptr;
        }
    }
}

498 499
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
        const NCBKernSizeParam& param, size_t workspace_size) {
500 501 502
    if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
        return nullptr;
    }
503
    if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
504
        return algo;
505 506 507
    }
    if (!m_prev_selected_algo ||
        memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
508
        m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
M
Megvii Engine Team 已提交
509
                param, workspace_size, AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
510 511 512 513 514
        m_prev_selected_algo_sizep = param;
    }
    return m_prev_selected_algo;
}

515 516 517 518 519 520 521
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
        const NCBKernSizeParam& param) const {
    auto IC = param.filter_meta.icpg;
    auto OC = param.filter_meta.ocpg;
    auto FH = param.filter_meta.spatial[0];
    auto FW = param.filter_meta.spatial[1];
    //! TODO: now winograd only support in fast-run
522

523 524 525 526 527 528 529 530 531
    //! im2col + matmul
    bool im2col_prefer = (IC >= 32 || OC >= 32);
    //! quantized algo use matmul when direct algo is unusable
    if (param.src_type.category() == DTypeCategory::QUANTIZED) {
        im2col_prefer = is_matmul_quantized_prefer(param);
    }
    //! conv1x1
    im2col_prefer |= (FH == 1 && FW == 1);
    if (im2col_prefer) {
M
Megvii Engine Team 已提交
532
        return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, AlgoCategory::NAIVE};
533
    } else {
M
Megvii Engine Team 已提交
534
        return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, AlgoCategory::NAIVE};
535 536 537
    }
}

538 539 540 541 542
const char* ConvBiasImpl::get_algorithm_set_name() const {
    // fallback version 0
    return "F0";
}

543
namespace megdnn {
544
namespace fallback {
545

546
size_t ConvBiasImpl::NCBKernParam::src_offset(
M
Megvii Engine Team 已提交
547 548
        size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
        size_t group_pack_size, size_t channel_pack_size) const {
549
    size_t batch_offset = batch_id * inp_bs * src_type.size();
M
Megvii Engine Team 已提交
550 551 552 553
    size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg * isz[0] *
                          isz[1] * src_type.size();
    size_t channel_offset =
            channel_pack_size * channel_pack_id * isz[0] * isz[1] * src_type.size();
554
    return (batch_offset + group_offset + channel_offset);
555 556 557
}

template <typename T>
558 559 560 561 562 563 564 565 566 567 568
const T* ConvBiasImpl::NCBKernParam::src(
        size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
        size_t group_pack_size, size_t channel_pack_size) const {
    return reinterpret_cast<T*>(
            reinterpret_cast<ptrdiff_t>(src_ptr.get_ptr()) +
            src_offset(
                    batch_id, group_pack_id, channel_pack_id, group_pack_size,
                    channel_pack_size));
}

size_t ConvBiasImpl::NCBKernParam::filter_offset(
M
Megvii Engine Team 已提交
569
        size_t group_pack_id, size_t pack_group_size) const {
570 571 572
    size_t group_offset = 0_z;
    switch (filter_meta.format) {
        case Param::Format::NCHW: {
573
            group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
574 575 576 577 578 579 580 581 582
                           filter_meta.ocpg * filter_meta.spatial[0] *
                           filter_meta.spatial[1] * filter_type.size();
            break;
        }
        case Param::Format::NCHW88: {
            size_t group = filter_meta.group;
            size_t icpg = filter_meta.icpg;
            size_t ocpg = filter_meta.ocpg;
            //! four format of weight layout
583 584 585
            //! 1. {oc/8, ic/8, fh, fw, 8, 8},
            //! 2. {g, oc/8, ic/8, fh, fw, 8, 8},
            //! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8}
M
Megvii Engine Team 已提交
586 587 588 589 590 591
            megdnn_assert(
                    (icpg % 8 == 0 && ocpg % 8 == 0) ||
                            (group % 8 == 0 && icpg == 1 && ocpg == 1 &&
                             pack_group_size > 1) ||
                            (group == 1 && ocpg % 8 == 0),
                    "The filter shepe is not right of nchw88");
592 593 594 595 596 597
            group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
                           filter_meta.ocpg * filter_meta.spatial[0] *
                           filter_meta.spatial[1] * filter_type.size();

            break;
        }
598
        case Param::Format::NCHW44_DOT:
599 600 601 602 603 604 605
        case Param::Format::NCHW44: {
            size_t group = filter_meta.group;
            size_t icpg = filter_meta.icpg;
            size_t ocpg = filter_meta.ocpg;
            //! four format of weight layout
            //! 1. {oc/4, ic/4, fh, fw, 4, 4},
            //! 2. {g, oc/4, ic/4, fh, fw, 4, 4},
606 607
            //! 3. {g/4, fh, fw, 1, 1, 4},
            //! 4. {oc/4, fh, fw, ic, 4}
M
Megvii Engine Team 已提交
608 609 610 611 612 613
            megdnn_assert(
                    (icpg % 4 == 0 && ocpg % 4 == 0) ||
                            (group % 4 == 0 && icpg == 1 && ocpg == 1 &&
                             pack_group_size > 1) ||
                            (group == 1 && ocpg % 4 == 0),
                    "The filter shepe is not right of nchw44");
614
            group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
615 616 617 618 619 620
                           filter_meta.ocpg * filter_meta.spatial[0] *
                           filter_meta.spatial[1] * filter_type.size();

            break;
        }
        default:
621
            megdnn_assert(0, "other filter format is not support yet");
622
    }
623
    return group_offset;
624 625 626
}

template <typename T>
627 628 629 630 631 632 633 634
const T* ConvBiasImpl::NCBKernParam::filter(
        size_t group_pack_id, size_t pack_group_size) const {
    size_t group_offset = filter_offset(group_pack_id, pack_group_size);
    return reinterpret_cast<T*>(
            reinterpret_cast<ptrdiff_t>(filter_ptr.get_ptr()) + group_offset);
}

size_t ConvBiasImpl::NCBKernParam::bias_offset(
M
Megvii Engine Team 已提交
635 636
        size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
        size_t group_pack_size, size_t channel_pack_size) const {
637 638
    size_t batch_offset = 0_z;
    size_t group_offset = 0_z;
639
    size_t channel_offset = 0_z;
640 641
    if (bias_mode == BiasMode::BIAS) {
        batch_offset = batch_id * bias_bs * bias_type.size();
M
Megvii Engine Team 已提交
642 643
        group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] *
                       osz[1] * bias_type.size();
644 645
        channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] *
                         bias_type.size();
646
    } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
M
Megvii Engine Team 已提交
647 648
        group_offset =
                group_pack_size * group_pack_id * filter_meta.ocpg * bias_type.size();
649
        channel_offset = channel_pack_size * channel_pack_id * bias_type.size();
650
    }
651
    return (batch_offset + group_offset + channel_offset);
652 653 654
}

template <typename T>
655 656 657 658 659 660 661 662 663 664 665
const T* ConvBiasImpl::NCBKernParam::bias(
        size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
        size_t group_pack_size, size_t channel_pack_size) const {
    return reinterpret_cast<T*>(
            reinterpret_cast<ptrdiff_t>(bias_ptr.get_ptr()) +
            bias_offset(
                    batch_id, group_pack_id, channel_pack_id, group_pack_size,
                    channel_pack_size));
}

size_t ConvBiasImpl::NCBKernParam::dst_offset(
M
Megvii Engine Team 已提交
666 667
        size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
        size_t group_pack_size, size_t channel_pack_size) const {
668
    size_t batch_offset = batch_id * out_bs * dst_type.size();
M
Megvii Engine Team 已提交
669 670 671 672
    size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] *
                          osz[1] * dst_type.size();
    size_t channel_offset =
            channel_pack_size * channel_pack_id * osz[0] * osz[1] * dst_type.size();
673 674 675 676 677 678 679
    return (batch_offset + group_offset + channel_offset);
}

template <typename T>
T* ConvBiasImpl::NCBKernParam::dst(
        size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
        size_t group_pack_size, size_t channel_pack_size) const {
M
Megvii Engine Team 已提交
680
    return reinterpret_cast<T*>(
681 682 683 684
            reinterpret_cast<ptrdiff_t>(dst_ptr.get_ptr()) +
            dst_offset(
                    batch_id, group_pack_id, channel_pack_id, group_pack_size,
                    channel_pack_size));
685 686
}

687 688 689 690 691 692 693 694 695 696 697 698
#define INST(T)                                                      \
    template const T* ConvBiasImpl::NCBKernParam::src<T>(            \
            size_t batch_id, size_t group_id, size_t channel_id,     \
            size_t group_pack_size, size_t channel_pack_size) const; \
    template const T* ConvBiasImpl::NCBKernParam::bias<T>(           \
            size_t batch_id, size_t group_id, size_t channel_id,     \
            size_t group_pack_size, size_t channel_pack_size) const; \
    template const T* ConvBiasImpl::NCBKernParam::filter<T>(         \
            size_t group_id, size_t group_pack_size) const;          \
    template T* ConvBiasImpl::NCBKernParam::dst<T>(                  \
            size_t batch_id, size_t group_id, size_t channel_id,     \
            size_t group_pack_size, size_t channel_pack_size) const;
699 700 701 702

#define INST_DT(d) INST(DTypeTrait<d>::ctype)

MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT)
703
INST(void)
704 705 706 707 708
#undef INST
#undef INST_DT
}  // namespace fallback
}  // namespace megdnn

709
// vim: syntax=cpp.doxygen