提交 07de1571 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mgb): remove static mem record from tee

GitOrigin-RevId: ac61b2a5ebb813ed183048a5b5d88dc56bd589d6
上级 d7b6bfd5
......@@ -141,9 +141,13 @@ R"__usage__(
level 2 the computing graph can be destructed to reduce memory usage. Read
the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more
details.
)__usage__"
#ifndef __IN_TEE_ENV__
R"__usage__(
--get-static-mem-info <svgname>
Record the static graph's static memory info.
)__usage__"
#endif
#if MGB_ENABLE_FASTRUN
R"__usage__(
--full-run
......@@ -538,7 +542,9 @@ struct Args {
#endif
bool reproducible = false;
std::string fast_run_cache_path;
#ifndef __IN_TEE_ENV__
std::string static_mem_svg_path;
#endif
bool copy_to_host = false;
int nr_run = 10;
int nr_warmup = 1;
......@@ -797,9 +803,11 @@ void run_test_st(Args &env) {
}
auto func = env.load_ret.graph_compile(out_spec);
#ifndef __IN_TEE_ENV__
if (!env.static_mem_svg_path.empty()) {
func->get_static_memory_alloc_info(env.static_mem_svg_path);
}
#endif
auto warmup = [&]() {
printf("=== prepare: %.3fms; going to warmup\n",
timer.get_msecs_reset());
......@@ -1383,6 +1391,7 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.comp_node_seq_record_level = 2;
continue;
}
#ifndef __IN_TEE_ENV__
if (!strcmp(argv[i], "--get-static-mem-info")) {
++i;
mgb_assert(i < argc, "value not given for --get-static-mem-info");
......@@ -1393,6 +1402,7 @@ Args Args::from_argv(int argc, char **argv) {
ret.static_mem_svg_path.c_str());
continue;
}
#endif
#if MGB_ENABLE_FASTRUN
if (!strcmp(argv[i], "--fast-run")) {
ret.use_fast_run = true;
......
......@@ -491,7 +491,7 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
do_execute(nullptr);
return *this;
}
#ifndef __IN_TEE_ENV__
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
const std::string& svg_name) {
check_not_finalized();
......@@ -523,7 +523,7 @@ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
"svg_name must be end with \".svg\"\n");
recorder.show(svg_name);
}
#endif
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() {
do_wait(true);
return *this;
......
......@@ -170,9 +170,10 @@ public:
}
std::unique_ptr<RecordedComputingSequence> as_recorded_seq();
#ifndef __IN_TEE_ENV__
void get_static_memory_alloc_info(
const std::string& svg_name = "static_mem_record.svg") override;
#endif
};
class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj {
......
......@@ -178,18 +178,19 @@ bool SeqMemOptimizer::run_static_mem_alloc() {
ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval;
// get all memory chunks
#ifndef __IN_TEE_ENV__
if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().clear_opr_seq();
}
#endif
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) {
OperatorNodeBase *opr = m_cur_seq_full->at(idx);
#ifndef __IN_TEE_ENV__
if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().regist_opr_seq(
{idx, 0, opr->name()});
}
#endif
auto &&dep_map = opr->node_prop().dep_map();
if (in_sys_alloc(opr)) {
......@@ -358,6 +359,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
chk.chunk->mem_alloc_status.set_static_offset(
allocator->get_start_addr(&chk));
}
#ifndef __IN_TEE_ENV__
auto& recorder = StaticMemRecorder::Instance();
if (recorder.valid()) {
for (size_t i = 0; i < chunks.size(); i++) {
......@@ -366,6 +368,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
}
recorder.regist_peak_mem_size(size);
}
#endif
}
return should_realloc;
......
......@@ -119,7 +119,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
do_solve();
check_result_and_calc_lower_bound();
#ifndef __IN_TEE_ENV__
if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().clear_memory_chunk();
for (auto&& i : m_interval) {
......@@ -135,7 +135,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
is_overwrite, ""});
}
}
#endif
return *this;
}
......
......@@ -194,11 +194,12 @@ class AsyncExecutable : public json::Serializable,
m_user_data.get_user_data<OutputVarsUserData>();
return (*(output_vars_pair.first))->get_output_vars();
}
#ifndef __IN_TEE_ENV__
virtual void get_static_memory_alloc_info(const std::string& svg_name) {
mgb_assert(svg_name.length() < 0,
"can't call this function directly\n");
}
#endif
};
......
......@@ -14,13 +14,11 @@
#ifndef __IN_TEE_ENV__
#include <fstream>
#include <iostream>
#endif
using namespace mgb;
using namespace cg;
namespace {
#ifndef __IN_TEE_ENV__
#define SVG_WIDTH 20000.0
#define SVG_HEIGHT 15000.0
#define OPR_RECT_WIDTH 40.0
......@@ -86,13 +84,9 @@ std::string draw_polyline(std::string point_seq, std::string color,
std::string width, std::string p = polyline) {
return replace_by_parameter(p, 0, point_seq, color, width);
}
#endif
} // namespace
void StaticMemRecorder::dump_svg(std::string svg_name) {
#ifdef __IN_TEE_ENV__
MGB_MARK_USED_VAR(svg_name);
#else
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT,
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT;
float address_scale = 1;
......@@ -247,7 +241,6 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
<< std::endl;
outfile << "</svg>" << std::endl;
outfile.close();
#endif
}
void StaticMemRecorder::show(std::string svg_name) {
......@@ -326,3 +319,4 @@ std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct(
}
return chunk_ids;
}
#endif
\ No newline at end of file
......@@ -12,7 +12,7 @@
#pragma once
#include "megbrain/utils/metahelper.h"
#ifndef __IN_TEE_ENV__
namespace mgb {
namespace cg {
......@@ -83,3 +83,4 @@ private:
};
} // namespace cg
} // namespace mgb
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册