opr_impl.cpp 32.0 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/fallback/convolution/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12
 */

M
Megvii Engine Team 已提交
13
#include "src/fallback/convolution/opr_impl.h"
14 15 16 17 18 19 20 21 22 23 24
#include "src/common/algo_chooser.h"
#include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/fallback/convolution/algos.h"
#include "src/fallback/convolution/run_conv.h"
#include "src/naive/convolution/helper.h"
#include "src/naive/handle.h"

#include "midout.h"

25 26 27 28
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#include "src/arm_common/convolution/opr_impl.h"
#endif

29
#include <cstring>
30
#include <unordered_map>
31 32 33 34 35 36 37 38 39 40 41

MIDOUT_DECL(megdnn_fb_convbwd_float)

using namespace megdnn;
using namespace fallback;

namespace {
template <typename T>
void incr_ptr(T*& dst, ptrdiff_t delta) {
    dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
}
42

43 44 45 46 47 48
}  // namespace

class ConvolutionImpl::AlgoPack : NonCopyableObj {
    AlgoFallback algo_fallback;
    AlgoNaive algo_naive;
    SmallVector<std::unique_ptr<AlgoBase>> refhold;
49 50
    SmallVector<AlgoBase*> m_all_algos;
    AlgoBase::Mapper m_all_algos_map;
M
Megvii Engine Team 已提交
51

52 53 54 55 56
public:
    AlgoPack() {
        static CpuOprDelegationStorage<1> storage;
        auto conv_bias_opr = storage.get<ConvBias, 0>();
        auto&& conv_bias_algo =
57
                static_cast<ConvBiasImpl*>(conv_bias_opr)->get_all_packed_algo();
58 59
        for (auto&& algorithm : conv_bias_algo) {
            // fallback algo
60
            refhold.emplace_back(new AlgoDefault(algorithm));
61
            m_all_algos.emplace_back(refhold.back().get());
62 63
        }

64 65 66 67 68 69
        m_all_algos.emplace_back(&algo_fallback);
        m_all_algos.emplace_back(&algo_naive);

        for (auto&& algo : m_all_algos) {
            m_all_algos_map.emplace(algo->info().desc, algo);
        }
70
    }
71 72 73

    const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
74 75
};

76 77 78 79 80 81 82
const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() {
    static AlgoPack algo_pack;
    return algo_pack;
}

SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::get_all_packed_algo() {
    return algo_pack().all_algos();
83
}
84 85 86

SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
        ConvAlgoTypePack target_type) {
M
Megvii Engine Team 已提交
87 88 89
    megdnn_assert(
            nr_type_contain(target_type.data_type),
            "ConvBias algo selection only support one type");
90
    SmallVector<ConvolutionImpl::AlgoBase*> algos;
91
    for (auto&& algo : get_all_packed_algo()) {
92 93 94 95 96 97 98 99 100
        auto algo_type = algo->get_algo_type();
        if (contain_data_type(algo_type.data_type, target_type.data_type) &&
            algo_type.algo_category == target_type.algo_category) {
            algos.push_back(algo);
        }
    }
    return algos;
}

101 102 103
bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
    return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
}
104

M
Megvii Engine Team 已提交
105
#define NCB_ALGO_FUNC(name, algo, param) static_cast<AlgoBase*>(algo)->name(param)
106

M
Megvii Engine Team 已提交
107 108 109 110
void ConvolutionImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
    auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace);
111
    auto&& algo = get_algorithm(fparam, workspace.size);
112
    if (!is_naive_algo(algo) &&
113
        NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
114 115
        exec_with_ncb_kern(fparam, algo);
    } else {
M
Megvii Engine Team 已提交
116 117
        naive::ConvolutionForwardImpl::exec(
                src, filter, dst, preprocessed_filter, workspace);
118 119 120
    }
}

M
Megvii Engine Team 已提交
121 122 123 124
void ConvolutionImpl::exec_preprocess(
        const TensorLayout& src_layout, _megdnn_tensor_in filter,
        const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
        _megdnn_workspace workspace) {
125 126 127
    //! exec_preprocess currently only support preprocess weights before exec,
    //! src/dst will be ignored, just set to nullptr
    TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
M
Megvii Engine Team 已提交
128
    auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace);
