From 2eea00097c4b631e8535501355f4cd3c33eb2b7e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 19 May 2021 20:37:44 +0800 Subject: [PATCH] feat(mgb): add fast run batch size graph option GitOrigin-RevId: 94e333ec805d81a279365a03a9665a69f789f522 --- dnn/include/megdnn/oprs/utils.h | 5 + dnn/src/common/conv_bias.cpp | 30 -- dnn/src/common/conv_bias.h | 2 - dnn/src/common/utils.cpp | 30 ++ sdk/load-and-run/src/mgblar.cpp | 27 ++ src/core/include/megbrain/graph/cg.h | 23 +- src/opr/impl/search_policy/algo_chooser.cpp | 302 ++++++++++++++--- .../megbrain/opr/search_policy/algo_chooser.h | 21 +- .../megbrain/opr/search_policy/profiler.h | 8 +- src/opr/test/algo_chooser.cpp | 304 ++++++++++++++++++ test/src/helper.cpp | 15 +- test/src/include/megbrain/test/helper.h | 18 +- 12 files changed, 682 insertions(+), 103 deletions(-) create mode 100644 src/opr/test/algo_chooser.cpp diff --git a/dnn/include/megdnn/oprs/utils.h b/dnn/include/megdnn/oprs/utils.h index 7583427d..e2303dc7 100644 --- a/dnn/include/megdnn/oprs/utils.h +++ b/dnn/include/megdnn/oprs/utils.h @@ -91,6 +91,11 @@ class MaxTensorDiff : public OperatorBase { void check_exec(const TensorLayout& layout1, const TensorLayout& layout2, size_t workspace_in_bytes); }; + + +bool check_bias_share_in_channel(const TensorLayout& bias, + const param::ConvBias::Format format); + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index 1bf32c87..296e8876 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -318,36 +318,6 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, megdnn_assert(false); } } - -bool check_bias_share_in_channel(const TensorLayout& bias, - const param::ConvBias::Format format) { - bool share_in_channel = false; - if (format == param::ConvBias::Format::NCHW || - format == param::ConvBias::Format::NCHW4_NCHW) { - share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && - bias[3] == 1); - } else if (format == param::ConvBias::Format::NHWC) { - share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && - bias[2] == 1); - } else if (format == param::ConvBias::Format::NCHW4 || - format == param::ConvBias::Format::NCHW8 || - format == param::ConvBias::Format::NCHW32 || - format == param::ConvBias::Format::NCHW64 || - format == param::ConvBias::Format::NCHW4_NCHW32 || - format == param::ConvBias::Format::NCHW32_NCHW4) { - share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && - bias[3] == 1); - } else if (format == param::ConvBias::Format::NHWCD4) { - share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && - bias[3] == 1); - } else { - megdnn_assert(format == param::ConvBias::Format::CHWN4); - share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && - bias[3] == 1); - } - return share_in_channel; -} - } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/conv_bias.h b/dnn/src/common/conv_bias.h index 3eeacfc1..3a55afe4 100644 --- a/dnn/src/common/conv_bias.h +++ b/dnn/src/common/conv_bias.h @@ -22,8 +22,6 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, const TensorND* dst_tensor, const TensorND* bias_tensor); -bool check_bias_share_in_channel(const TensorLayout& bias, - const param::ConvBias::Format format); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 6d340de6..9e056791 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -10,6 +10,7 @@ */ #include "src/common/utils.h" +#include "megdnn/oprs/utils.h" #include "megdnn/handle.h" #include @@ -344,4 +345,33 @@ size_t& CpuNDRange::operator[](size_t idx) { return m_dim[idx]; } +bool megdnn::check_bias_share_in_channel(const TensorLayout& bias, + const param::ConvBias::Format format) { + bool share_in_channel = false; + if (format == param::ConvBias::Format::NCHW || + format == param::ConvBias::Format::NCHW4_NCHW) { + share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && + bias[3] == 1); + } else if (format == param::ConvBias::Format::NHWC) { + share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && + bias[2] == 1); + } else if (format == param::ConvBias::Format::NCHW4 || + format == param::ConvBias::Format::NCHW8 || + format == param::ConvBias::Format::NCHW32 || + format == param::ConvBias::Format::NCHW64 || + format == param::ConvBias::Format::NCHW4_NCHW32 || + format == param::ConvBias::Format::NCHW32_NCHW4) { + share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && + bias[3] == 1); + } else if (format == param::ConvBias::Format::NHWCD4) { + share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && + bias[3] == 1); + } else { + megdnn_assert(format == param::ConvBias::Format::CHWN4); + share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && + bias[3] == 1); + } + return share_in_channel; +} + // vim: syntax=cpp.doxygen diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 46b6391a..8f5cbca0 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -158,6 +158,11 @@ R"__usage__( R"__usage__( --fast-run-algo-policy It will read the cache file before profile, and save new fastrun in cache file. + --fast-run-shared-batch-size + Set the batch size used during fastrun, Note that it may not be the same as the actual running batch size + --binary-equal-between-batch + Each batch of output is promised binary equal if each batch of input is binary equal. + Note that if this option is turned on, `--reproducible` will also be turned on. --reproducible Enable choose algo which is reproducible. It mainly used for cudnn algos. See https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#reproducibility @@ -1356,6 +1361,20 @@ Args Args::from_argv(int argc, char **argv) { ret.fast_run_cache_path = argv[i]; continue; } + if (!strcmp(argv[i], "--fast-run-shared-batch-size")) { + ++i; + mgb_assert(i < argc, + "value not given for --fast-run-shared-batch-size"); + int32_t batch_size = std::stoi(argv[i]); + mgb_assert(batch_size >= 0); + graph_opt.fast_run_config.shared_batch_size = batch_size; + continue; + } + if (!strcmp(argv[i], "--binary-equal-between-batch")) { + graph_opt.fast_run_config.binary_equal_between_batch = true; + ret.reproducible = true; + continue; + } if (!strcmp(argv[i], "--reproducible")) { ret.reproducible = true; continue; @@ -1452,6 +1471,14 @@ Args Args::from_argv(int argc, char **argv) { return ret; } +#if MGB_ENABLE_FASTRUN + if (graph_opt.fast_run_config.shared_batch_size) { + mgb_assert(ret.use_fast_run || ret.use_full_run || + !ret.fast_run_cache_path.empty(), + "--fast-run-shared-batch-size should be used with " + "--fast-run/--full-run/--fast-run-algo-policy"); + } +#endif return ret; } diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index dd731284..15467552 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -502,7 +502,28 @@ class ComputingGraph : public std::enable_shared_from_this, //! contains any user data associated with this graph UserDataContainer user_data; - }; // Options + + //! Control parameter for fast run + struct FastRunConfig { + /*! + * the batch size used by fastrun + * + * Non-zero value means that fastrun use this batch size + * regardless of the batch size of the model + * + * Zero means fastrun use batch size of the model + */ + uint32_t shared_batch_size = 0; + + /*! + * \brief if the content of each input batch is binary equal, + * whether the content of each output batch is promised to be + * equal + */ + bool binary_equal_between_batch = false; + } fast_run_config; + + }; // Options Options& options() { return m_options; diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 32900597..51bab364 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -60,21 +60,31 @@ std::string profile_name(Opr* opr) { template std::string format_fixlayouts( const typename opr::AlgoChooser::FixedTensorLayouts& layouts, - size_t arity_in, size_t arity_out) { + size_t arity_in, size_t arity_out, + const std::string& delimiter = " -> ") { std::string ret; - ret.append(": tensor layouts("); - for (size_t i = 0; i < arity_in; ++i) { - if (i) { - ret.append(", "); + if (arity_in) { + ret.append("("); + for (size_t i = 0; i < arity_in; ++i) { + if (i) { + ret.append(", "); + } + ret.append(layouts[i].to_string() + " "); } - ret.append(layouts[i].to_string() + " "); + ret.append(")"); + } + if (arity_in && arity_out) { + ret.append(delimiter); } - ret.append(") -> ("); - for (size_t i = 0; i < arity_out; ++i) { - if (i) { - ret.append(", "); + if (arity_out) { + ret.append("("); + for (size_t i = 0; i < arity_out; ++i) { + if (i) { + ret.append(", "); + } + ret.append(layouts[i + arity_in].to_string() + " "); } - ret.append(layouts[i + arity_in].to_string() + " "); + ret.append(")"); } return ret; } @@ -247,7 +257,7 @@ std::vector flatten_search_space( CircularDepsChecker& checker) { auto&& search_item = megdnn::Algorithm::SearchItem{ OprTypeFromOprTrait::opr_type, helper.param(), - to_layout_array(helper.layouts())}; + to_layout_array(helper.fastrun_layouts())}; checker.put(search_item); std::vector ret; for (auto algo_info : helper.get_all_candidates()) { @@ -255,8 +265,9 @@ std::vector flatten_search_space( helper.get_algorithm_from_desc(algo_info.desc); mgb_assert(algo, "Unknown algo description"); std::vector&& sub_items = - algo->get_subopr_list(to_layout_array(helper.layouts()), - helper.megdnn_opr()); + algo->get_subopr_list( + to_layout_array(helper.fastrun_layouts()), + helper.megdnn_opr()); FOREACH_OPR_TYPE_DISPATCH(sub_items, { auto&& megdnn_opr = @@ -323,6 +334,166 @@ static Algorithm::Info::Desc deserialize_read_pod(const std::string& data, namespace mgb { namespace opr { +template +class LayoutsModifier { + using FixedTensorLayouts = typename AlgoChooser::FixedTensorLayouts; + +public: + static void on(FixedTensorLayouts&, const typename Opr::Param&, size_t) {} + +private: + //! index of batch in tensor, 3 for CHWN4 e.g. + static size_t index_of_batch(const typename Opr::Param&) { return 0; } + + //! indices contain batch in inputs and outputs, src(0) dst(2) for conv e.g. + static std::vector sm_indices_contain_batch; +}; +template +std::vector LayoutsModifier::sm_indices_contain_batch = {}; + +#define DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(opr, idxs) \ + template <> \ + class LayoutsModifier { \ + public: \ + using FixedTensorLayouts = \ + typename AlgoChooser::FixedTensorLayouts; \ + static void on(FixedTensorLayouts& layouts, const opr::Param& param, \ + size_t new_batch_size) { \ + size_t batch_index = index_of_batch(param); \ + for (size_t index : sm_indices_contain_batch) { \ + layouts.at(index)[batch_index] = new_batch_size; \ + } \ + } \ + \ + private: \ + static size_t index_of_batch(const opr::Param&) { return 0; } \ + static std::vector sm_indices_contain_batch; \ + }; \ + std::vector LayoutsModifier::sm_indices_contain_batch = idxs; + +DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(megdnn::Convolution3DForward, + (std::initializer_list{0, 2})) +DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(megdnn::Convolution3DBackwardData, + (std::initializer_list{1, 2})) +DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(megdnn::Convolution3DBackwardFilter, + (std::initializer_list{0, 1})) +DEFAULT_OPR_WITHOUT_INPUT_BROADCAST(megdnn::BatchedMatrixMul, + (std::initializer_list{0, 1, 2})) +#undef DEFAULT_OPR_WITHOUT_INPUT_BROADCAST + +#define CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(opr, idxs) \ + template <> \ + class LayoutsModifier { \ + public: \ + using FixedTensorLayouts = \ + typename AlgoChooser::FixedTensorLayouts; \ + static void on(FixedTensorLayouts& layouts, const opr::Param& param, \ + size_t new_batch_size) { \ + size_t batch_index = index_of_batch(param); \ + for (size_t index : sm_indices_contain_batch) { \ + layouts.at(index)[batch_index] = new_batch_size; \ + } \ + } \ + \ + private: \ + static size_t index_of_batch(const opr::Param& param) { \ + if (param.format == opr::Param::Format::CHWN4) { \ + return 3; \ + } \ + return 0; \ + } \ + static std::vector sm_indices_contain_batch; \ + }; \ + std::vector LayoutsModifier::sm_indices_contain_batch = idxs; + +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::ConvolutionForward, + (std::initializer_list{0, 2})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::ConvolutionBackwardData, + (std::initializer_list{1, 2})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::ConvolutionBackwardFilter, + (std::initializer_list{0, 1})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::LocalShareForward, + (std::initializer_list{0, 2})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::LocalShareBackwardData, + (std::initializer_list{1, 2})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::LocalShareBackwardFilter, + (std::initializer_list{0, 1})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::DeformableConvForward, + (std::initializer_list{0, 2, 3, + 4})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::DeformableConvBackwardData, + (std::initializer_list{0, 2, 3, 4, + 5, 6, 7})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::DeformableConvBackwardFilter, + (std::initializer_list{0, 1, 2, + 3})) +CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST(megdnn::BatchConvBiasForward, + (std::initializer_list{0, 1, 2, 3, + 4})) +#undef CONV_LIKE_OPR_WITHOUT_INPUT_BROADCAST + +template <> +class LayoutsModifier { +public: + using FixedTensorLayouts = + typename AlgoChooser::FixedTensorLayouts; + static void on(FixedTensorLayouts& layouts, + const megdnn::ConvBiasForward::Param& param, + size_t new_batch_size) { + size_t batch_index = index_of_batch(param); + for (size_t index : sm_indices_contain_batch) { + layouts.at(index)[batch_index] = new_batch_size; + } + for (size_t index : sm_indices_contain_batch_broadcast) { + if (!check_bias_share_in_channel(layouts.at(index), param.format)) { + layouts.at(index)[batch_index] = new_batch_size; + } + } + } + +private: + static std::vector sm_indices_contain_batch; + static std::vector sm_indices_contain_batch_broadcast; + static size_t index_of_batch(const megdnn::ConvBiasForward::Param& param) { + if (param.format == megdnn::ConvBiasForward::Param::Format::CHWN4) { + return 3; + } + return 0; + } +}; +std::vector + LayoutsModifier::sm_indices_contain_batch = { + 0, 3, 4}; +std::vector LayoutsModifier< + megdnn::ConvBiasForward>::sm_indices_contain_batch_broadcast = {2}; + +template <> +class LayoutsModifier { +public: + using FixedTensorLayouts= + typename AlgoChooser::FixedTensorLayouts; + static void on(FixedTensorLayouts& layouts, + const megdnn::MatrixMul::Param& param, + size_t new_batch_size) { + //! Because we do not know whether the batch size is in the dimension m + //! or the dimension n, we just ignore both m and n here. + // FIXME Find a way to make mgb obtain batch size information from R or + // automatically + layouts.at(2)[0] = new_batch_size; + layouts.at(2)[1] = new_batch_size; + if (param.transposeA) { + layouts.at(0)[1] = new_batch_size; + } else { + layouts.at(0)[0] = new_batch_size; + } + if (param.transposeB) { + layouts.at(1)[0] = new_batch_size; + } else { + layouts.at(1)[1] = new_batch_size; + } + } +}; + ///////////////////////////// AlgoChooserHelper ////////////////////////// template AlgoChooser::AlgoChooserHelper::AlgoChooserHelper( @@ -331,14 +502,25 @@ AlgoChooser::AlgoChooserHelper::AlgoChooserHelper( const CompNode& cn, const megdnn::param::ExecutionPolicy& execution_policy, bool allow_weight_preprocess) - : m_layouts{layouts}, + : m_fastrun_layouts{layouts}, + m_incache_layouts{layouts}, m_dnn_opr{megdnn_opr}, m_param{param_str}, m_base_mgb_opr{mgb_opr}, m_cn{cn}, m_execution_policy{execution_policy}, m_allow_weight_preprocess{allow_weight_preprocess} { - mgb_assert(m_layouts.size() == layouts.size()); + auto fastrun_batch_size = + owner_graph()->options().fast_run_config.shared_batch_size; + + if (fastrun_batch_size) { + LayoutsModifier::on(m_incache_layouts, m_dnn_opr->param(), 0); + LayoutsModifier::on(m_fastrun_layouts, m_dnn_opr->param(), + fastrun_batch_size); + } + + mgb_assert(m_fastrun_layouts.size() == layouts.size()); + static_assert(std::tuple_size::value == 3 || std::tuple_size::value == 5 || std::tuple_size::value == 8, @@ -358,13 +540,13 @@ AlgoChooser::AlgoChooserHelper::choose_by_heuristic( policy.algo = APPLY(m_dnn_opr->get_algorithm_info_heuristic( args..., workspace_limit, attr.first, attr.second), - m_layouts) + m_fastrun_layouts) .desc; Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); mgb_assert(algo, "Unknown algo description"); std::vector&& sub_items = algo->get_subopr_list( - to_layout_array(m_layouts), m_dnn_opr); + to_layout_array(m_fastrun_layouts), m_dnn_opr); FOREACH_OPR_TYPE_DISPATCH(sub_items, { auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); @@ -393,7 +575,7 @@ AlgoChooser::AlgoChooserHelper::choose_by_profile( if (policy.algo.valid()) { return policy; } - if (!algo_usable_on_shape_change()) { + if (is_matmul()) { mgb_log_warn( "choose algo by heuristic, which may cause performance " "regression."); @@ -442,7 +624,8 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str()); typename Opr::Param origin_param = m_dnn_opr->param(); - AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(), + AlgoChooserProfileCache::Key cache_key{m_incache_layouts.data(), + m_incache_layouts.size(), &origin_param, sizeof(origin_param)}; auto&& rst = cache.get(cache_key); if (!rst.valid()) @@ -472,21 +655,21 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( } std::string layouts_str = - format_fixlayouts(m_layouts, arity_in, arity_out); + format_fixlayouts(m_fastrun_layouts, arity_in, arity_out); if (skip_by_negative) { mgb_log_error( - "opr: %s, layouts: %s, No usable algo. There are available algos match " + "opr: %s, layouts: %s, No usable algo. There are available " + "algos match " "positive strategy(%s), but filtered by negative stategy(%s).", - m_base_mgb_opr->dyn_typeinfo()->name, - layouts_str.c_str(), + m_base_mgb_opr->dyn_typeinfo()->name, layouts_str.c_str(), Algorithm::attribute_str(target_attr.first).c_str(), Algorithm::attribute_str(target_attr.second).c_str()); } else { mgb_log_error( - "opr: %s, layouts: %s, No usable algo. algos read from cache could not " + "opr: %s, layouts: %s, No usable algo. algos read from cache " + "could not " "satisfy positive strategy(%s)", - m_base_mgb_opr->dyn_typeinfo()->name, - layouts_str.c_str(), + m_base_mgb_opr->dyn_typeinfo()->name, layouts_str.c_str(), Algorithm::attribute_str(target_attr.first).c_str()); } @@ -508,7 +691,7 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( auto target_attr = extract_algo_attribute(selected_strategy); std::string layouts_str = format_fixlayouts( - m_layouts, arity_in, arity_out); + m_fastrun_layouts, arity_in, arity_out); std::string msg = ssprintf( "(opr : %s, layouts %s, with attribute(%s) and " "without attribute(%s)", @@ -535,7 +718,7 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( policy.algo = APPLY(m_dnn_opr->get_algorithm_info_heuristic( args..., workspace_limit, attr.first, attr.second), - m_layouts) + m_fastrun_layouts) .desc; mgb_assert(policy.algo.valid(), "No algo found from heuristic with strategy %u and " @@ -548,7 +731,7 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); mgb_assert(algo, "Unknown algo description"); std::vector&& sub_items = algo->get_subopr_list( - to_layout_array(m_layouts), m_dnn_opr); + to_layout_array(m_fastrun_layouts), m_dnn_opr); FOREACH_OPR_TYPE_DISPATCH(sub_items, { auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); @@ -575,26 +758,32 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( template size_t AlgoChooser::AlgoChooserHelper::get_workspace_size_bytes( - const ImplExecutionPolicy& policy) const { + const ImplExecutionPolicy& policy, + const FixedTensorLayouts& layouts) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes"))) m_dnn_opr->execution_policy() = policy; size_t result; + const FixedTensorLayouts* layouts_ptr = &m_fastrun_layouts; + if (layouts.at(0).ndim) { + layouts_ptr = &layouts; + } if_constexpr()>( [&](auto _) { auto&& opr = _(m_dnn_opr); - auto prep = this->construct_fake_preprocess_filter(); + auto prep = + this->construct_fake_preprocess_filter(*layouts_ptr); PreprocessFilter* prep_ptr = prep.valid() ? &prep.val() : nullptr; result = std::max( APPLY(opr->get_preprocess_workspace_in_bytes(args...), - m_layouts), + *layouts_ptr), APPLY(opr->get_workspace_in_bytes(args..., prep_ptr), - m_layouts)); + *layouts_ptr)); }, /* else */ [&](auto _) { result = APPLY(_(m_dnn_opr)->get_workspace_in_bytes(args...), - m_layouts); + *layouts_ptr); }); return result; MIDOUT_E @@ -605,8 +794,8 @@ std::vector::ImplAlgo> AlgoChooser::AlgoChooserHelper::get_all_candidates() const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) auto heu = choose_by_heuristic(m_execution_policy.strategy); - auto&& ret = - APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_layouts); + auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), + m_fastrun_layouts); bool found = false; for (size_t i = 0; i < ret.size(); ++i) { if (ret[i].desc == heu.algo) { @@ -637,7 +826,7 @@ AlgoChooser::AlgoChooserHelper::profile_single_algo( TimedProfiler::Param::ExecutionPolicyBlob::serialize(policy); param.workspace = get_workspace_size_bytes(policy); for (int i = 0; i < arity; ++i) { - auto&& src = m_layouts[i]; + auto&& src = m_fastrun_layouts[i]; bool cond_normal = src.format.is_default() && (src.dtype.category() == DTypeCategory::FLOAT || src.dtype.category() == DTypeCategory::INT || @@ -655,9 +844,9 @@ AlgoChooser::AlgoChooserHelper::profile_single_algo( param.dtypes[i] = src.dtype.enumv(); } param.comp_node_loc = m_cn.locator(); - mgb_assert(param.shapes.size() == m_layouts.size()); + mgb_assert(param.shapes.size() == m_fastrun_layouts.size()); for (size_t i = 0; i < param.shapes.size(); ++i) - param.shapes[i] = m_layouts[i]; + param.shapes[i] = m_fastrun_layouts[i]; param.opr_param = m_dnn_opr->param(); param.allow_weight_preprocess = m_allow_weight_preprocess; @@ -692,7 +881,7 @@ void AlgoChooser::AlgoChooserHelper::profile( auto target_attr = extract_algo_attribute(selected_strategy); std::string layouts_str = - format_fixlayouts(m_layouts, arity_in, arity_out); + format_fixlayouts(m_fastrun_layouts, arity_in, arity_out); double cur_timeout = 0; auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( @@ -761,10 +950,10 @@ void AlgoChooser::AlgoChooserHelper::profile( workspace_limit); mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); - FixedTensorLayouts origin_layouts = m_layouts; + FixedTensorLayouts incache_layouts = m_incache_layouts; typename Opr::Param origin_param = m_dnn_opr->param(); - AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), - origin_layouts.size(), &origin_param, + AlgoChooserProfileCache::Key cache_key{incache_layouts.data(), + incache_layouts.size(), &origin_param, sizeof(origin_param)}; AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str()); @@ -774,15 +963,20 @@ void AlgoChooser::AlgoChooserHelper::profile( template Maybe> -AlgoChooser::AlgoChooserHelper::construct_fake_preprocess_filter() const { +AlgoChooser::AlgoChooserHelper::construct_fake_preprocess_filter( + const FixedTensorLayouts& layouts) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_fake_preprocess_filter"))) Maybe> result = None; + const FixedTensorLayouts* layouts_ptr = &m_fastrun_layouts; + if (layouts.at(0).ndim) { + layouts_ptr = &layouts; + } if_constexpr()>([&](auto _) { if (!m_allow_weight_preprocess) return; auto opr = _(m_dnn_opr); auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), - m_layouts); + *layouts_ptr); //! No preprocess layout means no need weight preprocess if (layouts.empty()) { return; @@ -825,6 +1019,16 @@ AlgoChooser::AlgoChooserHelper::extract_algo_attribute( ret.second |= AlgoAttribute::NAIVE; } + //! from graph option + if (owner_graph()->options().fast_run_config.shared_batch_size) { + ret.second |= AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + } + + if (owner_graph()->options().fast_run_config.binary_equal_between_batch) { + ret.first |= AlgoAttribute::REPRODUCIBLE; + ret.second |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + } + return ret; } @@ -854,7 +1058,8 @@ AlgoChooser::AlgoChooserHelper::extract_algo_attribute( template size_t \ AlgoChooser::AlgoChooserHelper::get_workspace_size_bytes( \ const typename AlgoChooser::ImplExecutionPolicy& \ - policy) const; \ + policy, \ + const FixedTensorLayouts& layouts) const; \ template std::vector::ImplAlgo> \ AlgoChooser::AlgoChooserHelper::get_all_candidates() const; \ template Maybe \ @@ -942,10 +1147,11 @@ size_t AlgoChooser::setup_algo(const FixedTensorLayouts& layouts, if (!policy.algo.valid()) { policy = get_policy(helper); } - size_t workspace = helper.get_workspace_size_bytes(policy); + size_t workspace = helper.get_workspace_size_bytes(policy, layouts); std::string ret; ret.append(mgb_opr->dyn_typeinfo()->name); + ret.append(": tensor layouts"); ret += format_fixlayouts(layouts, arity_in, arity_out); Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); mgb_assert(palgo, "Unknown algo description"); diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index 06dcad3c..599c9489 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -68,7 +68,10 @@ class AlgoChooser { public: using FixedTensorLayouts = std::array; class AlgoChooserHelper { - FixedTensorLayouts m_layouts; + //! fastrun layouts + FixedTensorLayouts m_fastrun_layouts; + //! layouts used when get and set cache item + FixedTensorLayouts m_incache_layouts; Opr* m_dnn_opr; std::string m_param; const cg::OperatorNodeBase* m_base_mgb_opr; @@ -89,7 +92,7 @@ public: const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; } const TensorLayout& inp_layout(size_t idx) const { - return m_layouts[idx]; + return m_fastrun_layouts[idx]; } cg::ComputingGraph* owner_graph() const { return m_base_mgb_opr->owner_graph(); @@ -109,7 +112,13 @@ public: return m_dnn_opr->get_algorithm_from_desc(desc); } - const FixedTensorLayouts& layouts() const { return m_layouts; } + const FixedTensorLayouts& fastrun_layouts() const { + return m_fastrun_layouts; + } + + const FixedTensorLayouts& incache_layouts() const { + return m_incache_layouts; + } //! construct algo chain by heuristic ImplExecutionPolicy choose_by_heuristic( @@ -141,7 +150,8 @@ public: //! get workspace size required for specific execution policy size_t get_workspace_size_bytes( - const ImplExecutionPolicy& policy) const; + const ImplExecutionPolicy& policy, + const FixedTensorLayouts& layouts = {}) const; //! get all candidate algos, and the one choose_by_heuristic() is //! put first @@ -173,7 +183,8 @@ public: const ExecutionStrategy& strategy) const; private: - Maybe> construct_fake_preprocess_filter() const; + Maybe> construct_fake_preprocess_filter( + const FixedTensorLayouts& layouts = {}) const; }; template diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index 489fcad6..045c33f7 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -54,11 +54,11 @@ constexpr bool opr_contain_bias() { return std::is_same::value; } -//! matmul and batchedMatrixMul may not be usable once shape changed +//! matmul and batchedMatrixMul template -constexpr bool algo_usable_on_shape_change() { - return !(std::is_same::value || - std::is_same::value); +constexpr bool is_matmul() { + return std::is_same::value || + std::is_same::value; } template diff --git a/src/opr/test/algo_chooser.cpp b/src/opr/test/algo_chooser.cpp new file mode 100644 index 00000000..52200dc8 --- /dev/null +++ b/src/opr/test/algo_chooser.cpp @@ -0,0 +1,304 @@ +/** + * \file src/opr/test/algo_chooser.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/comp_node_env.h" + +#include "megbrain/opr/blas.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/test/autocheck.h" +#include "megbrain/test/helper.h" +#include "megbrain/test/megdnn_helper.h" +#include "megbrain/serialization/serializer.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/tensor_manip.h" +#include "megdnn/oprs/base.h" +#include "megdnn/dtype.h" + +#include +#include +#include + +using namespace mgb; + +namespace { + +#if MGB_CUDA +#if MGB_ENABLE_FASTRUN +template +struct GraphMaker; + +template +struct GraphMaker { + SymbolVar operator()(const std::array& inputs, + typename MgbOpr::Param& param, + typename MgbOpr::ExecutionPolicy& policy) { + return MgbOpr::make(inputs[0], inputs[1], param, policy); + } +}; + +template <> +struct GraphMaker { + SymbolVar operator()( + const std::array& inputs, + opr::ConvolutionBackwardData::Param& param, + opr::ConvolutionBackwardData::ExecutionPolicy& policy) { + return opr::ConvolutionBackwardData::make_deconv(inputs[0], inputs[1], + param, policy); + } +}; + +template <> +struct GraphMaker { + SymbolVar operator()( + const std::array& inputs, + opr::Convolution3DBackwardData::Param& param, + opr::Convolution3DBackwardData::ExecutionPolicy& policy) { + return opr::Convolution3DBackwardData::make_deconv(inputs[0], inputs[1], + param, policy); + } +}; + +template +struct GraphMaker { + SymbolVar operator()(const std::array& inputs, + typename MgbOpr::Param& param, + typename MgbOpr::ExecutionPolicy& policy) { + return MgbOpr::make(inputs[0], inputs[1], inputs[2], param, policy, {}); + } +}; + +template +struct GraphMaker { + SymbolVar operator()(const std::array& inputs, + typename MgbOpr::Param& param, + typename MgbOpr::ExecutionPolicy& policy) { + return MgbOpr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, + policy, {}); + } +}; + +template +struct GraphMaker { + SymbolVar operator()(const std::array& inputs, + typename MgbOpr::Param& param, + typename MgbOpr::ExecutionPolicy& policy) { + return MgbOpr::make(inputs[0], inputs[1], inputs[2], inputs[3], + inputs[4], param, policy, {}); + } +}; + +template +void test_fastrun_opr(std::array inps0, + std::array inps1, + size_t expect_nr_cache_set_inp0 = 0, + size_t expect_nr_cache_set_inp1 = 0, + typename MgbOpr::Param param = {}) { + using Policy = opr::Convolution::ExecutionPolicy; + using S = Policy::Strategy; + using InputGenerator = std::function; + using ShapeInpArray = std::array; + using CacheMem = std::pair; + auto on_get = [](const std::string&, const void*, size_t, const void*, + size_t) {}; + + std::vector> cache_set_history; + auto on_set = [&cache_set_history](const std::string&, const void* key, + size_t key_size, const void* val, + size_t val_size) { + cache_set_history.emplace_back(std::make_pair(key, key_size), + std::make_pair(val, val_size)); + }; + + PersistentCacheHook cache_hook{on_get, on_set}; + + CompNode comp_node = CompNode::load("xpu0"); + GraphMaker graph_maker; + auto run = [¶m, &comp_node, &graph_maker]( + const std::shared_ptr& graph, + const ShapeInpArray& shapes) { + std::array inputs_generator; + std::array, arith> inputs; + for (size_t i = 0; i < arith; ++i) { + inputs[i] = std::make_shared(comp_node, + dtype()); + } + HostTensorGenerator gen_host; + for (size_t i = 0; i < arith; ++i) { + inputs[i]->resize(shapes[i]); + *inputs[i] = *gen_host(inputs[i]->shape(), comp_node); + mgb_assert(inputs[i]->shape().eq_shape(shapes[i])); + } + std::array sym_in; + for (size_t i = 0; i < arith; ++i) { + // to trigger graph trans + sym_in[i] = opr::Host2DeviceCopy::make(*graph, inputs[i], + ssprintf("inp%zu", i)); + } + Policy policy; + policy.strategy = S::PROFILE; + auto out = graph_maker(sym_in, param, policy); + + std::unique_ptr func = + graph->compile({{out, {}}}); + func->execute(); + }; + + std::shared_ptr fastrun_ignore_batchsize_graph = + ComputingGraph::make(); + fastrun_ignore_batchsize_graph->options() + .fast_run_config.shared_batch_size = 20; + run(fastrun_ignore_batchsize_graph, inps0); + size_t nr_set_inp0 = cache_set_history.size(); + if (expect_nr_cache_set_inp0) { + ASSERT_EQ(cache_set_history.size(), expect_nr_cache_set_inp0); + } + run(fastrun_ignore_batchsize_graph, inps1); + size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0; + ASSERT_EQ(cache_set_history.size(), nr_set_total); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) { + REQUIRE_GPU(1); + test_fastrun_opr( + {TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}, + {TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); + + test_fastrun_opr( + {TensorShape{12, 4, 23, 29}, TensorShape{4, 5, 3, 2}}, + {TensorShape{2, 4, 23, 29}, TensorShape{4, 5, 3, 2}}); + + test_fastrun_opr( + {TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, + TensorShape{5, 4, 3, 2}}, + {TensorShape{2, 4, 23, 29}, TensorShape{2, 5, 21, 28}, + TensorShape{5, 4, 3, 2}}); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeConvBias) { + REQUIRE_GPU(1); + test_fastrun_opr( + {TensorShape{20, 16, 50, 50}, TensorShape{24, 16, 3, 3}, + TensorShape{1, 24, 1, 1}}, + {TensorShape{1, 16, 50, 50}, TensorShape{24, 16, 3, 3}, + TensorShape{1, 24, 1, 1}}); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution3D) { + REQUIRE_GPU(1); + test_fastrun_opr( + {TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}, + {TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); + + test_fastrun_opr( + {TensorShape{14, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}}, + {TensorShape{4, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}}); + + test_fastrun_opr( + {TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, + TensorShape{16, 16, 1, 1, 1}}, + {TensorShape{4, 16, 18, 18, 18}, TensorShape{4, 16, 18, 18, 18}, + TensorShape{16, 16, 1, 1, 1}}); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeLocalShare) { + REQUIRE_GPU(1); + opr::LocalShare::Param local_share_param; + local_share_param.mode = opr::LocalShare::Param::Mode::CROSS_CORRELATION; + local_share_param.pad_h = local_share_param.pad_w = 1; + local_share_param.stride_h = local_share_param.stride_w = 1; + local_share_param.spatial_groups_h = local_share_param.spatial_groups_w = 2; + test_fastrun_opr( + {TensorShape{32, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}}, + {TensorShape{3, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}}, 0, 0, + local_share_param); + + test_fastrun_opr( + {TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{32, 128, 24, 24}, + TensorShape{32, 128, 24, 24}}, + {TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{2, 128, 24, 24}, + TensorShape{2, 128, 24, 24}}); + + test_fastrun_opr( + {TensorShape{12, 3, 36, 36}, TensorShape{12, 4, 35, 35}, + TensorShape{3, 3, 3, 3, 3, 4}}, + {TensorShape{4, 3, 36, 36}, TensorShape{4, 4, 35, 35}, + TensorShape{3, 3, 3, 3, 3, 4}}); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeDeformableConv) { + REQUIRE_GPU(1); + test_fastrun_opr( + {TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, + TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}}, + {TensorShape{4, 6, 20, 20}, TensorShape{6, 6, 3, 3}, + TensorShape{4, 18, 18, 18}, TensorShape{4, 9, 18, 18}}); + + test_fastrun_opr( + {TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, + TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}, + TensorShape{12, 6, 18, 18}}, + {TensorShape{4, 6, 20, 20}, + TensorShape{6, 6, 3, 3}, + TensorShape{4, 18, 18, 18}, + TensorShape{4, 9, 18, 18}, + TensorShape{4, 6, 18, 18}}); + + test_fastrun_opr( + {TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, + TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}, + TensorShape{12, 6, 18, 18}}, + {TensorShape{4, 6, 20, 20}, TensorShape{6, 6, 3, 3}, + TensorShape{4, 18, 18, 18}, TensorShape{4, 9, 18, 18}, + TensorShape{4, 6, 18, 18}}); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeMatrixMul) { + REQUIRE_GPU(1); + //! fastrun_shared_batch_size == 20 + //! {20(12), 12(1)}, {12(12), 20(1)} -> {20(12), 20(1)} origin + //! {12(10), 20(1)}, {12(12), 20(1)} -> {20(12), 20(1)} transA + //! {12(10), 20(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transA, transB + //! {20(12), 12(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transB + //! + //! {20(12), 12(1)}, {12(12), 20(1)} -> {20(12), 20(1)} origin duplicate + //! {12(4), 20(1)}, {12(12), 20(1)} -> {20(12), 20(1)} transA + //! {12(4), 20(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transA, transB + //! {20(12), 12(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transB duplicate + test_fastrun_opr( + {TensorShape{10, 12}, TensorShape{12, 12}}, + {TensorShape{4, 12}, TensorShape{12, 12}}, 4, 2); +} + +TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) { + REQUIRE_GPU(1); + + //! fastrun_shared_batch_size == 20 + //! {20(48), 6(8), 8(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} origin + //! {20(48), 8(6), 6(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} transA + //! {20(48), 8(6), 6(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transA, transB + //! {20(48), 6(8), 8(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transB + //! + //! {20(48), 6(8), 8(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} origin duplicate + //! {20(48), 8(6), 6(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} transA duplicate + //! {20(48), 8(6), 6(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transA, transB duplicate + //! {20(48), 6(8), 8(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transB duplicate + test_fastrun_opr( + {TensorShape{12, 6, 8}, TensorShape{12, 8, 4}}, + {TensorShape{4, 6, 8}, TensorShape{4, 8, 4}}); +} + +#endif // MGB_ENABLE_FASTRUN +#endif // MGB_CUDA + +} // anonymous namespace + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/test/src/helper.cpp b/test/src/helper.cpp index 9a2a0f31..9ea0c151 100644 --- a/test/src/helper.cpp +++ b/test/src/helper.cpp @@ -460,12 +460,13 @@ mgb::make_callback_copy(SymbolVar dev, HostTensorND &host, bool sync) { /* ========================== PersistentCacheHook ========================== */ class PersistentCacheHook::HookedImpl final : public PersistentCache { - GetHook m_on_get; + Hook m_on_get, m_on_set; public: std::shared_ptr orig_impl; - HookedImpl(GetHook on_get) : m_on_get{std::move(on_get)} {} + HookedImpl(Hook on_get, Hook on_set) + : m_on_get{std::move(on_get)}, m_on_set{std::move(on_set)} {} Maybe get(const std::string& category, const Blob& key) override { auto ret = orig_impl->get(category, key); @@ -476,12 +477,18 @@ public: void put(const std::string& category, const Blob& key, const Blob& value) override { + m_on_set(category, key.ptr, key.size, value.ptr, + value.size); orig_impl->put(category, key, value); } }; -PersistentCacheHook::PersistentCacheHook(GetHook on_get) - : m_impl{std::make_shared(std::move(on_get))} { +PersistentCacheHook::Hook PersistentCacheHook::default_set_hook = + [](const std::string&, const void*, size_t, const void*, size_t) {}; + +PersistentCacheHook::PersistentCacheHook(Hook on_get, Hook on_set) + : m_impl{std::make_shared(std::move(on_get), + std::move(on_set))} { m_impl->orig_impl = PersistentCache::set_impl(m_impl); } diff --git a/test/src/include/megbrain/test/helper.h b/test/src/include/megbrain/test/helper.h index fc566033..4b2419bf 100644 --- a/test/src/include/megbrain/test/helper.h +++ b/test/src/include/megbrain/test/helper.h @@ -512,17 +512,17 @@ bool check_device_type_avaiable(CompNode::DeviceType device_type); //! hook persistent cache get calls during the lifetime class PersistentCacheHook { - class HookedImpl; - - std::shared_ptr m_impl; - public: - //! if value is not available, \p val and \p val_size would be zero - using GetHook = thin_function; - PersistentCacheHook(GetHook on_get); + using Hook = thin_function; + PersistentCacheHook(Hook on_get, Hook on_set = default_set_hook); + ~PersistentCacheHook(); +private: + static Hook default_set_hook; + class HookedImpl; + std::shared_ptr m_impl; }; //! skip a testcase if xpu not available #define REQUIRE_XPU(n) do { \ -- GitLab