From 07de15713c1687fceeed70237703d26736937181 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 6 Jul 2021 15:43:35 +0800 Subject: [PATCH] fix(mgb): remove static mem record from tee GitOrigin-RevId: ac61b2a5ebb813ed183048a5b5d88dc56bd589d6 --- sdk/load-and-run/src/mgblar.cpp | 10 ++++++++++ src/core/impl/graph/cg_impl_seq.cpp | 4 ++-- src/core/impl/graph/cg_impl_seq.h | 3 ++- src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp | 9 ++++++--- .../graph/var_node_mem_mgr/static_mem_alloc/impl.cpp | 4 ++-- src/core/include/megbrain/graph/bases.h | 3 ++- src/plugin/impl/static_mem_record.cpp | 8 +------- src/plugin/include/megbrain/plugin/static_mem_record.h | 3 ++- 8 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index d4138dcc2..015f891cc 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -141,9 +141,13 @@ R"__usage__( level 2 the computing graph can be destructed to reduce memory usage. Read the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more details. +)__usage__" +#ifndef __IN_TEE_ENV__ +R"__usage__( --get-static-mem-info Record the static graph's static memory info. )__usage__" +#endif #if MGB_ENABLE_FASTRUN R"__usage__( --full-run @@ -538,7 +542,9 @@ struct Args { #endif bool reproducible = false; std::string fast_run_cache_path; +#ifndef __IN_TEE_ENV__ std::string static_mem_svg_path; +#endif bool copy_to_host = false; int nr_run = 10; int nr_warmup = 1; @@ -797,9 +803,11 @@ void run_test_st(Args &env) { } auto func = env.load_ret.graph_compile(out_spec); +#ifndef __IN_TEE_ENV__ if (!env.static_mem_svg_path.empty()) { func->get_static_memory_alloc_info(env.static_mem_svg_path); } +#endif auto warmup = [&]() { printf("=== prepare: %.3fms; going to warmup\n", timer.get_msecs_reset()); @@ -1383,6 +1391,7 @@ Args Args::from_argv(int argc, char **argv) { graph_opt.comp_node_seq_record_level = 2; continue; } +#ifndef __IN_TEE_ENV__ if (!strcmp(argv[i], "--get-static-mem-info")) { ++i; mgb_assert(i < argc, "value not given for --get-static-mem-info"); @@ -1393,6 +1402,7 @@ Args Args::from_argv(int argc, char **argv) { ret.static_mem_svg_path.c_str()); continue; } +#endif #if MGB_ENABLE_FASTRUN if (!strcmp(argv[i], "--fast-run")) { ret.use_fast_run = true; diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index 926bf8212..c59604f41 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -491,7 +491,7 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { do_execute(nullptr); return *this; } - +#ifndef __IN_TEE_ENV__ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( const std::string& svg_name) { check_not_finalized(); @@ -523,7 +523,7 @@ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( "svg_name must be end with \".svg\"\n"); recorder.show(svg_name); } - +#endif AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { do_wait(true); return *this; diff --git a/src/core/impl/graph/cg_impl_seq.h b/src/core/impl/graph/cg_impl_seq.h index 8cdc3bee7..47818a6f3 100644 --- a/src/core/impl/graph/cg_impl_seq.h +++ b/src/core/impl/graph/cg_impl_seq.h @@ -170,9 +170,10 @@ public: } std::unique_ptr as_recorded_seq(); - +#ifndef __IN_TEE_ENV__ void get_static_memory_alloc_info( const std::string& svg_name = "static_mem_record.svg") override; +#endif }; class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { diff --git a/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp b/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp index c5f41656a..05d33681b 100644 --- a/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp +++ b/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp @@ -178,18 +178,19 @@ bool SeqMemOptimizer::run_static_mem_alloc() { ThinHashMap chk2interval; // get all memory chunks +#ifndef __IN_TEE_ENV__ if (StaticMemRecorder::Instance().valid()) { StaticMemRecorder::Instance().clear_opr_seq(); } - +#endif for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { OperatorNodeBase *opr = m_cur_seq_full->at(idx); - +#ifndef __IN_TEE_ENV__ if (StaticMemRecorder::Instance().valid()) { StaticMemRecorder::Instance().regist_opr_seq( {idx, 0, opr->name()}); } - +#endif auto &&dep_map = opr->node_prop().dep_map(); if (in_sys_alloc(opr)) { @@ -358,6 +359,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( chk.chunk->mem_alloc_status.set_static_offset( allocator->get_start_addr(&chk)); } +#ifndef __IN_TEE_ENV__ auto& recorder = StaticMemRecorder::Instance(); if (recorder.valid()) { for (size_t i = 0; i < chunks.size(); i++) { @@ -366,6 +368,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( } recorder.regist_peak_mem_size(size); } +#endif } return should_realloc; diff --git a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp index 693080084..06ea42d7e 100644 --- a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp +++ b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp @@ -119,7 +119,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { do_solve(); check_result_and_calc_lower_bound(); - +#ifndef __IN_TEE_ENV__ if (StaticMemRecorder::Instance().valid()) { StaticMemRecorder::Instance().clear_memory_chunk(); for (auto&& i : m_interval) { @@ -135,7 +135,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { is_overwrite, ""}); } } - +#endif return *this; } diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index 6b79d385c..7d13b527b 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -194,11 +194,12 @@ class AsyncExecutable : public json::Serializable, m_user_data.get_user_data(); return (*(output_vars_pair.first))->get_output_vars(); } - +#ifndef __IN_TEE_ENV__ virtual void get_static_memory_alloc_info(const std::string& svg_name) { mgb_assert(svg_name.length() < 0, "can't call this function directly\n"); } +#endif }; diff --git a/src/plugin/impl/static_mem_record.cpp b/src/plugin/impl/static_mem_record.cpp index 25a57bc2a..14e86ad9b 100644 --- a/src/plugin/impl/static_mem_record.cpp +++ b/src/plugin/impl/static_mem_record.cpp @@ -14,13 +14,11 @@ #ifndef __IN_TEE_ENV__ #include #include -#endif using namespace mgb; using namespace cg; namespace { -#ifndef __IN_TEE_ENV__ #define SVG_WIDTH 20000.0 #define SVG_HEIGHT 15000.0 #define OPR_RECT_WIDTH 40.0 @@ -86,13 +84,9 @@ std::string draw_polyline(std::string point_seq, std::string color, std::string width, std::string p = polyline) { return replace_by_parameter(p, 0, point_seq, color, width); } -#endif } // namespace void StaticMemRecorder::dump_svg(std::string svg_name) { -#ifdef __IN_TEE_ENV__ - MGB_MARK_USED_VAR(svg_name); -#else float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; float address_scale = 1; @@ -247,7 +241,6 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { << std::endl; outfile << "" << std::endl; outfile.close(); -#endif } void StaticMemRecorder::show(std::string svg_name) { @@ -326,3 +319,4 @@ std::vector> StaticMemRecorder::get_chunk_construct( } return chunk_ids; } +#endif \ No newline at end of file diff --git a/src/plugin/include/megbrain/plugin/static_mem_record.h b/src/plugin/include/megbrain/plugin/static_mem_record.h index 6276227d7..78e583594 100644 --- a/src/plugin/include/megbrain/plugin/static_mem_record.h +++ b/src/plugin/include/megbrain/plugin/static_mem_record.h @@ -12,7 +12,7 @@ #pragma once #include "megbrain/utils/metahelper.h" - +#ifndef __IN_TEE_ENV__ namespace mgb { namespace cg { @@ -83,3 +83,4 @@ private: }; } // namespace cg } // namespace mgb +#endif -- GitLab