From b06b58996068e42db2478d5066b631da11adf43a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 23 Apr 2021 17:20:53 +0800 Subject: [PATCH] feat(mgb): get static graph memory info GitOrigin-RevId: f31745f8df67e6f239aa66f18dd12546081cd3e5 --- .../python/megengine/tools/svg_viewer.html | 154 +++++++++ src/core/impl/graph/cg_impl_seq.cpp | 32 ++ src/core/impl/graph/cg_impl_seq.h | 4 + .../graph/var_node_mem_mgr/seq_mem_opt.cpp | 17 + .../graph/var_node_mem_mgr/static_mem_alloc.h | 1 + .../static_mem_alloc/impl.cpp | 16 + src/core/include/megbrain/graph/bases.h | 5 + src/plugin/impl/static_mem_record.cpp | 319 ++++++++++++++++++ .../megbrain/plugin/static_mem_record.h | 85 +++++ 9 files changed, 633 insertions(+) create mode 100644 imperative/python/megengine/tools/svg_viewer.html create mode 100644 src/plugin/impl/static_mem_record.cpp create mode 100644 src/plugin/include/megbrain/plugin/static_mem_record.h diff --git a/imperative/python/megengine/tools/svg_viewer.html b/imperative/python/megengine/tools/svg_viewer.html new file mode 100644 index 000000000..885091c17 --- /dev/null +++ b/imperative/python/megengine/tools/svg_viewer.html @@ -0,0 +1,154 @@ + + +Visualizer + + + + + + + +

desc

+

info

+

+