129 130

    //! should not pass workspace_size limit otherwise can not find match algo
131 132
    auto&& algo = get_algorithm(fparam);
    if (!is_naive_algo(algo) &&
M
Megvii Engine Team 已提交
133
        NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= workspace.size) {
134 135 136 137 138 139 140
        exec_preprocess_with_ncb_kern(fparam, algo);
    } else {
        naive::ConvolutionForwardImpl::exec_preprocess(
                src_layout, filter, dst_layout, preprocessed_filter, workspace);
    }
}

141
size_t ConvolutionImpl::get_workspace_in_bytes(
M
Megvii Engine Team 已提交
142
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
143
        const PreprocessedFilter* preprocessed_filter) {
144 145
    TensorLayoutArray layouts{src, filter, dst};
    HeuristicCache::Key key{this->handle(), this->get_opr_type(),
M
Megvii Engine Team 已提交
146 147
                            layouts.data(), layouts.size(),
                            &this->param(), sizeof(this->param())};
148 149 150 151 152
    auto rst = HeuristicCache::instance().get(key);
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

M
Megvii Engine Team 已提交
153
    auto fparam = make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
154
    auto&& algo = get_algorithm(fparam);
155 156
    if (is_naive_algo(algo)) {
        return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
157
                src, filter, dst, preprocessed_filter);
158
    } else {
159
        return NCB_ALGO_FUNC(get_workspace, algo, fparam);
160 161 162 163
    }
}

size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
M
Megvii Engine Team 已提交
164
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
165
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
166
    auto&& algo = get_algorithm(fparam);
167 168 169 170
    if (is_naive_algo(algo)) {
        return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
                src, filter, dst);
    } else {
171
        return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
172 173 174 175
    }
}

SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
M
Megvii Engine Team 已提交
176
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
177
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
178
    auto&& algo = get_algorithm(fparam);
179 180 181 182
    if (is_naive_algo(algo)) {
        return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
                src, filter, dst);
    } else {
183
        return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
184 185 186 187
    }
}

std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
M
Megvii Engine Team 已提交
188
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
189
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
190 191
    auto ret = get_all_algorithms_with_ncb(fparam);
    if (ret.empty()) {
M
Megvii Engine Team 已提交
192
        return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, dst);
193 194 195 196
    }
    return ret;
}

197
std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe(
M
Megvii Engine Team 已提交
198 199
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
    auto ret_safe = ConvolutionImpl::get_all_algorithms(src, filter, dst);
200 201 202
    return ret_safe;
}

