提交 50db9b84 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(gopt): fix paramfuse if the endpoint is const

GitOrigin-RevId: f666f6d70037debbff34551149d04b0bd8c256f4
上级 35bc0e1f
...@@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( ...@@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); auto &&ins = m_varmap.insert({out0[i], {true, nullptr}});
mgb_assert(ins.second || ins.first->second.first, mgb_assert(ins.second || ins.first->second.first,
"opr output already replaced"); "opr output already replaced");
// handle repeated call on the same opr // handle repeated call on the same opr
ins.first->second.second = out1[i]; ins.first->second.second = out1[i];
on_var_replaced(out0[i], out1[i], nullptr); on_var_replaced(out0[i], out1[i], nullptr);
...@@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( ...@@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
/* ================ ConstVarPropogateBase ================ */ /* ================ ConstVarPropogateBase ================ */
ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
OperatorNodeBase *opr) { OperatorNodeBase *opr) {
using ProfFlag = OperatorNodeBase::NodeProp::Flag; using ProfFlag = OperatorNodeBase::NodeProp::Flag;
auto &&info = m_oprinfo[opr]; auto &&info = m_oprinfo[opr];
...@@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( ...@@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(
#endif #endif
info.max_size = max_input_size; info.max_size = max_input_size;
info.is_const = true; info.is_const = true;
on_midconst_opr(opr, max_input_size);
} }
return make_ret(); return make_ret();
} }
......
...@@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const { ...@@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const {
/* ================ ParamFusePass ================ */ /* ================ ParamFusePass ================ */
class ParamFusePass::ConstVarPropogateWithSizeCheck final:
public ConstVarPropogateBase
{
public:
//! rewrite a var; reader == nullptr means needed by endpoint
using VarRewriter = std::function<
void(VarNode *var, OperatorNodeBase *reader)>;
ConstVarPropogateWithSizeCheck(
const ParamFusePass &pf, OptState &opt_state,
const VarRewriter &rewriter):
ConstVarPropogateBase{ConstVarType::IMMUTABLE_AND_PARAM},
m_owner{pf}, m_opt_state{opt_state}, m_rewriter{rewriter}
{
}
private:
const ParamFusePass &m_owner;
OptState &m_opt_state;
VarRewriter m_rewriter;
void on_midconst_opr(
OperatorNodeBase *opr, size_t max_src_size) override {
for (auto var: opr->output()) {
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
continue;
auto osize = var_mem_size(var);
if (osize >= max_src_size &&
osize - max_src_size > m_owner.m_param_grow_limit) {
return;
}
// const oprs should be evaluated when output is used by another
// non-const opr or output is needed by the user
if (m_opt_state.graph().endpoint_contain(var)) {
m_rewriter(var, nullptr);
}
}
}
};
/*! /*!
* \brief get name for new param * \brief get name for new param
*/ */
...@@ -565,9 +521,15 @@ const char* ParamFusePass::name() const { ...@@ -565,9 +521,15 @@ const char* ParamFusePass::name() const {
void ParamFusePass::apply(OptState &state) const { void ParamFusePass::apply(OptState &state) const {
auto rewriter = state.graph().make_rewriter(); auto rewriter = state.graph().make_rewriter();
auto cg = state.graph().comp_graph(); auto cg = state.graph().comp_graph();
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
state.graph().iter([&cvprop](OperatorNodeBase *opr) {
cvprop.add_opr(opr);
});
ThinHashSet<VarNode*> processed_var; ThinHashSet<VarNode*> processed_var;
VarNamer var_namer; VarNamer var_namer;
// reader: null if used as endvar // reader: null if used as endvar
auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) { auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) {
if (!processed_var.insert(var).second) if (!processed_var.insert(var).second)
...@@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const { ...@@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const {
rewriter.replace_var(var, new_var.node(), log.c_str()); rewriter.replace_var(var, new_var.node(), log.c_str());
}; };
ConstVarPropogateWithSizeCheck cvprop{*this, state, replace_single_var}; auto replace_opr = [&](OperatorNodeBase* opr) {
auto on_opr = [&](OperatorNodeBase *opr) { auto add_ret = cvprop.opr_rst(opr);
auto add_ret = cvprop.add_opr(opr);
if (!add_ret.all_const_inp && add_ret.has_midconst_inp) { if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
for (auto i: opr->input()) { for (auto i: opr->input()) {
if (cvprop.is_midconst(i)) { if (cvprop.is_midconst(i)) {
...@@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const { ...@@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const {
} }
} }
rewriter.auto_replace_outputs(opr); rewriter.auto_replace_outputs(opr);
//! we should deal with midconst var after auto_replace_outputs, as
//! on_midconst_opr will replace the endpoint output which may cause
//! double replace.
if (add_ret.all_const_inp) {
for (auto var : opr->output()) {
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
continue;
auto osize = ConstVarPropogate::var_mem_size(var);
if (osize >= cvprop.max_size(opr) &&
osize - cvprop.max_size(opr) > m_param_grow_limit) {
return;
}
// const oprs should be evaluated when output is used by another
// non-const opr or output is needed by the user
if (state.graph().endpoint_contain(var)) {
replace_single_var(var, nullptr);
}
}
}
}; };
state.graph().iter(on_opr); state.graph().iter(replace_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
} }
......
...@@ -490,28 +490,17 @@ namespace gopt { ...@@ -490,28 +490,17 @@ namespace gopt {
* Usually you would want to use ConstVarPropogate, and this base class * Usually you would want to use ConstVarPropogate, and this base class
* exists to avoid virtual dtor while allowing polymorphism. * exists to avoid virtual dtor while allowing polymorphism.
*/ */
class ConstVarPropogateBase { class ConstVarPropogate{
protected:
~ConstVarPropogateBase() = default;
//! memory usage of a var
static size_t var_mem_size(VarNode *var) {
return var->dtype().size(var->shape().total_nr_elems());
}
//! called after a const but non-source opr is visited
virtual void on_midconst_opr(
OperatorNodeBase *opr, size_t max_src_size) {
MGB_MARK_USED_VAR(opr);
MGB_MARK_USED_VAR(max_src_size);
}
public: public:
explicit ConstVarPropogateBase(ConstVarType const_var_type): explicit ConstVarPropogate(ConstVarType const_var_type):
m_const_var_type{const_var_type} m_const_var_type{const_var_type}
{ {
} }
ConstVarPropogate() = default;
~ConstVarPropogate() = default;
//! note that both attrs would be false if opr is impure or it is //! note that both attrs would be false if opr is impure or it is
//! not allowed to be replaced //! not allowed to be replaced
struct AddOprResult { struct AddOprResult {
...@@ -527,12 +516,19 @@ namespace gopt { ...@@ -527,12 +516,19 @@ namespace gopt {
AddOprResult add_opr(OperatorNodeBase *opr); AddOprResult add_opr(OperatorNodeBase *opr);
const AddOprResult& opr_rst(OperatorNodeBase *opr) const {
return m_oprinfo.at(opr).result;
}
bool is_const(OperatorNodeBase *opr) const { bool is_const(OperatorNodeBase *opr) const {
return m_oprinfo.at(opr).is_const; return m_oprinfo.at(opr).is_const;
} }
bool is_const(VarNode *var) const { bool is_const(VarNode *var) const {
return is_const(var->owner_opr()); return is_const(var->owner_opr());
} }
size_t max_size(OperatorNodeBase *opr) const {
return m_oprinfo.at(opr).max_size;
}
//! whether a var is produced by non-source const opr //! whether a var is produced by non-source const opr
bool is_midconst(OperatorNodeBase *opr) const { bool is_midconst(OperatorNodeBase *opr) const {
...@@ -543,6 +539,11 @@ namespace gopt { ...@@ -543,6 +539,11 @@ namespace gopt {
return is_midconst(var->owner_opr()); return is_midconst(var->owner_opr());
} }
//! memory usage of a var
static size_t var_mem_size(VarNode *var) {
return var->dtype().size(var->shape().total_nr_elems());
}
private: private:
struct OprInfo { struct OprInfo {
bool processed = false, is_const = false; bool processed = false, is_const = false;
...@@ -556,11 +557,6 @@ namespace gopt { ...@@ -556,11 +557,6 @@ namespace gopt {
}; };
class ConstVarPropogate final: public ConstVarPropogateBase {
public:
using ConstVarPropogateBase::ConstVarPropogateBase;
};
} // namespace gopt } // namespace gopt
} // namespace mgb } // namespace mgb
......
...@@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H, ...@@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H,
#endif #endif
} // namespace } // namespace
TEST(TestGoptInference, ParamFuseConstEndPoint) {
constexpr size_t SIZE = 23;
HostTensorGenerator<> gen;
auto host_x = gen({SIZE}), host_y = gen({1}), host_p = gen({1});
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto x = opr::SharedDeviceTensor::make(*graph, *host_x),
y = opr::SharedDeviceTensor::make(*graph, *host_y),
p = opr::Host2DeviceCopy::make(*graph, host_p),
q = p + x,
a = y + 3,
z0 = a + q,
z1 = a + 4;
HostTensorND host_z0, host_z1;
SymbolVar z0_1, z1_1;
unpack_vector(
gopt::GraphOptimizer{}.
add_pass<gopt::ParamFusePass>().
apply({{z1, z0}}).endpoint_vars(),
z1_1, z0_1);
auto func = graph->compile({make_callback_copy(z0_1, host_z0),
make_callback_copy(z1_1, host_z1)});
func->to_json()->writeto_fpath(
output_file("TestGoptInference.ParamFuseEndPoint.json"));
func->execute();
int nr_opr = 0;
func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; });
ASSERT_EQ(8, nr_opr);
auto px = host_x->ptr<float>(), pz0 = host_z0.ptr<float>();
auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0],
pz1 = host_z1.ptr<float>()[0];
for (size_t i = 0; i < SIZE; ++ i) {
MGB_ASSERT_FLOAT_EQ(px[i] + yv + 3 + pv, pz0[i]);
}
MGB_ASSERT_FLOAT_EQ(yv + 7, pz1);
}
TEST(TestGoptInference, ParamFuse) { TEST(TestGoptInference, ParamFuse) {
constexpr size_t SIZE = 23; constexpr size_t SIZE = 23;
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
...@@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) { ...@@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) {
func->execute(); func->execute();
int nr_opr = 0; int nr_opr = 0;
func->iter_opr_seq([&](cg::OperatorNodeBase*op) {++ nr_opr; return true; }); func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; });
ASSERT_EQ(6, nr_opr); ASSERT_EQ(6, nr_opr);
auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(), auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册