Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6070f127
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6070f127
编写于
6月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix getting static memory alloc info
GitOrigin-RevId: dfc69c3b3f95b11d708ada0891526db50e4b382c
上级
e8a5932d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
61 addition
and
33 deletion
+61
-33
src/core/impl/graph/cg_impl_seq.cpp
src/core/impl/graph/cg_impl_seq.cpp
+33
-22
src/core/impl/graph/cg_impl_seq.h
src/core/impl/graph/cg_impl_seq.h
+4
-1
src/core/include/megbrain/graph/bases.h
src/core/include/megbrain/graph/bases.h
+2
-1
src/plugin/impl/static_mem_record.cpp
src/plugin/impl/static_mem_record.cpp
+4
-4
src/plugin/include/megbrain/plugin/static_mem_record.h
src/plugin/include/megbrain/plugin/static_mem_record.h
+18
-5
未找到文件。
src/core/impl/graph/cg_impl_seq.cpp
浏览文件 @
6070f127
...
...
@@ -12,6 +12,7 @@
#include "./cg_impl_seq.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/utils/arith_helper.h"
using
namespace
mgb
;
using
namespace
cg
;
...
...
@@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute(
}
exec_ctx
.
perform
(
&
m_exec_env
);
#ifndef __IN_TEE_ENV__
do_regist
();
#endif
}
void
ComputingGraphImpl
::
ComputingSequence
::
preprocess
(
ExecContext
*
ctx
)
{
...
...
@@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
}
#ifndef __IN_TEE_ENV__
void
ComputingGraphImpl
::
ComputingSequence
::
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
)
{
check_not_finalized
();
const
std
::
string
&
svg_name
)
const
{
auto
&
recorder
=
StaticMemRecorder
::
Instance
();
recorder
.
active
();
ExecContext
exec_ctx
{
this
};
recorder
.
set_svg_name
(
svg_name
);
}
void
ComputingGraphImpl
::
ComputingSequence
::
do_regist
()
const
{
// regist weights
size_t
addr_base
=
recorder
.
peak_mem_size
();
size_t
chunk_id
=
recorder
.
set_weight_chunk_id
();
for
(
auto
&&
i
:
*
(
this
->
m_opr_seq
))
{
auto
op
=
i
->
output
();
for
(
auto
&&
j
:
op
)
{
auto
&
mp
=
j
->
mem_plan
();
if
(
mp
.
valid
())
{
auto
&
mc
=
mp
.
chunk
();
if
(
mp
.
valid
()
&&
mc
.
mem_alloc_status
.
is_from_owner_var
())
{
recorder
.
regist_memory_chunk
(
{
chunk_id
++
,
mc
.
size
(),
0
,
this
->
m_opr_seq
->
size
(),
addr_base
,
addr_base
+
mc
.
size
(),
0
,
false
,
mc
.
owner_var
->
name
()});
addr_base
+=
mc
.
size
();
auto
&
recorder
=
StaticMemRecorder
::
Instance
();
if
(
recorder
.
valid
())
{
size_t
addr_base
=
recorder
.
peak_mem_size
();
size_t
chunk_id
=
recorder
.
set_weight_chunk_id
();
for
(
auto
&&
i
:
*
(
this
->
m_opr_seq
))
{
auto
op
=
i
->
output
();
for
(
auto
&&
j
:
op
)
{
auto
&
mp
=
j
->
mem_plan
();
if
(
mp
.
valid
())
{
auto
&
mc
=
mp
.
chunk
();
if
(
mp
.
valid
()
&&
mc
.
mem_alloc_status
.
is_from_owner_var
())
{
auto
size
=
mgb
::
get_aligned_power2
(
mc
.
size
(),
j
->
comp_node
().
get_mem_addr_alignment
());
recorder
.
regist_memory_chunk
(
{
chunk_id
++
,
size
,
0
,
this
->
m_opr_seq
->
size
(),
addr_base
,
addr_base
+
size
,
0
,
false
,
mc
.
owner_var
->
name
()});
addr_base
+=
size
;
}
}
}
}
recorder
.
set_sum_mem_size
(
addr_base
);
recorder
.
show
();
}
recorder
.
set_sum_mem_size
(
addr_base
);
mgb_assert
(
svg_name
.
length
()
>
4
,
"svg_name must be end with
\"
.svg
\"\n
"
);
mgb_assert
(
svg_name
.
compare
(
svg_name
.
length
()
-
4
,
4
,
".svg"
)
==
0
,
"svg_name must be end with
\"
.svg
\"\n
"
);
recorder
.
show
(
svg_name
);
}
#endif
AsyncExecutable
&
ComputingGraphImpl
::
ComputingSequence
::
wait
()
{
...
...
src/core/impl/graph/cg_impl_seq.h
浏览文件 @
6070f127
...
...
@@ -174,7 +174,10 @@ public:
std
::
unique_ptr
<
RecordedComputingSequence
>
as_recorded_seq
();
#ifndef __IN_TEE_ENV__
void
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
=
"static_mem_record.svg"
)
override
;
const
std
::
string
&
svg_name
=
"static_mem_record.svg"
)
const
override
;
void
do_regist
()
const
;
#endif
};
...
...
src/core/include/megbrain/graph/bases.h
浏览文件 @
6070f127
...
...
@@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable,
return
(
*
(
output_vars_pair
.
first
))
->
get_output_vars
();
}
#ifndef __IN_TEE_ENV__
virtual
void
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
)
{
virtual
void
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
)
const
{
mgb_assert
(
svg_name
.
length
()
<
0
,
"can't call this function directly
\n
"
);
}
...
...
src/plugin/impl/static_mem_record.cpp
浏览文件 @
6070f127
...
...
@@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color,
}
}
// namespace
void
StaticMemRecorder
::
dump_svg
(
std
::
string
svg_name
)
{
void
StaticMemRecorder
::
dump_svg
()
{
float
svg_width
=
SVG_WIDTH
,
svg_height
=
SVG_HEIGHT
,
opr_rect_width
=
OPR_RECT_WIDTH
,
opr_rect_height
=
OPR_RECT_HEIGHT
;
float
address_scale
=
1
;
...
...
@@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
svg_height
=
svg_height
+
opr_rect_height
*
2
;
std
::
ofstream
outfile
;
outfile
.
open
(
svg_name
);
outfile
.
open
(
m_
svg_name
);
outfile
<<
"<?xml version=
\"
1.0
\"
standalone=
\"
no
\"
?>"
<<
std
::
endl
;
outfile
<<
"<!DOCTYPE svg PUBLIC
\"
-//W3C//DTD SVG 1.1//EN/
\"
"
"
\"
http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd
\"
>"
...
...
@@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
outfile
.
close
();
}
void
StaticMemRecorder
::
show
(
std
::
string
svg_name
)
{
void
StaticMemRecorder
::
show
()
{
for
(
auto
&&
i
:
m_memory_chunk_recorder
)
{
if
(
i
.
id
>=
m_weight_chunk_id
)
{
break
;
...
...
@@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) {
m_opr_seq_recorder
.
at
(
chunk
.
time_begin
).
name
.
c_str
());
}
}
dump_svg
(
svg_name
);
dump_svg
();
}
std
::
vector
<
std
::
vector
<
size_t
>>
StaticMemRecorder
::
get_chunk_construct
(
...
...
src/plugin/include/megbrain/plugin/static_mem_record.h
浏览文件 @
6070f127
...
...
@@ -54,25 +54,38 @@ public:
void
regist_peak_mem_size
(
size_t
size
)
{
m_peak_mem_size
=
size
;
}
const
size_t
&
peak_mem_size
()
{
return
m_peak_mem_size
;
}
const
size_t
&
peak_mem_size
()
const
{
return
m_peak_mem_size
;
}
void
set_sum_mem_size
(
size_t
size
)
{
m_sum_mem_size
=
size
;
}
const
size_t
&
sum_mem_size
()
{
return
m_sum_mem_size
;
}
const
size_t
&
sum_mem_size
()
const
{
return
m_sum_mem_size
;
}
const
size_t
&
set_weight_chunk_id
()
{
m_weight_chunk_id
=
m_memory_chunk_recorder
.
size
();
return
m_weight_chunk_id
;
}
const
size_t
&
weight_chunk_id
()
{
return
m_weight_chunk_id
;
}
const
size_t
&
weight_chunk_id
()
const
{
return
m_weight_chunk_id
;
}
void
dump_svg
(
std
::
string
svg_name
);
void
dump_svg
();
void
show
(
std
::
string
svg_name
);
void
show
();
void
set_svg_name
(
const
std
::
string
&
svg_name
)
{
mgb_assert
(
svg_name
.
length
()
>
4
,
"svg_name must be end with
\"
.svg
\"\n
"
);
mgb_assert
(
svg_name
.
compare
(
svg_name
.
length
()
-
4
,
4
,
".svg"
)
==
0
,
"svg_name must be end with
\"
.svg
\"\n
"
);
m_svg_name
=
svg_name
;
}
const
std
::
string
&
get_svg_name
()
const
{
return
m_svg_name
;
}
private:
bool
m_is_record
=
false
;
std
::
string
m_svg_name
;
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are
// weights memory chunks
size_t
m_peak_mem_size
,
m_sum_mem_size
,
m_weight_chunk_id
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录