203
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
M
Megvii Engine Team 已提交
204 205
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
206
        const AlgoAttribute& negative_attr) {
207
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
208
    auto result = get_algorithm_heuristic_with_ncb(
209
            fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
210 211
    if (result == nullptr) {
        result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
212 213
                src, filter, dst, workspace_limit_in_bytes, positive_attr,
                negative_attr);
214 215 216 217 218
    }
    return result;
}

ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
M
Megvii Engine Team 已提交
219
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
220
        const PreprocessedFilter* preprocessed_filter) {
221
    auto safe_u32 = [](size_t v) -> uint32_t {
M
Megvii Engine Team 已提交
222 223
        megdnn_assert(
                v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
224 225 226 227 228
        return v;
    };
    size_t spatial_pos;
    if (param().format == Param::Format::NCHW88 ||
        param().format == Param::Format::NCHW8 ||
229
        param().format == Param::Format::NCHW4 ||
230
        param().format == Param::Format::NCHW44_DOT ||
231
        param().format == Param::Format::NCHW44) {
232
        spatial_pos = 2;
233
    } else if (param().format == Param::Format::NCHW) {
234 235 236 237
        spatial_pos = 2;
    } else if (param().format == Param::Format::NHWC) {
        spatial_pos = 1;
    } else {
M
Megvii Engine Team 已提交
238
        megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    }
    size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
                                ->megcore_dispatcher()
                                ->nr_threads();

    return {safe_u32(src[0]),
            {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
            {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
            check_layout_fwd(src, filter, dst),
            src.dtype,
            filter.dtype,
            dst.dtype,
            src.stride[0],
            dst.stride[0],
            {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
            {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
            param().compute_mode,
256 257
            nr_threads,
            preprocessed_filter};
258 259 260 261
}

ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
M
Megvii Engine Team 已提交
262
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
263
    NCBKernParam ret;
264 265
    static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
            src.layout, filter.layout, dst.layout, preprocessed_filter);
266 267 268
    ret.src_ptr = src.get_ref_ptr();
    ret.filter_ptr = filter.get_ref_ptr();
    ret.dst_ptr = dst.get_ref_ptr();
269 270 271 272 273
    ret.workspace_ptr = workspace.raw_ptr;
    ret.workspace_size = workspace.size;
    return ret;
}

M
Megvii Engine Team 已提交
274 275
void ConvolutionImpl::exec_preprocess_with_ncb_kern(
        const NCBKernParam& param, Algorithm* algo) {
276 277 278
    auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
    auto&& fallback_handle = handle();
    for (auto&& kernel : kerns) {
279 280 281 282
        megdnn_assert(
                param.filter_meta.format == Param::Format::NCHW ||
                        param.filter_meta.format == Param::Format::NHWC ||
                        param.filter_meta.format == Param::Format::NCHW88 ||
283 284
                        param.filter_meta.format == Param::Format::NCHW44 ||
                        param.filter_meta.format == Param::Format::NCHW44_DOT,
285 286 287 288 289 290 291 292 293 294
                "invalid conv format");
        auto run = [param, kernel](size_t index, size_t thread_id) {
            CpuNDRange ndrange_id(kernel.global_size, index);
            kernel.kern(param, {thread_id, ndrange_id});
        };
        static_cast<naive::HandleImpl*>(fallback_handle)
                ->dispatch_kern(run, kernel.global_size.total_size());
    }
}

M
Megvii Engine Team 已提交
295
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo) {
296 297 298
    auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
    auto&& fallback_handle = handle();
    for (auto&& kernel : kerns) {
299 300 301 302
        megdnn_assert(
                param.filter_meta.format == Param::Format::NCHW ||
                        param.filter_meta.format == Param::Format::NHWC ||
                        param.filter_meta.format == Param::Format::NCHW88 ||
303 304
                        param.filter_meta.format == Param::Format::NCHW44 ||
                        param.filter_meta.format == Param::Format::NCHW44_DOT,
305
                "invalid conv format");
306
        auto run = [param, kernel](size_t index, size_t thread_id) {
307
            CpuNDRange ndrange_id(kernel.global_size, index);
308
            kernel.kern(param, {thread_id, ndrange_id});
309 310 311 312 313 314 315 316
        };
        static_cast<naive::HandleImpl*>(fallback_handle)
                ->dispatch_kern(run, kernel.global_size.total_size());
    }
}

ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
        const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
M
Megvii Engine Team 已提交
317
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
318 319 320 321 322 323
    auto algo_data_type = param.deduce_algo_data_type();
    auto suggest_category_order = suggest_algo_category_order(param);
    for (auto category : suggest_category_order) {
        auto&& origin_algos = select_algo_type({algo_data_type, category});
        ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
        for (auto i : origin_algos) {
324
            bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
325 326
                    param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
                    negative_attr);
M
Megvii Engine Team 已提交
327 328
            if (usable_attribute && static_cast<AlgoBase*>(i)->get_workspace(param) <=
                                            workspace_limit_in_bytes) {
329 330 331 332 333 334 335 336 337 338 339 340 341
                //! store the first usable algo if no prefer algo, choose it as
                //! the target algo
                if (!heuristic_algo) {
                    heuristic_algo = i;
                }
                //! choose the first prefer algo
                if (i->is_preferred(param)) {
                    return i;
                }
            }
        }
        if (heuristic_algo) {
            return heuristic_algo;
342 343 344 345 346
        }
    }
    return nullptr;
}

