From 6070f1272dcd760cc3964af9844b6365d7fdb9f9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Jun 2021 10:19:28 +0800 Subject: [PATCH] fix(mgb): fix getting static memory alloc info GitOrigin-RevId: dfc69c3b3f95b11d708ada0891526db50e4b382c --- src/core/impl/graph/cg_impl_seq.cpp | 55 +++++++++++-------- src/core/impl/graph/cg_impl_seq.h | 5 +- src/core/include/megbrain/graph/bases.h | 3 +- src/plugin/impl/static_mem_record.cpp | 8 +-- .../megbrain/plugin/static_mem_record.h | 23 ++++++-- 5 files changed, 61 insertions(+), 33 deletions(-) diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index b5fa89e36..e88510f74 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -12,6 +12,7 @@ #include "./cg_impl_seq.h" #include "megbrain/graph/exc_extra_info.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/utils/arith_helper.h" using namespace mgb; using namespace cg; @@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute( } exec_ctx.perform(&m_exec_env); +#ifndef __IN_TEE_ENV__ + do_regist(); +#endif } void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { @@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { } #ifndef __IN_TEE_ENV__ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( - const std::string& svg_name) { - check_not_finalized(); + const std::string& svg_name) const { auto& recorder = StaticMemRecorder::Instance(); recorder.active(); - ExecContext exec_ctx{this}; + recorder.set_svg_name(svg_name); +} + +void ComputingGraphImpl::ComputingSequence::do_regist() const { // regist weights - size_t addr_base = recorder.peak_mem_size(); - size_t chunk_id = recorder.set_weight_chunk_id(); - for (auto&& i : *(this->m_opr_seq)) { - auto op = i->output(); - for (auto&& j : op) { - auto& mp = j->mem_plan(); - if (mp.valid()) { - auto& mc = mp.chunk(); - if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { - recorder.regist_memory_chunk( - {chunk_id++, mc.size(), 0, this->m_opr_seq->size(), - addr_base, addr_base + mc.size(), 0, false, - mc.owner_var->name()}); - addr_base += mc.size(); + auto& recorder = StaticMemRecorder::Instance(); + if (recorder.valid()) { + size_t addr_base = recorder.peak_mem_size(); + size_t chunk_id = recorder.set_weight_chunk_id(); + for (auto&& i : *(this->m_opr_seq)) { + auto op = i->output(); + for (auto&& j : op) { + auto& mp = j->mem_plan(); + if (mp.valid()) { + auto& mc = mp.chunk(); + if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { + auto size = mgb::get_aligned_power2( + mc.size(), + j->comp_node().get_mem_addr_alignment()); + + recorder.regist_memory_chunk( + {chunk_id++, size, 0, this->m_opr_seq->size(), + addr_base, addr_base + size, 0, false, + mc.owner_var->name()}); + + addr_base += size; + } } } } + recorder.set_sum_mem_size(addr_base); + recorder.show(); } - recorder.set_sum_mem_size(addr_base); - mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); - mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, - "svg_name must be end with \".svg\"\n"); - recorder.show(svg_name); } #endif AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { diff --git a/src/core/impl/graph/cg_impl_seq.h b/src/core/impl/graph/cg_impl_seq.h index f13e50e94..1de43f6aa 100644 --- a/src/core/impl/graph/cg_impl_seq.h +++ b/src/core/impl/graph/cg_impl_seq.h @@ -174,7 +174,10 @@ public: std::unique_ptr as_recorded_seq(); #ifndef __IN_TEE_ENV__ void get_static_memory_alloc_info( - const std::string& svg_name = "static_mem_record.svg") override; + const std::string& svg_name = + "static_mem_record.svg") const override; + + void do_regist() const; #endif }; diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index 7d13b527b..8a239ba8e 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable, return (*(output_vars_pair.first))->get_output_vars(); } #ifndef __IN_TEE_ENV__ - virtual void get_static_memory_alloc_info(const std::string& svg_name) { + virtual void get_static_memory_alloc_info( + const std::string& svg_name) const { mgb_assert(svg_name.length() < 0, "can't call this function directly\n"); } diff --git a/src/plugin/impl/static_mem_record.cpp b/src/plugin/impl/static_mem_record.cpp index 14e86ad9b..53629e288 100644 --- a/src/plugin/impl/static_mem_record.cpp +++ b/src/plugin/impl/static_mem_record.cpp @@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color, } } // namespace -void StaticMemRecorder::dump_svg(std::string svg_name) { +void StaticMemRecorder::dump_svg() { float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; float address_scale = 1; @@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { svg_height = svg_height + opr_rect_height * 2; std::ofstream outfile; - outfile.open(svg_name); + outfile.open(m_svg_name); outfile << "" << std::endl; outfile << "" @@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { outfile.close(); } -void StaticMemRecorder::show(std::string svg_name) { +void StaticMemRecorder::show() { for (auto&& i : m_memory_chunk_recorder) { if (i.id >= m_weight_chunk_id) { break; @@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) { m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); } } - dump_svg(svg_name); + dump_svg(); } std::vector> StaticMemRecorder::get_chunk_construct( diff --git a/src/plugin/include/megbrain/plugin/static_mem_record.h b/src/plugin/include/megbrain/plugin/static_mem_record.h index 78e583594..f892e96ff 100644 --- a/src/plugin/include/megbrain/plugin/static_mem_record.h +++ b/src/plugin/include/megbrain/plugin/static_mem_record.h @@ -54,25 +54,38 @@ public: void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } - const size_t& peak_mem_size() { return m_peak_mem_size; } + const size_t& peak_mem_size() const { return m_peak_mem_size; } void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } - const size_t& sum_mem_size() { return m_sum_mem_size; } + const size_t& sum_mem_size() const { return m_sum_mem_size; } const size_t& set_weight_chunk_id() { m_weight_chunk_id = m_memory_chunk_recorder.size(); return m_weight_chunk_id; } - const size_t& weight_chunk_id() { return m_weight_chunk_id; } + const size_t& weight_chunk_id() const { return m_weight_chunk_id; } - void dump_svg(std::string svg_name); + void dump_svg(); - void show(std::string svg_name); + void show(); + + void set_svg_name(const std::string& svg_name) { + mgb_assert(svg_name.length() > 4, + "svg_name must be end with \".svg\"\n"); + mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, + "svg_name must be end with \".svg\"\n"); + m_svg_name = svg_name; + } + + const std::string& get_svg_name() const{ + return m_svg_name; + } private: bool m_is_record = false; + std::string m_svg_name; // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are // weights memory chunks size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; -- GitLab