提交 12dc36a6 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mgb/gopt): add interface to reproducible

GitOrigin-RevId: f341bea40b6e52f4598640b81b477184d8473421
上级 cc4e1dfd
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "./json_loader.h" #include "./json_loader.h"
#include "./npy.h" #include "./npy.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/utils/debug.h" #include "megbrain/utils/debug.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/extern_c_opr.h" #include "megbrain/serialization/extern_c_opr.h"
...@@ -144,6 +145,10 @@ R"__usage__( ...@@ -144,6 +145,10 @@ R"__usage__(
R"__usage__( R"__usage__(
--fast-run-algo-policy <path> --fast-run-algo-policy <path>
It will read the cache file before profile, and save new fastrun in cache file. It will read the cache file before profile, and save new fastrun in cache file.
--reproducible
Enable choose algo which is reproducible. It mainly used for cudnn algos.
See https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#reproducibility
for more details.
--wait-gdb --wait-gdb
Print PID and wait for a line from stdin before starting execution. Useful Print PID and wait for a line from stdin before starting execution. Useful
for waiting for gdb attach. for waiting for gdb attach.
...@@ -467,6 +472,7 @@ struct Args { ...@@ -467,6 +472,7 @@ struct Args {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
bool use_fast_run = false; bool use_fast_run = false;
#endif #endif
bool reproducible = false;
std::string fast_run_cache_path; std::string fast_run_cache_path;
bool copy_to_host = false; bool copy_to_host = false;
int nr_run = 10; int nr_run = 10;
...@@ -647,10 +653,24 @@ void run_test_st(Args &env) { ...@@ -647,10 +653,24 @@ void run_test_st(Args &env) {
} }
mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit);
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy;
S strategy = S::HEURISTIC;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (env.use_fast_run) if (env.use_fast_run) {
mgb::gopt::enable_opr_algo_profiling_inplace(vars); if (env.reproducible) {
strategy = S::PROFILE_REPRODUCIBLE;
} else {
strategy = S::PROFILE;
}
} else if (env.reproducible) {
strategy = S::HEURISTIC_REPRODUCIBLE;
}
#else
if (env.reproducible) {
strategy = S::HEURISTIC_REPRODUCIBLE;
}
#endif #endif
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
if (!env.fast_run_cache_path.empty()) { if (!env.fast_run_cache_path.empty()) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (!access(env.fast_run_cache_path.c_str(), F_OK)) { if (!access(env.fast_run_cache_path.c_str(), F_OK)) {
...@@ -1149,6 +1169,10 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -1149,6 +1169,10 @@ Args Args::from_argv(int argc, char **argv) {
ret.fast_run_cache_path = argv[i]; ret.fast_run_cache_path = argv[i];
continue; continue;
} }
if (!strcmp(argv[i], "--reproducible")) {
ret.reproducible = true;
continue;
}
if (!strcmp(argv[i], "--const-shape")) { if (!strcmp(argv[i], "--const-shape")) {
ret.load_config.const_var_shape = true; ret.load_config.const_var_shape = true;
continue; continue;
......
...@@ -104,25 +104,21 @@ SymbolVarArray gopt::optimize_for_inference( ...@@ -104,25 +104,21 @@ SymbolVarArray gopt::optimize_for_inference(
} }
namespace { namespace {
void modify_conv_policy(opr::mixin::Convolution& conv, void modify_conv_strategy(
megdnn::param::ExecutionPolicy::Strategy strategy) { opr::mixin::Convolution& conv,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
auto policy = conv.execution_policy_transient(); auto policy = conv.execution_policy_transient();
policy.strategy = strategy; policy.strategy = strategy;
conv.set_execution_policy(policy); conv.set_execution_policy(policy);
} }
template <typename Opr> template <typename Opr>
void inplace_conv_opr_profile_modifier(OperatorNodeBase& opr) { void inplace_conv_opr_modifier(
modify_conv_policy( OperatorNodeBase& opr,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
modify_conv_strategy(
opr.cast_final_safe<Opr>(), opr.cast_final_safe<Opr>(),
opr::mixin::Convolution::ExecutionPolicy::Strategy::PROFILE); strategy);
}
template <typename Opr>
void inplace_conv_opr_profile_cache_modifier(OperatorNodeBase& opr) {
modify_conv_policy(opr.cast_final_safe<Opr>(),
opr::mixin::Convolution::ExecutionPolicy::Strategy::
PROFILE_HEURISTIC);
} }
void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv,
...@@ -150,12 +146,20 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, ...@@ -150,12 +146,20 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \
cb(BatchConvBiasForward), cb(BatchConvBiasForward),
void gopt::enable_opr_algo_profiling_inplace( void gopt::modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars) { const VarNodeArrayView& dest_vars,
#if MGB_ENABLE_FASTRUN opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers = #if !MGB_ENABLE_FASTRUN
{ using S = opr::mixin::Convolution::ExecutionPolicy::Strategy;
#define CONV(t) {opr::t::typeinfo(), &inplace_conv_opr_profile_modifier<opr::t>} if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) {
mgb_throw(MegBrainError, "fastrun is disabled at compile time");
}
#endif
const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>>
modifiers = {
#define CONV(t) \
{opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \
std::placeholders::_1, strategy)}
MGB_FOREACH_FASTRUN_OPR(CONV) MGB_FOREACH_FASTRUN_OPR(CONV)
#undef CONV #undef CONV
}; };
...@@ -171,34 +175,23 @@ void gopt::enable_opr_algo_profiling_inplace( ...@@ -171,34 +175,23 @@ void gopt::enable_opr_algo_profiling_inplace(
for (auto i : dest_vars) { for (auto i : dest_vars) {
dep_iter.add(i); dep_iter.add(i);
} }
#else
mgb_throw(MegBrainError, "fastrun is disabled at compile time");
#endif
} }
void gopt::enable_opr_use_profiling_cache_inplace( void gopt::enable_opr_algo_profiling_inplace(
const VarNodeArrayView& dest_vars) { const VarNodeArrayView& dest_vars) {
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers = modify_opr_algo_strategy_inplace(dest_vars,
{ opr::mixin::Convolution::ExecutionPolicy::
#define CONV(t) \ Strategy::PROFILE);
{opr::t::typeinfo(), &inplace_conv_opr_profile_cache_modifier<opr::t>} }
MGB_FOREACH_FASTRUN_OPR(CONV)
#undef CONV
};
auto on_opr = [&](OperatorNodeBase* opr) {
auto iter = modifiers.find(opr->dyn_typeinfo());
if (iter != modifiers.end()) {
iter->second(*opr);
}
};
cg::DepOprIter dep_iter{on_opr}; void gopt::enable_opr_use_profiling_cache_inplace(
for (auto i : dest_vars) { const VarNodeArrayView& dest_vars) {
dep_iter.add(i); modify_opr_algo_strategy_inplace(dest_vars,
} opr::mixin::Convolution::ExecutionPolicy::
Strategy::PROFILE_HEURISTIC);
} }
void gopt::set_opr_algo_workspace_limit_inplace( void gopt::set_opr_algo_workspace_limit_inplace(
const VarNodeArrayView& dest_vars, size_t workspace_limit) { const VarNodeArrayView& dest_vars, size_t workspace_limit) {
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "megbrain/gopt/framework.h" #include "megbrain/gopt/framework.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/opr/dnn/convolution.h"
namespace mgb { namespace mgb {
namespace gopt { namespace gopt {
...@@ -302,6 +303,17 @@ namespace gopt { ...@@ -302,6 +303,17 @@ namespace gopt {
const SymbolVarArray& dest_vars, const SymbolVarArray& dest_vars,
const OptimizeForInferenceOptions& opt = {}); const OptimizeForInferenceOptions& opt = {});
/*!
* \brief modify execution strategy for oprs with multiple
* algorithms
*
* This would modify the operators inplace. It can be used for implement
* the fast-run mode.
*/
void modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy);
/*! /*!
* \brief enable PROFILE execution strategy for oprs with multiple * \brief enable PROFILE execution strategy for oprs with multiple
* algorithms * algorithms
...@@ -315,7 +327,7 @@ namespace gopt { ...@@ -315,7 +327,7 @@ namespace gopt {
void enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars); void enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars);
/*! /*!
* \brief enable opr try profiling cache first, if failed, then try * \brief enable opr try profiling cache first, if failed, fallback to
* heuristic * heuristic
* *
* This would modify the operators inplace. It is usually used to enable * This would modify the operators inplace. It is usually used to enable
...@@ -324,7 +336,8 @@ namespace gopt { ...@@ -324,7 +336,8 @@ namespace gopt {
* You may want to implement TimedFuncInvoker::ForkExecImpl and/or * You may want to implement TimedFuncInvoker::ForkExecImpl and/or
* PersistentCache for better performance in an SDK. * PersistentCache for better performance in an SDK.
*/ */
void enable_opr_use_profiling_cache_inplace(const VarNodeArrayView& dest_vars); void enable_opr_use_profiling_cache_inplace(
const VarNodeArrayView& dest_vars);
/*! /*!
* \brief set workspace_limit for execution strategy for oprs with multiple * \brief set workspace_limit for execution strategy for oprs with multiple
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册