+ + + + \ No newline at end of file diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index 2a2587e77..926bf8212 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { return *this; } +void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( + const std::string& svg_name) { + check_not_finalized(); + auto& recorder = StaticMemRecorder::Instance(); + recorder.active(); + ExecContext exec_ctx{this}; + // regist weights + size_t addr_base = recorder.peak_mem_size(); + size_t chunk_id = recorder.set_weight_chunk_id(); + for (auto&& i : *(this->m_opr_seq)) { + auto op = i->output(); + for (auto&& j : op) { + auto& mp = j->mem_plan(); + if (mp.valid()) { + auto& mc = mp.chunk(); + if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { + recorder.regist_memory_chunk( + {chunk_id++, mc.size(), 0, this->m_opr_seq->size(), + addr_base, addr_base + mc.size(), 0, false, + mc.owner_var->name()}); + addr_base += mc.size(); + } + } + } + } + recorder.set_sum_mem_size(addr_base); + mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); + mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, + "svg_name must be end with \".svg\"\n"); + recorder.show(svg_name); +} + AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { do_wait(true); return *this; diff --git a/src/core/impl/graph/cg_impl_seq.h b/src/core/impl/graph/cg_impl_seq.h index a999542b2..8cdc3bee7 100644 --- a/src/core/impl/graph/cg_impl_seq.h +++ b/src/core/impl/graph/cg_impl_seq.h @@ -16,6 +16,7 @@ #include "megbrain/comp_node_env.h" #include "megbrain/plugin/var_sanity_check.h" #include "megbrain/utils/arith_helper.h" +#include "megbrain/plugin/static_mem_record.h" namespace mgb { namespace cg { @@ -169,6 +170,9 @@ public: } std::unique_ptr as_recorded_seq(); + + void get_static_memory_alloc_info( + const std::string& svg_name = "static_mem_record.svg") override; }; class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { diff --git a/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp b/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp index 348df67de..c5f41656a 100644 --- a/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp +++ b/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp @@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() { ThinHashMap chk2interval; // get all memory chunks + if (StaticMemRecorder::Instance().valid()) { + StaticMemRecorder::Instance().clear_opr_seq(); + } + for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { OperatorNodeBase *opr = m_cur_seq_full->at(idx); + if (StaticMemRecorder::Instance().valid()) { + StaticMemRecorder::Instance().regist_opr_seq( + {idx, 0, opr->name()}); + } + auto &&dep_map = opr->node_prop().dep_map(); if (in_sys_alloc(opr)) { @@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( chk.chunk->mem_alloc_status.set_static_offset( allocator->get_start_addr(&chk)); } + auto& recorder = StaticMemRecorder::Instance(); + if (recorder.valid()) { + for (size_t i = 0; i < chunks.size(); i++) { + recorder.regist_memory_chunk_owner_var_name( + i, chunks.at(i).chunk->owner_var->name()); + } + recorder.regist_peak_mem_size(size); + } } return should_realloc; diff --git a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h index aae006a06..77c0ad4e9 100644 --- a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h +++ b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h @@ -11,6 +11,7 @@ #pragma once +#include "megbrain/plugin/static_mem_record.h" #include "megbrain_build_config.h" #include diff --git a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp index 2834aa353..693080084 100644 --- a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp +++ b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp @@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { check_result_and_calc_lower_bound(); + if (StaticMemRecorder::Instance().valid()) { + StaticMemRecorder::Instance().clear_memory_chunk(); + for (auto&& i : m_interval) { + size_t overwrite_dest_id = 0; + bool is_overwrite = !i->is_overwrite_root(); + if (is_overwrite) { + overwrite_dest_id = i->overwrite_dest_root()->id; + } + + StaticMemRecorder::Instance().regist_memory_chunk( + {i->id, i->size_orig, i->time_begin, i->time_end, + i->addr_begin, i->addr_end(), overwrite_dest_id, + is_overwrite, ""}); + } + } + return *this; } diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index f8ee265c8..ed8b58882 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable, m_user_data.get_user_data(); return (*(output_vars_pair.first))->get_output_vars(); } + + virtual void get_static_memory_alloc_info(const std::string& svg_name) { + mgb_assert(svg_name.length() < 0, + "can't call this function directly\n"); + } }; diff --git a/src/plugin/impl/static_mem_record.cpp b/src/plugin/impl/static_mem_record.cpp new file mode 100644 index 000000000..cb2285698 --- /dev/null +++ b/src/plugin/impl/static_mem_record.cpp @@ -0,0 +1,319 @@ +/** + * \file src/plugin/impl/static_mem_record.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/plugin/static_mem_record.h" +#include +#include + +using namespace mgb; +using namespace cg; + +namespace { +#define SVG_WIDTH 20000.0 +#define SVG_HEIGHT 15000.0 +#define OPR_RECT_WIDTH 40.0 +#define OPR_RECT_HEIGHT 20.0 + +const std::string rect = + ""; +const std::string text = "{}"; +const std::string polyline = + ""; +const std::string opr_info = + "mge:type=\"opr\" mge:id=\"{}\" mge:size=\"{}\" mge:name=\"{}\""; +const std::string chunk_info = + "mge:type=\"chunk\" mge:id=\"{}\" mge:time=\"{}\" mge:addr=\"{}\" " + "mge:size=\"{}\" mge:owner_var_name=\"{}\""; +const std::string animate = + "\n"; + +std::string& replace_by_parameter(std::string& original_str, size_t index) { + return original_str; +} + +template +std::string& replace_by_parameter(std::string& original_str, size_t index, + const std::string& parameter, + const Args&... args) { + index = original_str.find("{}", index); + original_str.replace(index, 2, parameter); + index += parameter.length(); + replace_by_parameter(original_str, index, args...); + return original_str; +} + +std::string set_opr_info(std::string id, std::string size, std::string name, + std::string info = opr_info) { + return replace_by_parameter(info, 0, id, size, name); +} + +std::string set_chunk_info(std::string id, std::string time, std::string addr, + std::string size, std::string owner_var_name, + std::string info = chunk_info) { + return replace_by_parameter(info, 0, id, time, addr, size, owner_var_name); +} + +std::string draw_rect(std::string x, std::string y, std::string widith, + std::string height, std::string color, std::string info, + std::string r = rect) { + return replace_by_parameter(r, 0, x, y, widith, height, color, info); +} + +std::string draw_text(std::string x, std::string y, std::string font_size, + std::string txt, std::string t = text) { + return replace_by_parameter(t, 0, x, y, font_size, txt); +} + +std::string draw_polyline(std::string point_seq, std::string color, + std::string width, std::string p = polyline) { + return replace_by_parameter(p, 0, point_seq, color, width); +} +} // namespace + +void StaticMemRecorder::dump_svg(std::string svg_name) { + float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, + opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; + float address_scale = 1; + size_t opr_nr = m_opr_seq_recorder.size(); + if (opr_nr * OPR_RECT_WIDTH > SVG_WIDTH) { + svg_width = SVG_WIDTH; + opr_rect_width = svg_width / opr_nr; + opr_rect_height = opr_rect_width / 2; + } else { + opr_rect_width = OPR_RECT_WIDTH; + svg_width = opr_nr * opr_rect_width; + } + if (m_sum_mem_size > SVG_HEIGHT) { + svg_height = SVG_HEIGHT; + address_scale = svg_height / m_sum_mem_size; + } else { + svg_height = m_sum_mem_size; + } + + // Rescale + float aspect_ratio = SVG_WIDTH / SVG_HEIGHT; + if (svg_width / svg_height < 1) { + svg_width = svg_height * aspect_ratio; + opr_rect_width = svg_width / opr_nr; + opr_rect_height = opr_rect_width / 2; + } else if (svg_width / svg_height > aspect_ratio) { + svg_height = svg_width / aspect_ratio; + address_scale = svg_height / m_sum_mem_size; + } + + svg_height = svg_height + opr_rect_height * 2; + + std::ofstream outfile; + outfile.open(svg_name); + outfile << "" << std::endl; + outfile << "" + << std::endl; + outfile << "" + << std::endl; + + float base_height = svg_height - opr_rect_height; + std::string peak_mem_polyline = + "0," + + std::to_string(base_height - m_peak_mem_size * address_scale) + + " " + std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + + "," + std::to_string(base_height - m_peak_mem_size * address_scale); + std::string sum_mem_polyline = + "0," + + std::to_string(base_height - m_sum_mem_size * address_scale) + " " + + std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + "," + + std::to_string(base_height - m_sum_mem_size * address_scale); + std::string memory_polyline = ""; + for (size_t i = 0; i < m_opr_seq_recorder.size(); i++) { + auto&& opr = m_opr_seq_recorder.at(i); + memory_polyline += + std::to_string((i + 0.5) * opr_rect_width) + "," + + std::to_string(base_height - opr.size * address_scale) + " "; + + outfile << draw_text(std::to_string(i * opr_rect_width), + std::to_string(svg_height - opr_rect_height * 0.5), + std::to_string(opr_rect_height * 0.5), + "opr" + std::to_string(i)) + << std::endl; + std::string opr_info = + set_opr_info( + std::to_string(opr.id), + std::to_string(opr.size) + "B(" + + std::to_string(opr.size / 1024.0 / 1024.0) + + "MiB)", + opr.name) + + " opacity=\"0\""; + outfile << draw_rect(std::to_string(i * opr_rect_width), + std::to_string(base_height), + std::to_string(opr_rect_width), + std::to_string(opr_rect_height), "white", opr_info) + << std::endl; + } + + for (size_t i = 0; i < m_memory_chunk_recorder.size(); i++) { + auto&& chunk = m_memory_chunk_recorder.at(i); + std::string chunk_info = set_chunk_info( + std::to_string(chunk.id), + "[" + std::to_string(chunk.time_begin) + "," + + std::to_string(chunk.time_end) + ")", + "[" + std::to_string(chunk.addr_begin) + "," + + std::to_string(chunk.addr_end) + ")", + std::to_string(chunk.addr_end - chunk.addr_begin) + "B(" + + std::to_string((chunk.addr_end - chunk.addr_begin) / + 1024.0 / 1024.0) + + "MiB)", + chunk.owner_var_name); + + outfile << draw_rect( + std::to_string(chunk.time_begin * opr_rect_width), + std::to_string(base_height - + chunk.addr_end * address_scale), + std::to_string((chunk.time_end - chunk.time_begin) * + opr_rect_width), + std::to_string((chunk.addr_end - chunk.addr_begin) * + address_scale), + "gray", chunk_info) + << std::endl; + outfile << draw_text(std::to_string(chunk.time_begin * opr_rect_width), + std::to_string(base_height - + chunk.addr_end * address_scale + 9), + std::to_string(9), + "chunk" + std::to_string(chunk.id)) + << std::endl; + } + + outfile << draw_text("0", + std::to_string(base_height - + m_peak_mem_size * address_scale + + opr_rect_height * 0.5), + std::to_string(opr_rect_height * 0.5), + "peak_memory_size:" + std::to_string(m_peak_mem_size) + + "B(" + + std::to_string(m_peak_mem_size / 1024.0 / + 1024.0) + + "MiB)") + << std::endl; + outfile << draw_text("0", + std::to_string(base_height - + m_sum_mem_size * address_scale + + opr_rect_height * 0.5), + std::to_string(opr_rect_height * 0.5), + "sum_memory_size:" + std::to_string(m_sum_mem_size) + + "B(" + + std::to_string(m_sum_mem_size / 1024.0 / + 1024.0) + + "MiB)") + << std::endl; + outfile << draw_polyline(memory_polyline, "blue", + std::to_string(opr_rect_height * 0.1)) + << std::endl; + outfile << draw_polyline(peak_mem_polyline, "green", + std::to_string(opr_rect_height * 0.1)) + << std::endl; + outfile << draw_polyline(sum_mem_polyline, "red", + std::to_string(opr_rect_height * 0.1)) + << std::endl; + outfile << "" + << std::endl; + outfile << "" << std::endl; + outfile.close(); +} + +void StaticMemRecorder::show(std::string svg_name) { + for (auto&& i : m_memory_chunk_recorder) { + if (i.id >= m_weight_chunk_id) { + break; + } + size_t begin = i.time_begin, end = i.time_end; + if (i.is_overwrite) { + begin++; + } + for (size_t j = begin; j < end; j++) { + m_opr_seq_recorder.at(j).size += i.size_orig; + } + } + + // log peak memory size, where it is reached and which chunks constitute it. + mgb_log("peak_mem_size = %zu\n", m_peak_mem_size); + size_t max_size = 0; + std::vector opr_ids; + for (auto&& i : m_opr_seq_recorder) { + if (i.size == max_size) { + opr_ids.push_back(i.id); + } else if (i.size > max_size) { + max_size = i.size; + opr_ids.clear(); + opr_ids.push_back(i.id); + } + } + + auto opr2chunk = get_chunk_construct(opr_ids); + mgb_log("oprs reach the peak memory:\n"); + for (auto&& i : opr_ids) { + mgb_log("opr id = %zu\n", i); + } + mgb_log("More details:\n"); + for (size_t i = 0; i < opr2chunk.size(); i++) { + mgb_log("opr id = %zu\n", opr_ids.at(i)); + if (i + 1 < opr2chunk.size() && + opr2chunk.at(i) == opr2chunk.at(i + 1)) { + continue; + } + for (size_t j = 0; j < opr2chunk.at(i).size(); j++) { + auto&& chunk = m_memory_chunk_recorder.at(opr2chunk.at(i).at(j)); + mgb_log("[memory_chunk_id=%zu, size=%zu B, " + "[life_begin=%zu,life_end=%zu), owner_opr_name=%s]\n", + chunk.id, chunk.size_orig, chunk.time_begin, chunk.time_end, + m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); + } + } + dump_svg(svg_name); +} + +std::vector> StaticMemRecorder::get_chunk_construct( + std::vector opr_ids) { + std::vector> chunk_ids; + chunk_ids.resize(opr_ids.size()); + for (auto&& i : m_memory_chunk_recorder) { + if (i.id >= m_weight_chunk_id) { + break; + } + size_t begin = i.time_begin, end = i.time_end; + if (i.is_overwrite) { + begin = begin + 1; + } + if (opr_ids.front() >= end || opr_ids.back() < begin) { + continue; + } + for (size_t k = 0; k < opr_ids.size(); k++) { + if (opr_ids.at(k) >= end) { + break; + } else if (opr_ids.at(k) >= begin) { + chunk_ids.at(k).push_back(i.id); + } + } + } + return chunk_ids; +} \ No newline at end of file diff --git a/src/plugin/include/megbrain/plugin/static_mem_record.h b/src/plugin/include/megbrain/plugin/static_mem_record.h new file mode 100644 index 000000000..6276227d7 --- /dev/null +++ b/src/plugin/include/megbrain/plugin/static_mem_record.h @@ -0,0 +1,85 @@ +/** + * \file src/plugin/include/megbrain/plugin/static_mem_record.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/utils/metahelper.h" + +namespace mgb { +namespace cg { + +class StaticMemRecorder : public NonCopyableObj { +public: + static StaticMemRecorder& Instance() { + static StaticMemRecorder StaticMemRecorder; + return StaticMemRecorder; + } + + struct opr_record { + size_t id, size; + std::string name; + }; + struct memory_chunk_record { + size_t id, size_orig, time_begin, time_end, addr_begin, + addr_end, overwrite_dest_id; + bool is_overwrite; + std::string owner_var_name; + }; + + void active() { m_is_record = true; } + + bool valid() { return m_is_record; } + + void clear_opr_seq() { m_opr_seq_recorder.clear(); } + + void regist_opr_seq(opr_record opr) { m_opr_seq_recorder.push_back(opr); } + + void clear_memory_chunk() { m_memory_chunk_recorder.clear(); } + + void regist_memory_chunk(memory_chunk_record mcr) { + m_memory_chunk_recorder.push_back(mcr); + } + + void regist_memory_chunk_owner_var_name(size_t id, std::string name) { + m_memory_chunk_recorder.at(id).owner_var_name = name; + } + + void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } + + const size_t& peak_mem_size() { return m_peak_mem_size; } + + void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } + + const size_t& sum_mem_size() { return m_sum_mem_size; } + + const size_t& set_weight_chunk_id() { + m_weight_chunk_id = m_memory_chunk_recorder.size(); + return m_weight_chunk_id; + } + + const size_t& weight_chunk_id() { return m_weight_chunk_id; } + + void dump_svg(std::string svg_name); + + void show(std::string svg_name); + +private: + bool m_is_record = false; + // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are + // weights memory chunks + size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; + std::vector m_opr_seq_recorder; + std::vector m_memory_chunk_recorder; + std::vector> get_chunk_construct( + std::vector opr_ids); +}; +} // namespace cg +} // namespace mgb -- GitLab