提交 6070f127 编写于 作者: M Megvii Engine Team

fix(mgb): fix getting static memory alloc info

GitOrigin-RevId: dfc69c3b3f95b11d708ada0891526db50e4b382c
上级 e8a5932d
......@@ -12,6 +12,7 @@
#include "./cg_impl_seq.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/utils/arith_helper.h"
using namespace mgb;
using namespace cg;
......@@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute(
}
exec_ctx.perform(&m_exec_env);
#ifndef __IN_TEE_ENV__
do_regist();
#endif
}
void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) {
......@@ -511,12 +515,16 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
}
#ifndef __IN_TEE_ENV__
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
const std::string& svg_name) {
check_not_finalized();
const std::string& svg_name) const {
auto& recorder = StaticMemRecorder::Instance();
recorder.active();
ExecContext exec_ctx{this};
recorder.set_svg_name(svg_name);
}
void ComputingGraphImpl::ComputingSequence::do_regist() const {
// regist weights
auto& recorder = StaticMemRecorder::Instance();
if (recorder.valid()) {
size_t addr_base = recorder.peak_mem_size();
size_t chunk_id = recorder.set_weight_chunk_id();
for (auto&& i : *(this->m_opr_seq)) {
......@@ -526,20 +534,23 @@ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
if (mp.valid()) {
auto& mc = mp.chunk();
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) {
auto size = mgb::get_aligned_power2(
mc.size(),
j->comp_node().get_mem_addr_alignment());
recorder.regist_memory_chunk(
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(),
addr_base, addr_base + mc.size(), 0, false,
{chunk_id++, size, 0, this->m_opr_seq->size(),
addr_base, addr_base + size, 0, false,
mc.owner_var->name()});
addr_base += mc.size();
addr_base += size;
}
}
}
}
recorder.set_sum_mem_size(addr_base);
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n");
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0,
"svg_name must be end with \".svg\"\n");
recorder.show(svg_name);
recorder.show();
}
}
#endif
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() {
......
......@@ -174,7 +174,10 @@ public:
std::unique_ptr<RecordedComputingSequence> as_recorded_seq();
#ifndef __IN_TEE_ENV__
void get_static_memory_alloc_info(
const std::string& svg_name = "static_mem_record.svg") override;
const std::string& svg_name =
"static_mem_record.svg") const override;
void do_regist() const;
#endif
};
......
......@@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable,
return (*(output_vars_pair.first))->get_output_vars();
}
#ifndef __IN_TEE_ENV__
virtual void get_static_memory_alloc_info(const std::string& svg_name) {
virtual void get_static_memory_alloc_info(
const std::string& svg_name) const {
mgb_assert(svg_name.length() < 0,
"can't call this function directly\n");
}
......
......@@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color,
}
} // namespace
void StaticMemRecorder::dump_svg(std::string svg_name) {
void StaticMemRecorder::dump_svg() {
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT,
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT;
float address_scale = 1;
......@@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
svg_height = svg_height + opr_rect_height * 2;
std::ofstream outfile;
outfile.open(svg_name);
outfile.open(m_svg_name);
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl;
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" "
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">"
......@@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
outfile.close();
}
void StaticMemRecorder::show(std::string svg_name) {
void StaticMemRecorder::show() {
for (auto&& i : m_memory_chunk_recorder) {
if (i.id >= m_weight_chunk_id) {
break;
......@@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) {
m_opr_seq_recorder.at(chunk.time_begin).name.c_str());
}
}
dump_svg(svg_name);
dump_svg();
}
std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct(
......
......@@ -54,25 +54,38 @@ public:
void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; }
const size_t& peak_mem_size() { return m_peak_mem_size; }
const size_t& peak_mem_size() const { return m_peak_mem_size; }
void set_sum_mem_size(size_t size) { m_sum_mem_size = size; }
const size_t& sum_mem_size() { return m_sum_mem_size; }
const size_t& sum_mem_size() const { return m_sum_mem_size; }
const size_t& set_weight_chunk_id() {
m_weight_chunk_id = m_memory_chunk_recorder.size();
return m_weight_chunk_id;
}
const size_t& weight_chunk_id() { return m_weight_chunk_id; }
const size_t& weight_chunk_id() const { return m_weight_chunk_id; }
void dump_svg(std::string svg_name);
void dump_svg();
void show(std::string svg_name);
void show();
void set_svg_name(const std::string& svg_name) {
mgb_assert(svg_name.length() > 4,
"svg_name must be end with \".svg\"\n");
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0,
"svg_name must be end with \".svg\"\n");
m_svg_name = svg_name;
}
const std::string& get_svg_name() const{
return m_svg_name;
}
private:
bool m_is_record = false;
std::string m_svg_name;
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are
// weights memory chunks
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册