diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index b5fa89e363997e00682f779b0c25841c4b2456cb..e88510f74e1b95f396f7fb264f63f640f8db93dc 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -12,6 +12,7 @@ #include "./cg_impl_seq.h" #include "megbrain/graph/exc_extra_info.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/utils/arith_helper.h" using namespace mgb; using namespace cg; @@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute( } exec_ctx.perform(&m_exec_env); +#ifndef __IN_TEE_ENV__ + do_regist(); +#endif } void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { @@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { } #ifndef __IN_TEE_ENV__ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( - const std::string& svg_name) { - check_not_finalized(); + const std::string& svg_name) const { auto& recorder = StaticMemRecorder::Instance(); recorder.active(); - ExecContext exec_ctx{this}; + recorder.set_svg_name(svg_name); +} + +void ComputingGraphImpl::ComputingSequence::do_regist() const { // regist weights - size_t addr_base = recorder.peak_mem_size(); - size_t chunk_id = recorder.set_weight_chunk_id(); - for (auto&& i : *(this->m_opr_seq)) { - auto op = i->output(); - for (auto&& j : op) { - auto& mp = j->mem_plan(); - if (mp.valid()) { - auto& mc = mp.chunk(); - if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { - recorder.regist_memory_chunk( - {chunk_id++, mc.size(), 0, this->m_opr_seq->size(), - addr_base, addr_base + mc.size(), 0, false, - mc.owner_var->name()}); - addr_base += mc.size(); + auto& recorder = StaticMemRecorder::Instance(); + if (recorder.valid()) { + size_t addr_base = recorder.peak_mem_size(); + size_t chunk_id = recorder.set_weight_chunk_id(); + for (auto&& i : *(this->m_opr_seq)) { + auto op = i->output(); + for (auto&& j : op) { + auto& mp = j->mem_plan(); + if (mp.valid()) { + auto& mc = mp.chunk(); + if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { + auto size = mgb::get_aligned_power2( + mc.size(), + j->comp_node().get_mem_addr_alignment()); + + recorder.regist_memory_chunk( + {chunk_id++, size, 0, this->m_opr_seq->size(), + addr_base, addr_base + size, 0, false, + mc.owner_var->name()}); + + addr_base += size; + } } } } + recorder.set_sum_mem_size(addr_base); + recorder.show(); } - recorder.set_sum_mem_size(addr_base); - mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); - mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, - "svg_name must be end with \".svg\"\n"); - recorder.show(svg_name); } #endif AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { diff --git a/src/core/impl/graph/cg_impl_seq.h b/src/core/impl/graph/cg_impl_seq.h index f13e50e94be3d0f71cab98118d9287c4c494b608..1de43f6aab0445032f027c811fd56e2e5768805d 100644 --- a/src/core/impl/graph/cg_impl_seq.h +++ b/src/core/impl/graph/cg_impl_seq.h @@ -174,7 +174,10 @@ public: std::unique_ptr as_recorded_seq(); #ifndef __IN_TEE_ENV__ void get_static_memory_alloc_info( - const std::string& svg_name = "static_mem_record.svg") override; + const std::string& svg_name = + "static_mem_record.svg") const override; + + void do_regist() const; #endif }; diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index 7d13b527bc5bfb5d4ba02742a6a58638693b1b1b..8a239ba8eff750c58e1a9417f626ed3716d7ff20 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable, return (*(output_vars_pair.first))->get_output_vars(); } #ifndef __IN_TEE_ENV__ - virtual void get_static_memory_alloc_info(const std::string& svg_name) { + virtual void get_static_memory_alloc_info( + const std::string& svg_name) const { mgb_assert(svg_name.length() < 0, "can't call this function directly\n"); } diff --git a/src/plugin/impl/static_mem_record.cpp b/src/plugin/impl/static_mem_record.cpp index 14e86ad9b4989790a05f994452fa916a05d96018..53629e2884445f1f7cb20dccf1cbefee0e4b6b08 100644 --- a/src/plugin/impl/static_mem_record.cpp +++ b/src/plugin/impl/static_mem_record.cpp @@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color, } } // namespace -void StaticMemRecorder::dump_svg(std::string svg_name) { +void StaticMemRecorder::dump_svg() { float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; float address_scale = 1; @@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { svg_height = svg_height + opr_rect_height * 2; std::ofstream outfile; - outfile.open(svg_name); + outfile.open(m_svg_name); outfile << "" << std::endl; outfile << "" @@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { outfile.close(); } -void StaticMemRecorder::show(std::string svg_name) { +void StaticMemRecorder::show() { for (auto&& i : m_memory_chunk_recorder) { if (i.id >= m_weight_chunk_id) { break; @@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) { m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); } } - dump_svg(svg_name); + dump_svg(); } std::vector> StaticMemRecorder::get_chunk_construct( diff --git a/src/plugin/include/megbrain/plugin/static_mem_record.h b/src/plugin/include/megbrain/plugin/static_mem_record.h index 78e5835945a86c07249d44b4c94eb78eb4958d78..f892e96ff96141e87e14b78a858c749adc74a178 100644 --- a/src/plugin/include/megbrain/plugin/static_mem_record.h +++ b/src/plugin/include/megbrain/plugin/static_mem_record.h @@ -54,25 +54,38 @@ public: void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } - const size_t& peak_mem_size() { return m_peak_mem_size; } + const size_t& peak_mem_size() const { return m_peak_mem_size; } void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } - const size_t& sum_mem_size() { return m_sum_mem_size; } + const size_t& sum_mem_size() const { return m_sum_mem_size; } const size_t& set_weight_chunk_id() { m_weight_chunk_id = m_memory_chunk_recorder.size(); return m_weight_chunk_id; } - const size_t& weight_chunk_id() { return m_weight_chunk_id; } + const size_t& weight_chunk_id() const { return m_weight_chunk_id; } - void dump_svg(std::string svg_name); + void dump_svg(); - void show(std::string svg_name); + void show(); + + void set_svg_name(const std::string& svg_name) { + mgb_assert(svg_name.length() > 4, + "svg_name must be end with \".svg\"\n"); + mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, + "svg_name must be end with \".svg\"\n"); + m_svg_name = svg_name; + } + + const std::string& get_svg_name() const{ + return m_svg_name; + } private: bool m_is_record = false; + std::string m_svg_name; // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are // weights memory chunks size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id;