提交 f3378100 编写于 作者: M Megvii Engine Team

feat(mgb): enable output dynamic memory alloc

GitOrigin-RevId: c809629034f87dfeacd6586ca96aa9c110e0a3c9
上级 e82fa4ec
......@@ -14,6 +14,8 @@
using namespace mgb::cg;
MGB_TYPEINFO_OBJ_IMPL(OutputVarsUserData);
GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph):
m_owner_graph{owner_graph}
{
......
......@@ -563,6 +563,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
std::unordered_map<CallbackCallerKey, CallbackCallerVal,
CallbackCallerKey::Hash>
opr2vars;
using F = VarNode::Flag;
if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) {
for (auto&& i : dest_vars) {
if (!i->contain_flag(F::NO_SYS_MEM_ALLOC |
F::NO_SYS_STATIC_MEM_ALLOC)) {
mgb_assert(
!i->contain_flag(
F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC),
"Can not force graph output dynamic alloc with "
"DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s",
i->cname());
i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC);
}
i->add_flag(F::NO_MEM_RECLAIM);
}
}
for (size_t i = 0; i < out_spec.size(); ++i) {
auto&& cb = out_spec[i].second;
if (cb) {
......@@ -641,13 +657,14 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
init_opr_seq();
#endif // MGB_ENABLE_SUBLINEAR
return {std::move(extra_info), opr_seq};
return {std::move(extra_info), opr_seq, std::move(dest_vars)};
}
std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit(
CompileState state) {
auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this());
comp_seq->extra_info = std::move(state.extra_info);
comp_seq->set_output_vars(state.dest_vars);
auto opr_seq = state.opr_seq;
auto&& cmpnt = components();
......
......@@ -38,6 +38,7 @@ class ComputingGraphImpl final : public ComputingGraph {
//! extra info that must be set in the ComputingSequence
CompSeqExtraInfo extra_info;
const OprNodeArray* opr_seq = nullptr;
VarNodeArray dest_vars;
};
struct CallbackCallerKey {
......
......@@ -67,9 +67,10 @@ namespace static_infer {
};
using GraphError = mgb::GraphError;
class VarNode;
class OperatorNodeBase;
class ComputingGraph;
using VarNodeArray = mgb::SmallVector<VarNode*>;
/*!
* \brief Base class for a node in the graph.
*
......@@ -102,6 +103,17 @@ class GraphNodeBase: public json::Serializable, public NonCopyableObj {
}
};
class OutputVarsUserData final : public mgb::UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
private:
VarNodeArray m_output_vars;
public:
void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); }
const VarNodeArray& get_output_vars() const { return m_output_vars; }
};
/*!
* \brief an object that executes asynchronously
*/
......@@ -165,6 +177,19 @@ class AsyncExecutable : public json::Serializable,
UserDataContainer& user_data() {
return m_user_data;
}
void set_output_vars(const VarNodeArray& vars) {
std::shared_ptr<OutputVarsUserData> ud =
std::make_shared<OutputVarsUserData>();
ud->set_output_vars(vars);
m_user_data.add_user_data(ud);
}
const VarNodeArray& get_output_vars() const {
auto output_vars_pair =
m_user_data.get_user_data<OutputVarsUserData>();
return (*(output_vars_pair.first))->get_output_vars();
}
};
......
......@@ -399,6 +399,12 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! force dynamic memory alloc for all vars
bool force_dynamic_alloc = false;
/*!
* force dynamic memory alloc for output vars which are used as
* CallbackCaller input when call compile() function
*/
bool force_output_dynamic_alloc = false;
//! whether to perform var sanity check on first run
bool var_sanity_check_first_run = true;
......@@ -657,6 +663,7 @@ SymbolVar SymbolVar::insert_single_output_opr(Args &&...args) const {
std::make_unique<Node>(std::forward<Args>(args)...))->output(0);
}
} // namespace cg
} // namespace mgb
......
......@@ -34,7 +34,7 @@ namespace static_infer {
class StaticInferManagerImpl;
}
class VarNode;
class VarDevMemDefragmenter;
class EagerEvalManager;
......@@ -685,7 +685,6 @@ bool VarNode::contain_flag(Flag flag) const {
return static_cast<bool>(m_flag & flag);
}
using VarNodeArray = mgb::SmallVector<VarNode*>;
using VarNodeSet = ThinHashSet<VarNode*>;
DType MemAllocPlan::dtype() const {
......
......@@ -2287,4 +2287,39 @@ TEST(TestGraph, CallbackCaller) {
}
}
TEST(TestGraph, DynamicOutput) {
using namespace opr;
REQUIRE_GPU(1);
auto cn0 = CompNode::load("gpu0");
constexpr size_t C1 = 20, C2 = 20;
constexpr size_t C = C1 + C2;
HostTensorGenerator<> gen;
auto host_opr0 = gen({C}, cn0);
auto graph = ComputingGraph::make();
graph->options().force_output_dynamic_alloc = true;
SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0);
auto spl_0 = opr::Split::make(
opr0, Split::Options::make_partition(opr0, 0, {C1, C2}));
auto sum = opr::add(spl_0[1], spl_0[1]);
HostTensorND expect_sum, expect_spl_0_0, result_sum, result_spl_0_0;
auto func1 = graph->compile({make_callback_copy(sum, expect_sum),
make_callback_copy(spl_0[0], expect_spl_0_0)});
func1->execute().wait();
auto func2 = graph->compile({{sum, nullptr}, {spl_0[0], nullptr}});
auto&& dest_vars = func2->get_output_vars();
func2->execute().wait();
result_sum.copy_from(dest_vars[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(expect_sum, result_sum, 1e-4);
result_spl_0_0.copy_from(dest_vars[1]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册