M
Megvii Engine Team 已提交
347 348
std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_with_ncb(
        const NCBKernSizeParam& param) {
349 350
    std::vector<Algorithm*> ret;
    std::vector<Algorithm*> prefer_algos;
351
    for (auto&& i : get_all_packed_algo()) {
352 353
        if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
            if (i->is_preferred(param)) {
354 355 356 357 358 359 360 361 362 363
                prefer_algos.push_back(i);
            } else {
                ret.push_back(i);
            }
        }
    }
    ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
    return ret;
}

364 365
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
        const AlgorithmDesc& desc) {
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
    if (!desc.valid()) {
        return nullptr;
    } else {
        switch (desc.handle_type) {
            case Handle::HandleType::FALLBACK: {
                const auto& map = algo_pack().all_algos_map();
                megdnn_assert(map.find(desc) != map.end());
                return map.at(desc);
            }
            case Handle::HandleType::NAIVE: {
                auto algo = static_cast<naive::HandleImpl*>(handle())
                                    ->default_conv_fwd_algo();
                megdnn_assert(algo->info().desc == desc);
                return algo;
            }
            default:
                megdnn_throw("Unknown handle type");
                return nullptr;
        }
    }
}

388 389
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
        const NCBKernSizeParam& param, size_t workspace_size) {
390
    if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
391
        return algo;
392 393 394
    }
    if (!m_prev_selected_algo ||
        memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
395
        m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
M
Megvii Engine Team 已提交
396
                param, workspace_size, AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
397 398 399 400 401
        m_prev_selected_algo_sizep = param;
    }
    return m_prev_selected_algo;
}

402 403 404 405
SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
        const NCBKernSizeParam& param) const {
    static CpuOprDelegationStorage<1> storage;
    auto conv_bias_opr = storage.get<ConvBias, 0>();
M
Megvii Engine Team 已提交
406
    auto conv_bias_param = ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
407 408 409 410
    return static_cast<ConvBiasImpl*>(conv_bias_opr)
            ->suggest_algo_category_order(conv_bias_param);
}

411 412 413 414 415
const char* ConvolutionImpl::get_algorithm_set_name() const {
    // fallback version 0
    return "F0";
}

M
Megvii Engine Team 已提交
416 417
ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type()
        const {
418 419 420 421 422 423
    if (src_type.enumv() == DTypeEnum::Float32) {
        return ConvolutionImpl::AlgoDataType::FLOAT32;
#if !MEGDNN_DISABLE_FLOAT16
    } else if (src_type.enumv() == DTypeEnum::Float16) {
        return ConvolutionImpl::AlgoDataType::FLOAT16;
#endif
M
Megvii Engine Team 已提交
424 425 426
    } else if (
            src_type.enumv() == DTypeEnum::Int8 ||
            src_type.enumv() == DTypeEnum::QuantizedS8) {
427 428 429 430 431 432 433
        if (dst_type.enumv() == DTypeEnum::Int16) {
            return ConvolutionImpl::AlgoDataType::INT8X8X16;
        } else {
            return ConvolutionImpl::AlgoDataType::QINT8X8X32;
        }
    } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
434 435
    } else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
        return ConvolutionImpl::AlgoDataType::QINT4x4x32;
436
    } else {
M
Megvii Engine Team 已提交
437 438 439
        megdnn_throw(ssprintf(
                "not support data type of %s * %s -> %s\n", src_type.name(),
                filter_type.name(), dst_type.name()));
440 441 442
    }
}

443 444
/* ===================== ConvolutionBackwardData ===================== */

445 446 447 448
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
    AlgoNaive algo_naive;
    AlgoDirect algo_direct;
    AlgoMatrixMul algo_matmul;
449 450
    SmallVector<AlgoBase*> m_all_algos;
    AlgoBase::Mapper m_all_algos_map;
451 452 453

public:
    AlgoPack() {
454 455 456 457 458 459 460
        m_all_algos.emplace_back(&algo_matmul);
        m_all_algos.emplace_back(&algo_direct);
        m_all_algos.emplace_back(&algo_naive);

        for (auto&& algo : m_all_algos) {
            m_all_algos_map.emplace(algo->info().desc, algo);
        }
461
    }
462 463
    const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
464
};
M
Megvii Engine Team 已提交
465
const ConvolutionBackwardDataImpl::AlgoPack& ConvolutionBackwardDataImpl::algo_pack() {
466 467 468
    static AlgoPack algo_pack;
    return algo_pack;
}
469

