提交 d782edf8 编写于 作者: M Megvii Engine Team

refactor(mgb): decouple node insertion from ComputingGraphImpl

GitOrigin-RevId: 59b45fcb17be8ca9b94c1cb4eabf7ed949967f08
上级 d42cf4cd
......@@ -18,11 +18,9 @@ GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph):
m_owner_graph{owner_graph}
{
mgb_assert(owner_graph, "owner graph not given");
auto id = static_cast<ComputingGraphImpl*>(owner_graph)->next_node_id();
m_id = id;
m_id = owner_graph->next_node_id();
}
AsyncExecutable::~AsyncExecutable() noexcept = default;
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -267,6 +267,14 @@ void ComputingGraphImpl::cleanup() {
m_opr_refkeeper.clear();
}
void* ComputingGraphImpl::alloc_varnode_storage() {
return m_var_node_pool.alloc_raw();
};
void ComputingGraphImpl::free_varnode_storage(void *ptr) {
m_var_node_pool.free_raw(ptr);
};
OperatorNodeBase* ComputingGraphImpl::insert_opr(
std::unique_ptr<OperatorNodeBase> opr_uniqp) {
auto opr = opr_uniqp.get();
......
......@@ -142,6 +142,10 @@ public:
OperatorNodeBase* insert_opr(
std::unique_ptr<OperatorNodeBase> opr) override;
void* alloc_varnode_storage() override;
void free_varnode_storage(void *ptr) override;
const VarReceiverInfo& var_receiver_in_current_comp_seq(
const VarNode* var) const override;
......@@ -161,7 +165,7 @@ public:
TopoSorter& topo_sorter() { return components().topo_sorter; }
size_t next_node_id() { return (*m_node_id_counter)++; }
size_t next_node_id() override { return (*m_node_id_counter)++; }
VarNodeMemManager& var_node_mem_manager() {
return components().var_node_mem_manager;
......
......@@ -93,10 +93,8 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner,
}
OperatorNodeBase::~OperatorNodeBase() noexcept {
auto &&pool = ComputingGraphImpl::cast(
owner_graph())->var_node_pool();
for (auto i: m_output) {
pool.free(i);
owner_graph()->free_varnode(i);
}
}
......@@ -264,8 +262,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe<std::string> &name) {
mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(),
"add output on opr after it has been inserted into graph");
auto ptr = ComputingGraphImpl::cast(
owner_graph())->var_node_pool().alloc(
auto ptr = owner_graph()->alloc_varnode(
name.valid() ? this->name() + ":" + name.val() : name, this);
m_output.push_back(ptr);
return ptr;
......
......@@ -174,6 +174,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
return m_id;
}
virtual size_t next_node_id() = 0;
static std::shared_ptr<ComputingGraph> make();
//! assert that refcnt for ptr is one and destories the ptr
......@@ -235,6 +237,26 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
virtual OperatorNodeBase* insert_opr(
std::unique_ptr<OperatorNodeBase> opr) = 0;
/*!
* \brief used by OperatorNodeBase to allocate its outputs
*/
template<typename... Args>
VarNode* alloc_varnode(Args&&... args) {
return new(alloc_varnode_storage()) VarNode(std::forward<Args>(args)...);
}
inline void free_varnode(VarNode* var) {
var->~VarNode();
free_varnode_storage(var);
}
protected:
/*!
* \brief provided by impl to support alloc_varnode
*/
virtual void* alloc_varnode_storage() = 0;
virtual void free_varnode_storage(void *ptr) = 0;
public:
/*!
* \brief get current computing sequence
*/
......
......@@ -86,12 +86,17 @@ namespace mgb {
};
using UniquePtr = std::unique_ptr<T, Deleter>;
void* alloc_raw() {
return m_storage.alloc(Const<>::ELEM_SIZE);
}
void free_raw(void *ptr) {
m_storage.free(ptr);
}
template<typename...Args>
T* alloc(Args&&... args) {
auto ptr = static_cast<T*>(
m_storage.alloc(Const<>::ELEM_SIZE));
new(ptr) T(std::forward<Args>(args)...);
return ptr;
return new(alloc_raw()) T(std::forward<Args>(args)...);
}
template<typename...Args>
......@@ -102,7 +107,7 @@ namespace mgb {
void free(T *ptr) {
ptr->~T();
m_storage.free(ptr);
free_raw(ptr);
}
//! reorder free list for cache friendly in future alloc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册