提交 2ec7c167 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): profiler support opr filter and var node filter

GitOrigin-RevId: 5f8d86687f6316a80cd89601361b308c13a62f59
上级 50ea5ae8
......@@ -19,6 +19,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h"
......@@ -246,6 +247,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
}
auto new_opr = serialization::copy_opr_shallow(
*opr, new_inps, opr->config(), {graph.get()});
if (!m_opr_filter(opr, new_opr))
return PROFILE_TIME_OUT;
auto y = new_opr->output(0);
auto mark = MarkInputContiguous::make(SymbolVar(y));
auto func = graph->compile({{mark, {}}});
......@@ -338,6 +341,8 @@ float ProfilerImpl::profile_operator(
!mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()))
return PROFILE_TIME_OUT;
#endif
if (!m_opr_filter(opr, y->owner_opr()))
return PROFILE_TIME_OUT;
auto mark = MarkInputContiguous::make(SymbolVar(y));
auto func = graph->compile({{mark, {}}});
auto new_opr = y->owner_opr();
......@@ -384,6 +389,9 @@ float ProfilerImpl::profile_var_node(const VarNode* var,
auto builder = ReformatManager::instance().auto_aligned_reformat_featrue(
var, base_format, key);
auto y = builder({aligned_var.node()});
if (!m_var_node_filter(var, aligned_tensor_shape, y->shape(),
TensorFormat{}))
return PROFILE_TIME_OUT;
ThinHashSet<OperatorNodeBase*> set;
DepOprIter iter([&set](OperatorNodeBase* opr) { set.insert(opr); });
iter.add(y->owner_opr());
......@@ -503,6 +511,40 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
}
/* ================== ProfilerBase =================*/
ProfilerBase::ProfilerBase(float opr_threshold, float var_node_threshold)
: m_opr_threshold{opr_threshold},
m_var_node_threshold{var_node_threshold} {
m_opr_filter = [this](const OperatorNodeBase* opr,
OperatorNodeBase* new_opr) {
float comp1 = m_opr_footprint.get_computation(
const_cast<OperatorNodeBase*>(opr));
float comp2 = m_opr_footprint.get_computation(new_opr);
if (comp2 > m_opr_threshold * comp1)
return false;
return true;
};
m_var_node_filter = [this](const VarNode* var, TensorShape from,
TensorShape to, TensorFormat format) {
TensorFormat default_;
TensorLayout orig_ly, from_ly, to_ly;
if (format == default_) {
orig_ly = {var->shape(), var->dtype()};
from_ly = {from, var->dtype()};
to_ly = {to, var->dtype()};
} else {
orig_ly = {var->shape(), var->dtype(), format};
from_ly = {from, var->dtype(), format};
to_ly = {to, var->dtype(), format};
}
float orig_memory = orig_ly.span().dist_byte() * 2.f;
float reformat_memory =
from_ly.span().dist_byte() + to_ly.span().dist_byte();
if (reformat_memory > orig_memory * m_var_node_threshold)
return false;
return true;
};
}
std::string ProfilerBase::OperatorNodeRecord::to_string() const {
auto str = ssprintf("\nopr type: %s\nopr name: %s\ninputs:\n",
opr->dyn_typeinfo()->name, opr->cname());
......
......@@ -15,6 +15,7 @@
#include "megbrain/gopt/reformat_manager.h"
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/plugin/opr_footprint.h"
namespace mgb {
namespace gopt {
......@@ -218,11 +219,27 @@ public:
/// A hashmap, that maps the var node to the costs of layout transform
ThinHashMap<VarNode*, VarNodeRecord> var_record;
};
using OprFilter = thin_function<bool(const cg::OperatorNodeBase*,
cg::OperatorNodeBase*)>;
using VarNodeFilter = thin_function<bool(const VarNode*, TensorShape,
TensorShape, TensorFormat)>;
ProfilerBase() = default;
ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f);
ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {})
: m_opr_filter{std::move(opr_filter)},
m_var_node_filter{std::move(var_node_filter)} {}
virtual ~ProfilerBase() = default;
virtual ProfilingResult profile(const Problem& problem) const = 0;
static std::unique_ptr<ProfilerBase> make_profiler();
protected:
OprFilter m_opr_filter;
VarNodeFilter m_var_node_filter;
float m_opr_threshold;
float m_var_node_threshold;
private:
OprFootprint m_opr_footprint;
};
/*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册