diff --git a/src/gopt/impl/profiler_impl.cpp b/src/gopt/impl/profiler_impl.cpp index 38f7e6ca90946def2fefcd39fa8086cd51ee87f0..a760245c562cb0047455676b90c0fb7739fb5523 100644 --- a/src/gopt/impl/profiler_impl.cpp +++ b/src/gopt/impl/profiler_impl.cpp @@ -19,6 +19,7 @@ #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/nn_int.h" #include "megbrain/plugin/base.h" #include "megbrain/serialization/sereg.h" @@ -246,6 +247,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, } auto new_opr = serialization::copy_opr_shallow( *opr, new_inps, opr->config(), {graph.get()}); + if (!m_opr_filter(opr, new_opr)) + return PROFILE_TIME_OUT; auto y = new_opr->output(0); auto mark = MarkInputContiguous::make(SymbolVar(y)); auto func = graph->compile({{mark, {}}}); @@ -338,6 +341,8 @@ float ProfilerImpl::profile_operator( !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) return PROFILE_TIME_OUT; #endif + if (!m_opr_filter(opr, y->owner_opr())) + return PROFILE_TIME_OUT; auto mark = MarkInputContiguous::make(SymbolVar(y)); auto func = graph->compile({{mark, {}}}); auto new_opr = y->owner_opr(); @@ -384,6 +389,9 @@ float ProfilerImpl::profile_var_node(const VarNode* var, auto builder = ReformatManager::instance().auto_aligned_reformat_featrue( var, base_format, key); auto y = builder({aligned_var.node()}); + if (!m_var_node_filter(var, aligned_tensor_shape, y->shape(), + TensorFormat{})) + return PROFILE_TIME_OUT; ThinHashSet set; DepOprIter iter([&set](OperatorNodeBase* opr) { set.insert(opr); }); iter.add(y->owner_opr()); @@ -503,6 +511,40 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( } /* ================== ProfilerBase =================*/ +ProfilerBase::ProfilerBase(float opr_threshold, float var_node_threshold) + : m_opr_threshold{opr_threshold}, + m_var_node_threshold{var_node_threshold} { + m_opr_filter = [this](const OperatorNodeBase* opr, + OperatorNodeBase* new_opr) { + float comp1 = m_opr_footprint.get_computation( + const_cast(opr)); + float comp2 = m_opr_footprint.get_computation(new_opr); + if (comp2 > m_opr_threshold * comp1) + return false; + return true; + }; + m_var_node_filter = [this](const VarNode* var, TensorShape from, + TensorShape to, TensorFormat format) { + TensorFormat default_; + TensorLayout orig_ly, from_ly, to_ly; + if (format == default_) { + orig_ly = {var->shape(), var->dtype()}; + from_ly = {from, var->dtype()}; + to_ly = {to, var->dtype()}; + } else { + orig_ly = {var->shape(), var->dtype(), format}; + from_ly = {from, var->dtype(), format}; + to_ly = {to, var->dtype(), format}; + } + float orig_memory = orig_ly.span().dist_byte() * 2.f; + float reformat_memory = + from_ly.span().dist_byte() + to_ly.span().dist_byte(); + if (reformat_memory > orig_memory * m_var_node_threshold) + return false; + return true; + }; +} + std::string ProfilerBase::OperatorNodeRecord::to_string() const { auto str = ssprintf("\nopr type: %s\nopr name: %s\ninputs:\n", opr->dyn_typeinfo()->name, opr->cname()); diff --git a/src/gopt/include/megbrain/gopt/global_layout_transform.h b/src/gopt/include/megbrain/gopt/global_layout_transform.h index 9a851869a42ed3483ed55adce0a3720a2134a280..5d72971756a91a974f69ee27ab0489c22317a4ea 100644 --- a/src/gopt/include/megbrain/gopt/global_layout_transform.h +++ b/src/gopt/include/megbrain/gopt/global_layout_transform.h @@ -15,6 +15,7 @@ #include "megbrain/gopt/reformat_manager.h" #include "megbrain/gopt/subgraph_extractor.h" #include "megbrain/opr/dnn/convolution.h" +#include "megbrain/plugin/opr_footprint.h" namespace mgb { namespace gopt { @@ -218,11 +219,27 @@ public: /// A hashmap, that maps the var node to the costs of layout transform ThinHashMap var_record; }; + using OprFilter = thin_function; + using VarNodeFilter = thin_function; - ProfilerBase() = default; + ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); + ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) + : m_opr_filter{std::move(opr_filter)}, + m_var_node_filter{std::move(var_node_filter)} {} virtual ~ProfilerBase() = default; virtual ProfilingResult profile(const Problem& problem) const = 0; static std::unique_ptr make_profiler(); + +protected: + OprFilter m_opr_filter; + VarNodeFilter m_var_node_filter; + float m_opr_threshold; + float m_var_node_threshold; + +private: + OprFootprint m_opr_footprint; }; /*!