M
Megvii Engine Team 已提交
470 471
SmallVector<ConvolutionBackwardDataImpl::AlgoBase*> ConvolutionBackwardDataImpl::
        get_all_packed_algo() {
472
    return algo_pack().all_algos();
473
}
474

M
Megvii Engine Team 已提交
475 476 477
void ConvolutionBackwardDataImpl::exec(
        _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        _megdnn_workspace workspace) {
478
    if (param().format == param::Convolution::Format::NHWCD4 ||
479 480 481
        param().format == param::Convolution::Format::NCHW4 ||
        (param().format == param::Convolution::Format::NCHW &&
         grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
M
Megvii Engine Team 已提交
482
        return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
483 484 485 486 487 488 489 490
    }
    auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
    return exec_with_ncb_kern(fparam);
}

size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
        const TensorLayout& filter, const TensorLayout& diff,
        const TensorLayout& grad) {
491 492
    TensorLayoutArray layouts{filter, diff, grad};
    HeuristicCache::Key key{this->handle(), this->get_opr_type(),
M
Megvii Engine Team 已提交
493 494
                            layouts.data(), layouts.size(),
                            &this->param(), sizeof(this->param())};
495 496 497 498 499
    auto rst = HeuristicCache::instance().get(key);
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

500
    if (param().format == param::Convolution::Format::NHWCD4 ||
501 502 503
        param().format == param::Convolution::Format::NCHW4 ||
        (param().format == param::Convolution::Format::NCHW &&
         grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
504 505 506 507 508 509 510
        return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
                filter, diff, grad);
    }
    auto fparam = make_ncb_kern_size_param(filter, diff, grad);
    return get_workspace_with_ncb(fparam);
}

M
Megvii Engine Team 已提交
511 512 513 514
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad) {
515
    if (param().format == param::Convolution::Format::NHWCD4 ||
516 517 518
        param().format == param::Convolution::Format::NCHW4 ||
        (param().format == param::Convolution::Format::NCHW &&
         grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
519 520 521 522 523 524 525 526
        return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
                filter, diff, grad);
    }
    auto fparam = make_ncb_kern_size_param(filter, diff, grad);
    auto ret = get_all_algorithms_with_ncb(fparam);
    return ret;
}

M
Megvii Engine Team 已提交
527 528 529 530 531
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms_safe(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad) {
    auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter, diff, grad);
532 533 534 535
    megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm");
    return ret_safe;
}

