提交 b06b5899 编写于 作者: M Megvii Engine Team

feat(mgb): get static graph memory info

GitOrigin-RevId: f31745f8df67e6f239aa66f18dd12546081cd3e5
上级 0cf4ff70
<html>
<title>Visualizer</title>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
</head>
<script>
window.onload = () => {
var board = document.getElementById('board');
var fileInput = document.getElementById('fileInput');
var desc = document.getElementById('desc');
var hRange = document.getElementById('hRange');
var vRange = document.getElementById('vRange');
var lastColor = undefined;
var lastElem = undefined;
var scale = 1;
var svg = undefined;
var svgWidth = undefined;
var svgHeight = undefined;
var loadDesc = (svgElem) => {
var mgeType = svgElem.attributes['mge:type'];
if (mgeType === undefined) {
return;
}
var elemList = [];
for (attrName of svgElem.getAttributeNames()) {
var prefix = 'mge:';
if (!attrName.startsWith(prefix)) {
continue;
}
var elem = '<p>' + attrName.substr(prefix.length) + ': ' + svgElem.attributes[attrName].value + '</p>'
elemList.push(elem);
}
desc.innerHTML = elemList.join('');
};
var selectElem = svgElem => {
loadDesc(svgElem);
lastColor = svgElem.attributes['fill'].value;
lastElem = svgElem;
svgElem.attributes['fill'].value = 'green';
};
var unselectLast = svgElem => {
if (lastElem) {
lastElem.attributes['fill'].value = lastColor;
}
lastElem = undefined;
lastColor = undefined;
};
function recLoadSVG(svgElem) {
if (svgElem.children === undefined) {
return;
}
svgElem.onmousedown = e => {
var mgeType = svgElem.attributes['mge:type'];
if (mgeType === undefined) {
return;
}
unselectLast();
selectElem(svgElem);
e.stopPropagation();
};
for (child of svgElem.children) {
recLoadSVG(child);
}
}
function loadSVG() {
var file = fileInput.files[0];
var reader = new FileReader();
reader.readAsText(file, "UTF-8");
reader.onload = e => {
board.innerHTML = '<p style="margin: 0;">' + e.target.result + '</p>';
svg = board.children[0].children[0];
svgWidth = svg.attributes['width'].value;
svgHeight = svg.attributes['height'].value;
for (child of board.children) {
recLoadSVG(child);
var svgInfo = child.attributes['svg:info'];
if (svgInfo !== undefined) {
var elemList = [];
for (attrName of child.getAttributeNames()) {
var prefix = 'svg:';
if (!attrName.startsWith(prefix)) {
continue;
}
var elem = '<p>' + attrName.substr(prefix.length) + ': ' + child.attributes[attrName].value + '</p>'
elemList.push(elem);
}
info.innerHTML = elemList.join('');
}
}
};
}
function scaleBoard(x, y) {
var transform = 'scale(' + x + ',' + y + ')';
svg.setAttribute('transform', transform);
board.style['width'] = svgWidth * x;
board.style['height'] = svgHeight * y;
}
function autoScaleBoard() {
var hRangeValue = Math.sqrt(Number(hRange.value) / 10);
var vRangeValue = Math.sqrt(Number(vRange.value) / 10);
scaleBoard(Number(hRangeValue), Number(vRangeValue));
}
fileInput.onchange = loadSVG;
var zoomBoard = dScale => {
scale *= dScale;
scaleBoard(scale, scale);
};
window.addEventListener('wheel', e => {
console.log(e);
if (e.ctrlKey) {
e.preventDefault();
e.stopPropagation();
var factor = 1;
if (e.deltaY < 0) {
factor = 1.1;
} else if (e.deltaY > 0) {
factor = 1 / 1.1;
}
zoomBoard(factor);
var newPageX = e.pageX * factor;
var newPageY = e.pageY * factor;
x = newPageX - e.x;
y = newPageY - e.y;
window.scrollTo({
top: y,
left: x,
});
console.log('scroll', [x, y]);
}
}, { 'passive': false });
};
</script>
<body>
<p id="desc" style="position: fixed;bottom: 0; background-color: white;">desc</p>
<p id="info" style="position: fixed;top: 0; right: 0; background-color: white;">info</p>
<p id="board"
style="white-space: nowrap; display: flex; justify-content: center; align-content: center; align-items: center; margin: 0;opacity: 0.7;">
</p>
<input type='file' id='fileInput' style="position: fixed; top: 0; background-color: white;"></input>
</body>
</html>
\ No newline at end of file
...@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { ...@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
return *this; return *this;
} }
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
const std::string& svg_name) {
check_not_finalized();
auto& recorder = StaticMemRecorder::Instance();
recorder.active();
ExecContext exec_ctx{this};
// regist weights
size_t addr_base = recorder.peak_mem_size();
size_t chunk_id = recorder.set_weight_chunk_id();
for (auto&& i : *(this->m_opr_seq)) {
auto op = i->output();
for (auto&& j : op) {
auto& mp = j->mem_plan();
if (mp.valid()) {
auto& mc = mp.chunk();
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) {
recorder.regist_memory_chunk(
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(),
addr_base, addr_base + mc.size(), 0, false,
mc.owner_var->name()});
addr_base += mc.size();
}
}
}
}
recorder.set_sum_mem_size(addr_base);
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n");
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0,
"svg_name must be end with \".svg\"\n");
recorder.show(svg_name);
}
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() {
do_wait(true); do_wait(true);
return *this; return *this;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/plugin/var_sanity_check.h" #include "megbrain/plugin/var_sanity_check.h"
#include "megbrain/utils/arith_helper.h" #include "megbrain/utils/arith_helper.h"
#include "megbrain/plugin/static_mem_record.h"
namespace mgb { namespace mgb {
namespace cg { namespace cg {
...@@ -169,6 +170,9 @@ public: ...@@ -169,6 +170,9 @@ public:
} }
std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); std::unique_ptr<RecordedComputingSequence> as_recorded_seq();
void get_static_memory_alloc_info(
const std::string& svg_name = "static_mem_record.svg") override;
}; };
class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj {
......
...@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() { ...@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() {
ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval;
// get all memory chunks // get all memory chunks
if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().clear_opr_seq();
}
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) {
OperatorNodeBase *opr = m_cur_seq_full->at(idx); OperatorNodeBase *opr = m_cur_seq_full->at(idx);
if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().regist_opr_seq(
{idx, 0, opr->name()});
}
auto &&dep_map = opr->node_prop().dep_map(); auto &&dep_map = opr->node_prop().dep_map();
if (in_sys_alloc(opr)) { if (in_sys_alloc(opr)) {
...@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( ...@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
chk.chunk->mem_alloc_status.set_static_offset( chk.chunk->mem_alloc_status.set_static_offset(
allocator->get_start_addr(&chk)); allocator->get_start_addr(&chk));
} }
auto& recorder = StaticMemRecorder::Instance();
if (recorder.valid()) {
for (size_t i = 0; i < chunks.size(); i++) {
recorder.regist_memory_chunk_owner_var_name(
i, chunks.at(i).chunk->owner_var->name());
}
recorder.regist_peak_mem_size(size);
}
} }
return should_realloc; return should_realloc;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#pragma once #pragma once
#include "megbrain/plugin/static_mem_record.h"
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#include <cstddef> #include <cstddef>
......
...@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { ...@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
check_result_and_calc_lower_bound(); check_result_and_calc_lower_bound();
if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().clear_memory_chunk();
for (auto&& i : m_interval) {
size_t overwrite_dest_id = 0;
bool is_overwrite = !i->is_overwrite_root();
if (is_overwrite) {
overwrite_dest_id = i->overwrite_dest_root()->id;
}
StaticMemRecorder::Instance().regist_memory_chunk(
{i->id, i->size_orig, i->time_begin, i->time_end,
i->addr_begin, i->addr_end(), overwrite_dest_id,
is_overwrite, ""});
}
}
return *this; return *this;
} }
......
...@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable, ...@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable,
m_user_data.get_user_data<OutputVarsUserData>(); m_user_data.get_user_data<OutputVarsUserData>();
return (*(output_vars_pair.first))->get_output_vars(); return (*(output_vars_pair.first))->get_output_vars();
} }
virtual void get_static_memory_alloc_info(const std::string& svg_name) {
mgb_assert(svg_name.length() < 0,
"can't call this function directly\n");
}
}; };
......
/**
* \file src/plugin/impl/static_mem_record.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/plugin/static_mem_record.h"
#include <fstream>
#include <iostream>
using namespace mgb;
using namespace cg;
namespace {
#define SVG_WIDTH 20000.0
#define SVG_HEIGHT 15000.0
#define OPR_RECT_WIDTH 40.0
#define OPR_RECT_HEIGHT 20.0
const std::string rect =
"<rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"{}\" "
" {}></rect>";
const std::string text = "<text x=\"{}\" y=\"{}\" font-size=\"{}\">{}</text>";
const std::string polyline =
"<polyline points=\"{}\" style=\"fill:none;stroke:{};stroke-width:{}\" "
"/>";
const std::string opr_info =
"mge:type=\"opr\" mge:id=\"{}\" mge:size=\"{}\" mge:name=\"{}\"";
const std::string chunk_info =
"mge:type=\"chunk\" mge:id=\"{}\" mge:time=\"{}\" mge:addr=\"{}\" "
"mge:size=\"{}\" mge:owner_var_name=\"{}\"";
const std::string animate =
"<animate attributeName=\"opacity\" from=\"0\" to=\"1\" "
"begin=\"{}.mouseover\" fill=\"freeze\" dur=\"1s\"/>\n<animate "
"attributeName=\"opacity\" from=\"1\" to=\"0\" begin=\"{}.mouseout\" "
"fill=\"freeze\" dur=\"1s\"/>";
std::string& replace_by_parameter(std::string& original_str, size_t index) {
return original_str;
}
template <typename... Args>
std::string& replace_by_parameter(std::string& original_str, size_t index,
const std::string& parameter,
const Args&... args) {
index = original_str.find("{}", index);
original_str.replace(index, 2, parameter);
index += parameter.length();
replace_by_parameter(original_str, index, args...);
return original_str;
}
std::string set_opr_info(std::string id, std::string size, std::string name,
std::string info = opr_info) {
return replace_by_parameter(info, 0, id, size, name);
}
std::string set_chunk_info(std::string id, std::string time, std::string addr,
std::string size, std::string owner_var_name,
std::string info = chunk_info) {
return replace_by_parameter(info, 0, id, time, addr, size, owner_var_name);
}
std::string draw_rect(std::string x, std::string y, std::string widith,
std::string height, std::string color, std::string info,
std::string r = rect) {
return replace_by_parameter(r, 0, x, y, widith, height, color, info);
}
std::string draw_text(std::string x, std::string y, std::string font_size,
std::string txt, std::string t = text) {
return replace_by_parameter(t, 0, x, y, font_size, txt);
}
std::string draw_polyline(std::string point_seq, std::string color,
std::string width, std::string p = polyline) {
return replace_by_parameter(p, 0, point_seq, color, width);
}
} // namespace
void StaticMemRecorder::dump_svg(std::string svg_name) {
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT,
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT;
float address_scale = 1;
size_t opr_nr = m_opr_seq_recorder.size();
if (opr_nr * OPR_RECT_WIDTH > SVG_WIDTH) {
svg_width = SVG_WIDTH;
opr_rect_width = svg_width / opr_nr;
opr_rect_height = opr_rect_width / 2;
} else {
opr_rect_width = OPR_RECT_WIDTH;
svg_width = opr_nr * opr_rect_width;
}
if (m_sum_mem_size > SVG_HEIGHT) {
svg_height = SVG_HEIGHT;
address_scale = svg_height / m_sum_mem_size;
} else {
svg_height = m_sum_mem_size;
}
// Rescale
float aspect_ratio = SVG_WIDTH / SVG_HEIGHT;
if (svg_width / svg_height < 1) {
svg_width = svg_height * aspect_ratio;
opr_rect_width = svg_width / opr_nr;
opr_rect_height = opr_rect_width / 2;
} else if (svg_width / svg_height > aspect_ratio) {
svg_height = svg_width / aspect_ratio;
address_scale = svg_height / m_sum_mem_size;
}
svg_height = svg_height + opr_rect_height * 2;
std::ofstream outfile;
outfile.open(svg_name);
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl;
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" "
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">"
<< std::endl;
outfile << "<svg width=\"" + std::to_string(svg_width) + "\" height=\"" +
std::to_string(svg_height) +
"\" version=\"1.1\" "
"xmlns=\"http://www.w3.org/2000/svg\">"
<< std::endl;
float base_height = svg_height - opr_rect_height;
std::string peak_mem_polyline =
"0," +
std::to_string(base_height - m_peak_mem_size * address_scale) +
" " + std::to_string(m_opr_seq_recorder.size() * opr_rect_width) +
"," + std::to_string(base_height - m_peak_mem_size * address_scale);
std::string sum_mem_polyline =
"0," +
std::to_string(base_height - m_sum_mem_size * address_scale) + " " +
std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + "," +
std::to_string(base_height - m_sum_mem_size * address_scale);
std::string memory_polyline = "";
for (size_t i = 0; i < m_opr_seq_recorder.size(); i++) {
auto&& opr = m_opr_seq_recorder.at(i);
memory_polyline +=
std::to_string((i + 0.5) * opr_rect_width) + "," +
std::to_string(base_height - opr.size * address_scale) + " ";
outfile << draw_text(std::to_string(i * opr_rect_width),
std::to_string(svg_height - opr_rect_height * 0.5),
std::to_string(opr_rect_height * 0.5),
"opr" + std::to_string(i))
<< std::endl;
std::string opr_info =
set_opr_info(
std::to_string(opr.id),
std::to_string(opr.size) + "B(" +
std::to_string(opr.size / 1024.0 / 1024.0) +
"MiB)",
opr.name) +
" opacity=\"0\"";
outfile << draw_rect(std::to_string(i * opr_rect_width),
std::to_string(base_height),
std::to_string(opr_rect_width),
std::to_string(opr_rect_height), "white", opr_info)
<< std::endl;
}
for (size_t i = 0; i < m_memory_chunk_recorder.size(); i++) {
auto&& chunk = m_memory_chunk_recorder.at(i);
std::string chunk_info = set_chunk_info(
std::to_string(chunk.id),
"[" + std::to_string(chunk.time_begin) + "," +
std::to_string(chunk.time_end) + ")",
"[" + std::to_string(chunk.addr_begin) + "," +
std::to_string(chunk.addr_end) + ")",
std::to_string(chunk.addr_end - chunk.addr_begin) + "B(" +
std::to_string((chunk.addr_end - chunk.addr_begin) /
1024.0 / 1024.0) +
"MiB)",
chunk.owner_var_name);
outfile << draw_rect(
std::to_string(chunk.time_begin * opr_rect_width),
std::to_string(base_height -
chunk.addr_end * address_scale),
std::to_string((chunk.time_end - chunk.time_begin) *
opr_rect_width),
std::to_string((chunk.addr_end - chunk.addr_begin) *
address_scale),
"gray", chunk_info)
<< std::endl;
outfile << draw_text(std::to_string(chunk.time_begin * opr_rect_width),
std::to_string(base_height -
chunk.addr_end * address_scale + 9),
std::to_string(9),
"chunk" + std::to_string(chunk.id))
<< std::endl;
}
outfile << draw_text("0",
std::to_string(base_height -
m_peak_mem_size * address_scale +
opr_rect_height * 0.5),
std::to_string(opr_rect_height * 0.5),
"peak_memory_size:" + std::to_string(m_peak_mem_size) +
"B(" +
std::to_string(m_peak_mem_size / 1024.0 /
1024.0) +
"MiB)")
<< std::endl;
outfile << draw_text("0",
std::to_string(base_height -
m_sum_mem_size * address_scale +
opr_rect_height * 0.5),
std::to_string(opr_rect_height * 0.5),
"sum_memory_size:" + std::to_string(m_sum_mem_size) +
"B(" +
std::to_string(m_sum_mem_size / 1024.0 /
1024.0) +
"MiB)")
<< std::endl;
outfile << draw_polyline(memory_polyline, "blue",
std::to_string(opr_rect_height * 0.1))
<< std::endl;
outfile << draw_polyline(peak_mem_polyline, "green",
std::to_string(opr_rect_height * 0.1))
<< std::endl;
outfile << draw_polyline(sum_mem_polyline, "red",
std::to_string(opr_rect_height * 0.1))
<< std::endl;
outfile << "<text svg:info=\"The abscissa represents the opr sequence, the "
"ordinate represents the logical address.\" "
"svg:chunk_time=\"[opra,oprb) means the chunk is created when "
"opra execute and is freed before oprb\" "
"svg:chunk_oner_var_name=\"var that first creates this "
"chunk\"></text>"
<< std::endl;
outfile << "</svg>" << std::endl;
outfile.close();
}
void StaticMemRecorder::show(std::string svg_name) {
for (auto&& i : m_memory_chunk_recorder) {
if (i.id >= m_weight_chunk_id) {
break;
}
size_t begin = i.time_begin, end = i.time_end;
if (i.is_overwrite) {
begin++;
}
for (size_t j = begin; j < end; j++) {
m_opr_seq_recorder.at(j).size += i.size_orig;
}
}
// log peak memory size, where it is reached and which chunks constitute it.
mgb_log("peak_mem_size = %zu\n", m_peak_mem_size);
size_t max_size = 0;
std::vector<size_t> opr_ids;
for (auto&& i : m_opr_seq_recorder) {
if (i.size == max_size) {
opr_ids.push_back(i.id);
} else if (i.size > max_size) {
max_size = i.size;
opr_ids.clear();
opr_ids.push_back(i.id);
}
}
auto opr2chunk = get_chunk_construct(opr_ids);
mgb_log("oprs reach the peak memory:\n");
for (auto&& i : opr_ids) {
mgb_log("opr id = %zu\n", i);
}
mgb_log("More details:\n");
for (size_t i = 0; i < opr2chunk.size(); i++) {
mgb_log("opr id = %zu\n", opr_ids.at(i));
if (i + 1 < opr2chunk.size() &&
opr2chunk.at(i) == opr2chunk.at(i + 1)) {
continue;
}
for (size_t j = 0; j < opr2chunk.at(i).size(); j++) {
auto&& chunk = m_memory_chunk_recorder.at(opr2chunk.at(i).at(j));
mgb_log("[memory_chunk_id=%zu, size=%zu B, "
"[life_begin=%zu,life_end=%zu), owner_opr_name=%s]\n",
chunk.id, chunk.size_orig, chunk.time_begin, chunk.time_end,
m_opr_seq_recorder.at(chunk.time_begin).name.c_str());
}
}
dump_svg(svg_name);
}
std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct(
std::vector<size_t> opr_ids) {
std::vector<std::vector<size_t>> chunk_ids;
chunk_ids.resize(opr_ids.size());
for (auto&& i : m_memory_chunk_recorder) {
if (i.id >= m_weight_chunk_id) {
break;
}
size_t begin = i.time_begin, end = i.time_end;
if (i.is_overwrite) {
begin = begin + 1;
}
if (opr_ids.front() >= end || opr_ids.back() < begin) {
continue;
}
for (size_t k = 0; k < opr_ids.size(); k++) {
if (opr_ids.at(k) >= end) {
break;
} else if (opr_ids.at(k) >= begin) {
chunk_ids.at(k).push_back(i.id);
}
}
}
return chunk_ids;
}
\ No newline at end of file
/**
* \file src/plugin/include/megbrain/plugin/static_mem_record.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/utils/metahelper.h"
namespace mgb {
namespace cg {
class StaticMemRecorder : public NonCopyableObj {
public:
static StaticMemRecorder& Instance() {
static StaticMemRecorder StaticMemRecorder;
return StaticMemRecorder;
}
struct opr_record {
size_t id, size;
std::string name;
};
struct memory_chunk_record {
size_t id, size_orig, time_begin, time_end, addr_begin,
addr_end, overwrite_dest_id;
bool is_overwrite;
std::string owner_var_name;
};
void active() { m_is_record = true; }
bool valid() { return m_is_record; }
void clear_opr_seq() { m_opr_seq_recorder.clear(); }
void regist_opr_seq(opr_record opr) { m_opr_seq_recorder.push_back(opr); }
void clear_memory_chunk() { m_memory_chunk_recorder.clear(); }
void regist_memory_chunk(memory_chunk_record mcr) {
m_memory_chunk_recorder.push_back(mcr);
}
void regist_memory_chunk_owner_var_name(size_t id, std::string name) {
m_memory_chunk_recorder.at(id).owner_var_name = name;
}
void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; }
const size_t& peak_mem_size() { return m_peak_mem_size; }
void set_sum_mem_size(size_t size) { m_sum_mem_size = size; }
const size_t& sum_mem_size() { return m_sum_mem_size; }
const size_t& set_weight_chunk_id() {
m_weight_chunk_id = m_memory_chunk_recorder.size();
return m_weight_chunk_id;
}
const size_t& weight_chunk_id() { return m_weight_chunk_id; }
void dump_svg(std::string svg_name);
void show(std::string svg_name);
private:
bool m_is_record = false;
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are
// weights memory chunks
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id;
std::vector<opr_record> m_opr_seq_recorder;
std::vector<memory_chunk_record> m_memory_chunk_recorder;
std::vector<std::vector<size_t>> get_chunk_construct(
std::vector<size_t> opr_ids);
};
} // namespace cg
} // namespace mgb
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册