algo_chooser.cpp 4.1 KB
Newer Older
1 2 3 4
/**
 * \file src/opr/impl/search_policy/algo_chooser.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

13
#include <limits>
14
#include <unordered_set>
15

16
#include "megbrain/opr/dnn/convolution.h"
17
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
18
#include "megbrain/opr/search_policy/algo_chooser.h"
19
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
20 21
#include "megbrain/utils/invoke.h"
#include "megdnn/heuristic_cache.h"
22 23 24 25 26

#include "../internal/megdnn_opr_wrapper.inl"
#include "./workspace_need_limit_getter.inl"

using mgb::opr::intl::WorkspaceLimitGetter;
27 28
using namespace megdnn;
using namespace mgb;
29

30 31
namespace mgb {
namespace opr {
32 33

template <typename Opr>
M
Megvii Engine Team 已提交
34 35 36
size_t AlgoChooser<Opr>::setup_algo(
        const FixedTensorLayouts& layouts, Opr* megdnn_opr, const MGBOpr* mgb_opr,
        bool allow_weight_preprocess) {
37 38 39 40 41 42 43 44 45
    HeuristicCache::Key cache_key(
            megdnn_opr->handle(), megdnn_opr->get_opr_type(), layouts.data(),
            layouts.size(), &megdnn_opr->param(), sizeof(megdnn_opr->param()));
    auto rst = HeuristicCache::instance().get(cache_key);
    if (rst.policy.algo.valid()) {
        megdnn_opr->execution_policy() = rst.policy;
        return rst.workspace;
    }

46 47 48 49 50 51
    if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) {
        return 0;
    }

    std::string param_str;
    Algorithm::serialize_write_pod(megdnn_opr->param(), param_str);
52 53 54 55 56 57 58 59 60 61 62

    auto cg = mgb_opr->owner_graph();
    rdnn::AlgoChooserDesc desc;
    desc.shared_batch_size = cg->options().fast_run_config.shared_batch_size;
    desc.binary_equal_between_batch =
            cg->options().fast_run_config.binary_equal_between_batch;
    desc.no_profiling_on_shape_change = cg->options().no_profiling_on_shape_change;
    desc.get_workspace_limit = [&](CompNode cn, size_t old_limit) {
        return WorkspaceLimitGetter::get_workspace_limit(cg, cn, old_limit);
    };

M
Megvii Engine Team 已提交
63
    AlgoChooserHelper helper(
64 65
            layouts, megdnn_opr, param_str, mgb_opr->comp_node(),
            mgb_opr->execution_policy(), allow_weight_preprocess, desc);
66 67 68 69

    ImplExecutionPolicy policy;
    if (auto algo_choose_hook = mgb_opr->algo_chooser()) {
        policy = algo_choose_hook(mgb_opr);
70 71
        auto strategy = rdnn::ExecutionStrategy::HEURISTIC |
                        rdnn::ExecutionStrategy::REPRODUCIBLE;
72 73
        bool retrive_from_cache = false;
        helper.construct_execution_policy(strategy, policy, retrive_from_cache);
74 75
    }
    if (!policy.algo.valid()) {
76
        policy = Base::get_policy(helper);
77
    }
78
    size_t workspace = helper.get_workspace_size_bytes(policy, layouts);
79 80 81

    std::string ret;
    ret.append(mgb_opr->dyn_typeinfo()->name);
82
    ret.append(": tensor layouts");
83
    ret += Base::format_fixlayouts(layouts);
84 85 86
    Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo);
    mgb_assert(palgo, "Unknown algo description");
    ret.append("): algo=" + std::string(palgo->name()));
M
Megvii Engine Team 已提交
87 88 89
    ret.append(ssprintf(
            " workspace=%.2fMiB attribute=%d", workspace / (1024 * 1024.0),
            static_cast<uint32_t>(palgo->attribute())));
90 91 92
    mgb_log_debug("%s", ret.c_str());

    megdnn_opr->execution_policy() = policy;
93

94
    if (mgb_opr->execution_policy().strategy & rdnn::ExecutionStrategy::HEURISTIC) {
95 96 97
        HeuristicCache::Result cache_result{policy, workspace};
        HeuristicCache::instance().put(cache_key, cache_result);
    }
98 99 100
    return workspace;
}

101 102 103
#define INST(Opr)                                                       \
    template size_t AlgoChooser<megdnn::Opr>::setup_algo(               \
            const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
104 105 106
            const MGBOpr* mgb_opr, bool allow_weight_preprocess);

MGB_FOREACH_FASTRUN_OPR(INST)
107
#undef INST
108

109 110 111 112
}  // namespace opr
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}