M
Megvii Engine Team 已提交
536 537 538 539 540 541
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_heuristic(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad, size_t workspace_limit_in_bytes,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
542
    if (param().format == param::Convolution::Format::NHWCD4 ||
543 544 545
        param().format == param::Convolution::Format::NCHW4 ||
        (param().format == param::Convolution::Format::NCHW &&
         grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
546
        return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
547 548
                filter, diff, grad, workspace_limit_in_bytes, positive_attr,
                negative_attr);
549 550
    }
    auto fparam = make_ncb_kern_size_param(filter, diff, grad);
M
Megvii Engine Team 已提交
551 552
    return get_algorithm_heuristic_with_ncb(
            fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
553 554
}

M
Megvii Engine Team 已提交
555 556 557 558
ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl::
        make_ncb_kern_size_param(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad) {
559
    auto safe_u32 = [](size_t v) -> uint32_t {
M
Megvii Engine Team 已提交
560 561
        megdnn_assert(
                v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
562 563 564 565 566 567
        return v;
    };
    size_t spatial_pos;
    if (param().format == Param::Format::NCHW) {
        spatial_pos = 2;
    } else {
M
Megvii Engine Team 已提交
568
        megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
        spatial_pos = 1;
    }
    auto grad_fwd = grad;
    auto filter_fwd = filter;
    auto diff_fwd = diff;

    std::swap(grad_fwd.dtype, diff_fwd.dtype);

    return {
            safe_u32(diff[0]),
            {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
            {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
            check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
            diff.dtype,
            filter.dtype,
            grad.dtype,
            diff,
            filter,
            grad,
            diff.stride[0],
            grad.stride[0],
            0,
            0,
            0,
            param().compute_mode,
    };
}

M
Megvii Engine Team 已提交
597 598 599 600
ConvolutionBackwardDataImpl::NCBKernParam ConvolutionBackwardDataImpl::
        make_ncb_kern_param(
                _megdnn_tensor_in filter, _megdnn_tensor_in diff,
                _megdnn_tensor_out grad, _megdnn_workspace workspace) {
601 602 603 604 605
    NCBKernParam ret;
    static_cast<NCBKernSizeParam&>(ret) =
            make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);

    auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
M
Megvii Engine Team 已提交
606 607 608 609
    megdnn_assert(
            workspace.size >= required_workspace_in_bytes,
            "required workspace: %zu; provided workspace: %zu",
            required_workspace_in_bytes, workspace.size);
610 611 612
    ret.filter_ptr = filter.get_ref_ptr();
    ret.diff_ptr = diff.get_ref_ptr();
    ret.grad_ptr = grad.get_ref_ptr();
613 614 615 616 617
    ret.workspace_ptr = workspace.raw_ptr;
    ret.workspace_size = workspace.size;
    return ret;
}

M
Megvii Engine Team 已提交
618
void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) {
619 620 621
    auto p1g = param;
    auto group = p1g.filter_meta.group;
    p1g.filter_meta.group = 1;
622
    auto&& algo = get_algorithm(p1g);
623
    auto kptr = ncb_1g_dispatch_kern(algo, p1g);
624
    if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
625 626 627
        auto run = [kptr, param]() { kptr(param); };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
    } else {
M
Megvii Engine Team 已提交
628 629 630 631
        megdnn_assert(
                p1g.filter_meta.format == Param::Format::NCHW ||
                        p1g.filter_meta.format == Param::Format::NHWC,
                "invalid conv format");
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
        auto run = [kptr, p1g_orig = p1g, group]() {
            auto p1g = p1g_orig;
            ptrdiff_t istrd, fstrd, ostrd;
            fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
                    p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
                    p1g.filter_type.size();
            istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
            ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
            p1g.diff_extra_mem_size =
                    (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
            p1g.filter_extra_mem_size =
                    (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
                    p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
                    p1g.filter_type.size();
            p1g.grad_extra_mem_size =
                    (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
            if (p1g.filter_meta.format == Param::Format::NCHW) {
                istrd *= p1g.isz[0] * p1g.isz[1];
                ostrd *= p1g.osz[0] * p1g.osz[1];
                p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
                p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
            } else {
                // must be NHWC. No action performed.
            }
            for (size_t i = 0; i < group; ++i) {
                kptr(p1g);
658 659 660
                p1g.diff_ptr += istrd;
                p1g.filter_ptr += fstrd;
                p1g.grad_ptr += ostrd;
661 662 663 664 665 666 667 668 669 670 671 672 673 674
                p1g.diff_extra_mem_size -= istrd;
                p1g.filter_extra_mem_size -= fstrd;
                p1g.grad_extra_mem_size -= ostrd;
            }
        };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
    }
}

size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
        const NCBKernSizeParam& param) {
    if (param.filter_meta.group != 1) {
        auto p1g = param;
        p1g.filter_meta.group = 1;
675 676
        auto algo = get_algorithm(p1g);
        return ncb_1g_get_workspace(algo, p1g);
677
    }
678 679
    auto algo = get_algorithm(param);
    return ncb_1g_get_workspace(algo, param);
680 681
}

M
Megvii Engine Team 已提交
682 683
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
684 685 686 687 688 689 690 691
    if (param.filter_meta.group != 1) {
        auto p1g = param;
        p1g.filter_meta.group = 1;
        return ncb_1g_get_all_algorithms(p1g);
    }
    return ncb_1g_get_all_algorithms(param);
}

M
Megvii Engine Team 已提交
692 693 694 695 696
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_heuristic_with_ncb(
                const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
697 698 699
    if (param.filter_meta.group != 1) {
        auto p1g = param;
        p1g.filter_meta.group = 1;
M
Megvii Engine Team 已提交
700 701
        return ncb_1g_get_algorithm_heuristic(
                p1g, workspace_limit_in_bytes, positive_attr, negative_attr);
702
    }
M
Megvii Engine Team 已提交
703 704
    return ncb_1g_get_algorithm_heuristic(
            param, workspace_limit_in_bytes, positive_attr, negative_attr);
705 706 707 708 709
}

size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
        Algorithm* algo, const NCBKernSizeParam& param) {
    megdnn_assert(param.filter_meta.group == 1);
710
    if (algo->handle_type() == Handle::HandleType::FALLBACK) {
711 712 713 714 715
        return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
    }
    return 0;
}

M
Megvii Engine Team 已提交
716 717
ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::
        ncb_1g_dispatch_kern(Algorithm* algo, const NCBKernSizeParam& param) {
718 719
    megdnn_assert(param.filter_meta.group == 1);

720
    if (algo->handle_type() == Handle::HandleType::FALLBACK) {
721 722 723
        return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
    }

M
Megvii Engine Team 已提交
724
    megdnn_throw("no suitable ConvolutionBackwardData algorithm");
725 726 727 728 729 730 731 732 733 734 735 736
}

bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
        const NCBKernSizeParam& param) {
    auto&& fm = param.filter_meta;
    auto OC = fm.ocpg, IC = fm.icpg;

    return (OC * IC >= 32) ||
           (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
            fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
}

M
Megvii Engine Team 已提交
737 738
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) {
739
    std::vector<Algorithm*> ret;
740
    std::vector<Algorithm*> prefer_algos;
741
    for (auto&& i : get_all_packed_algo()) {
742 743 744
        if (i->usable(this, param)) {
            if (i->is_preferred(param)) {
                prefer_algos.push_back(i);
745
            } else {
746
                ret.push_back(i);
747 748 749
            }
        }
    }
750
    ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
751 752 753
    return ret;
}

