diff --git a/src/core/impl/graph/bases.cpp b/src/core/impl/graph/bases.cpp index 020167c636f85fda9e3b9548ad1b2fb2c844f7a5..ccb18c3fccfb78b2276c2c93bb7ad34f32536cf5 100644 --- a/src/core/impl/graph/bases.cpp +++ b/src/core/impl/graph/bases.cpp @@ -18,11 +18,9 @@ GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph): m_owner_graph{owner_graph} { mgb_assert(owner_graph, "owner graph not given"); - auto id = static_cast(owner_graph)->next_node_id(); - m_id = id; + m_id = owner_graph->next_node_id(); } AsyncExecutable::~AsyncExecutable() noexcept = default; // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 308dde9c76faa0292a0714f631532c3e6bfc7e59..4641562a14094e2cbb3f019d6e055b6581f18af4 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -267,6 +267,14 @@ void ComputingGraphImpl::cleanup() { m_opr_refkeeper.clear(); } +void* ComputingGraphImpl::alloc_varnode_storage() { + return m_var_node_pool.alloc_raw(); +}; + +void ComputingGraphImpl::free_varnode_storage(void *ptr) { + m_var_node_pool.free_raw(ptr); +}; + OperatorNodeBase* ComputingGraphImpl::insert_opr( std::unique_ptr opr_uniqp) { auto opr = opr_uniqp.get(); diff --git a/src/core/impl/graph/cg_impl.h b/src/core/impl/graph/cg_impl.h index c859cd6b58d1d462dde06ba2288e141c7e589e85..4f7457dd04b68437849a5ee4c369d51e2c434be9 100644 --- a/src/core/impl/graph/cg_impl.h +++ b/src/core/impl/graph/cg_impl.h @@ -142,6 +142,10 @@ public: OperatorNodeBase* insert_opr( std::unique_ptr opr) override; + void* alloc_varnode_storage() override; + + void free_varnode_storage(void *ptr) override; + const VarReceiverInfo& var_receiver_in_current_comp_seq( const VarNode* var) const override; @@ -161,7 +165,7 @@ public: TopoSorter& topo_sorter() { return components().topo_sorter; } - size_t next_node_id() { return (*m_node_id_counter)++; } + size_t next_node_id() override { return (*m_node_id_counter)++; } VarNodeMemManager& var_node_mem_manager() { return components().var_node_mem_manager; diff --git a/src/core/impl/graph/operator_node.cpp b/src/core/impl/graph/operator_node.cpp index 2bf06a1c62df7650c4ceafcb46bbed51d18fac33..85500162fa3ef594e82beaef72fdbb1159d9c8c0 100644 --- a/src/core/impl/graph/operator_node.cpp +++ b/src/core/impl/graph/operator_node.cpp @@ -93,10 +93,8 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner, } OperatorNodeBase::~OperatorNodeBase() noexcept { - auto &&pool = ComputingGraphImpl::cast( - owner_graph())->var_node_pool(); for (auto i: m_output) { - pool.free(i); + owner_graph()->free_varnode(i); } } @@ -264,8 +262,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe &name) { mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(), "add output on opr after it has been inserted into graph"); - auto ptr = ComputingGraphImpl::cast( - owner_graph())->var_node_pool().alloc( + auto ptr = owner_graph()->alloc_varnode( name.valid() ? this->name() + ":" + name.val() : name, this); m_output.push_back(ptr); return ptr; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 4f47d2f654c5f90b44cde58265ca927b2b9599f7..7949b1c66ab1b67dde05f0ff67db15cce0f3a7fa 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -174,6 +174,8 @@ class ComputingGraph : public std::enable_shared_from_this, return m_id; } + virtual size_t next_node_id() = 0; + static std::shared_ptr make(); //! assert that refcnt for ptr is one and destories the ptr @@ -235,6 +237,26 @@ class ComputingGraph : public std::enable_shared_from_this, virtual OperatorNodeBase* insert_opr( std::unique_ptr opr) = 0; + /*! + * \brief used by OperatorNodeBase to allocate its outputs + */ + template + VarNode* alloc_varnode(Args&&... args) { + return new(alloc_varnode_storage()) VarNode(std::forward(args)...); + } + + inline void free_varnode(VarNode* var) { + var->~VarNode(); + free_varnode_storage(var); + } + protected: + /*! + * \brief provided by impl to support alloc_varnode + */ + virtual void* alloc_varnode_storage() = 0; + + virtual void free_varnode_storage(void *ptr) = 0; + public: /*! * \brief get current computing sequence */ diff --git a/src/core/include/megbrain/utils/mempool.h b/src/core/include/megbrain/utils/mempool.h index c7dfce276d3709e16b29f2135544560ba52a63e7..0b8b3b3dec832dd5ed9d10d6f0b495725bcacaad 100644 --- a/src/core/include/megbrain/utils/mempool.h +++ b/src/core/include/megbrain/utils/mempool.h @@ -86,12 +86,17 @@ namespace mgb { }; using UniquePtr = std::unique_ptr; + void* alloc_raw() { + return m_storage.alloc(Const<>::ELEM_SIZE); + } + + void free_raw(void *ptr) { + m_storage.free(ptr); + } + template T* alloc(Args&&... args) { - auto ptr = static_cast( - m_storage.alloc(Const<>::ELEM_SIZE)); - new(ptr) T(std::forward(args)...); - return ptr; + return new(alloc_raw()) T(std::forward(args)...); } template @@ -102,7 +107,7 @@ namespace mgb { void free(T *ptr) { ptr->~T(); - m_storage.free(ptr); + free_raw(ptr); } //! reorder free list for cache friendly in future alloc