M
Megvii Engine Team 已提交
754 755 756 757 758
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        ncb_1g_get_algorithm_heuristic(
                const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
759 760
    for (auto i : ncb_1g_get_all_algorithms(param)) {
        if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
761 762
            if (i->contain_attribute_all(positive_attr) &&
                !i->contain_attribute_any(negative_attr)) {
763 764 765 766
                return i;
            }
        }
    }
M
Megvii Engine Team 已提交
767
    megdnn_assert(0, "no suitable algorithm found within given workspace limit");
768 769
}

M
Megvii Engine Team 已提交
770 771
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_from_desc(const AlgorithmDesc& desc) {
772 773 774 775 776 777 778 779 780 781 782 783 784
    if (!desc.valid()) {
        return nullptr;
    } else {
        switch (desc.handle_type) {
            case Handle::HandleType::FALLBACK: {
                const auto& map = algo_pack().all_algos_map();
                megdnn_assert(map.find(desc) != map.end());
                return map.at(desc);
            }
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case Handle::HandleType::ARM_COMMON:
            case Handle::HandleType::AARCH64:
            case Handle::HandleType::ARMV7:
M
Megvii Engine Team 已提交
785 786
                return arm_common::ConvolutionBackwardDataImpl::get_algo_from_desc(
                        desc);
787 788 789 790 791 792 793 794 795 796 797 798 799 800
#endif
            case Handle::HandleType::NAIVE: {
                auto algo = static_cast<naive::HandleImpl*>(handle())
                                    ->default_conv_bwd_data_algo();
                megdnn_assert(algo->info().desc == desc);
                return algo;
            }
            default:
                megdnn_throw("Unknown handle type");
                return nullptr;
        }
    }
}

M
Megvii Engine Team 已提交
801 802
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm(
        const NCBKernSizeParam& param) {
803
    if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
804
        return algo;
805 806 807 808
    }
    if (!m_prev_selected_algo ||
        memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
        m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
M
Megvii Engine Team 已提交
809 810
                param, std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT,
                AlgoAttribute::DEFAULT);
811 812 813 814 815 816 817 818 819 820 821
        m_prev_selected_algo_sizep = param;
    }
    return m_prev_selected_algo;
}

const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
    // fallback version 0
    return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
}

// vim: syntax=cpp.doxygen