Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b06b5899
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b06b5899
编写于
4月 23, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): get static graph memory info
GitOrigin-RevId: f31745f8df67e6f239aa66f18dd12546081cd3e5
上级
0cf4ff70
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
633 addition
and
0 deletion
+633
-0
imperative/python/megengine/tools/svg_viewer.html
imperative/python/megengine/tools/svg_viewer.html
+154
-0
src/core/impl/graph/cg_impl_seq.cpp
src/core/impl/graph/cg_impl_seq.cpp
+32
-0
src/core/impl/graph/cg_impl_seq.h
src/core/impl/graph/cg_impl_seq.h
+4
-0
src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
+17
-0
src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h
src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h
+1
-0
src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
...ore/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
+16
-0
src/core/include/megbrain/graph/bases.h
src/core/include/megbrain/graph/bases.h
+5
-0
src/plugin/impl/static_mem_record.cpp
src/plugin/impl/static_mem_record.cpp
+319
-0
src/plugin/include/megbrain/plugin/static_mem_record.h
src/plugin/include/megbrain/plugin/static_mem_record.h
+85
-0
未找到文件。
imperative/python/megengine/tools/svg_viewer.html
0 → 100644
浏览文件 @
b06b5899
<html>
<title>
Visualizer
</title>
<head>
<meta
name=
"viewport"
content=
"width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no"
/>
</head>
<script>
window
.
onload
=
()
=>
{
var
board
=
document
.
getElementById
(
'
board
'
);
var
fileInput
=
document
.
getElementById
(
'
fileInput
'
);
var
desc
=
document
.
getElementById
(
'
desc
'
);
var
hRange
=
document
.
getElementById
(
'
hRange
'
);
var
vRange
=
document
.
getElementById
(
'
vRange
'
);
var
lastColor
=
undefined
;
var
lastElem
=
undefined
;
var
scale
=
1
;
var
svg
=
undefined
;
var
svgWidth
=
undefined
;
var
svgHeight
=
undefined
;
var
loadDesc
=
(
svgElem
)
=>
{
var
mgeType
=
svgElem
.
attributes
[
'
mge:type
'
];
if
(
mgeType
===
undefined
)
{
return
;
}
var
elemList
=
[];
for
(
attrName
of
svgElem
.
getAttributeNames
())
{
var
prefix
=
'
mge:
'
;
if
(
!
attrName
.
startsWith
(
prefix
))
{
continue
;
}
var
elem
=
'
<p>
'
+
attrName
.
substr
(
prefix
.
length
)
+
'
:
'
+
svgElem
.
attributes
[
attrName
].
value
+
'
</p>
'
elemList
.
push
(
elem
);
}
desc
.
innerHTML
=
elemList
.
join
(
''
);
};
var
selectElem
=
svgElem
=>
{
loadDesc
(
svgElem
);
lastColor
=
svgElem
.
attributes
[
'
fill
'
].
value
;
lastElem
=
svgElem
;
svgElem
.
attributes
[
'
fill
'
].
value
=
'
green
'
;
};
var
unselectLast
=
svgElem
=>
{
if
(
lastElem
)
{
lastElem
.
attributes
[
'
fill
'
].
value
=
lastColor
;
}
lastElem
=
undefined
;
lastColor
=
undefined
;
};
function
recLoadSVG
(
svgElem
)
{
if
(
svgElem
.
children
===
undefined
)
{
return
;
}
svgElem
.
onmousedown
=
e
=>
{
var
mgeType
=
svgElem
.
attributes
[
'
mge:type
'
];
if
(
mgeType
===
undefined
)
{
return
;
}
unselectLast
();
selectElem
(
svgElem
);
e
.
stopPropagation
();
};
for
(
child
of
svgElem
.
children
)
{
recLoadSVG
(
child
);
}
}
function
loadSVG
()
{
var
file
=
fileInput
.
files
[
0
];
var
reader
=
new
FileReader
();
reader
.
readAsText
(
file
,
"
UTF-8
"
);
reader
.
onload
=
e
=>
{
board
.
innerHTML
=
'
<p style="margin: 0;">
'
+
e
.
target
.
result
+
'
</p>
'
;
svg
=
board
.
children
[
0
].
children
[
0
];
svgWidth
=
svg
.
attributes
[
'
width
'
].
value
;
svgHeight
=
svg
.
attributes
[
'
height
'
].
value
;
for
(
child
of
board
.
children
)
{
recLoadSVG
(
child
);
var
svgInfo
=
child
.
attributes
[
'
svg:info
'
];
if
(
svgInfo
!==
undefined
)
{
var
elemList
=
[];
for
(
attrName
of
child
.
getAttributeNames
())
{
var
prefix
=
'
svg:
'
;
if
(
!
attrName
.
startsWith
(
prefix
))
{
continue
;
}
var
elem
=
'
<p>
'
+
attrName
.
substr
(
prefix
.
length
)
+
'
:
'
+
child
.
attributes
[
attrName
].
value
+
'
</p>
'
elemList
.
push
(
elem
);
}
info
.
innerHTML
=
elemList
.
join
(
''
);
}
}
};
}
function
scaleBoard
(
x
,
y
)
{
var
transform
=
'
scale(
'
+
x
+
'
,
'
+
y
+
'
)
'
;
svg
.
setAttribute
(
'
transform
'
,
transform
);
board
.
style
[
'
width
'
]
=
svgWidth
*
x
;
board
.
style
[
'
height
'
]
=
svgHeight
*
y
;
}
function
autoScaleBoard
()
{
var
hRangeValue
=
Math
.
sqrt
(
Number
(
hRange
.
value
)
/
10
);
var
vRangeValue
=
Math
.
sqrt
(
Number
(
vRange
.
value
)
/
10
);
scaleBoard
(
Number
(
hRangeValue
),
Number
(
vRangeValue
));
}
fileInput
.
onchange
=
loadSVG
;
var
zoomBoard
=
dScale
=>
{
scale
*=
dScale
;
scaleBoard
(
scale
,
scale
);
};
window
.
addEventListener
(
'
wheel
'
,
e
=>
{
console
.
log
(
e
);
if
(
e
.
ctrlKey
)
{
e
.
preventDefault
();
e
.
stopPropagation
();
var
factor
=
1
;
if
(
e
.
deltaY
<
0
)
{
factor
=
1.1
;
}
else
if
(
e
.
deltaY
>
0
)
{
factor
=
1
/
1.1
;
}
zoomBoard
(
factor
);
var
newPageX
=
e
.
pageX
*
factor
;
var
newPageY
=
e
.
pageY
*
factor
;
x
=
newPageX
-
e
.
x
;
y
=
newPageY
-
e
.
y
;
window
.
scrollTo
({
top
:
y
,
left
:
x
,
});
console
.
log
(
'
scroll
'
,
[
x
,
y
]);
}
},
{
'
passive
'
:
false
});
};
</script>
<body>
<p
id=
"desc"
style=
"position: fixed;bottom: 0; background-color: white;"
>
desc
</p>
<p
id=
"info"
style=
"position: fixed;top: 0; right: 0; background-color: white;"
>
info
</p>
<p
id=
"board"
style=
"white-space: nowrap; display: flex; justify-content: center; align-content: center; align-items: center; margin: 0;opacity: 0.7;"
>
</p>
<input
type=
'file'
id=
'fileInput'
style=
"position: fixed; top: 0; background-color: white;"
></input>
</body>
</html>
\ No newline at end of file
src/core/impl/graph/cg_impl_seq.cpp
浏览文件 @
b06b5899
...
@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
...
@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
return
*
this
;
return
*
this
;
}
}
void
ComputingGraphImpl
::
ComputingSequence
::
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
)
{
check_not_finalized
();
auto
&
recorder
=
StaticMemRecorder
::
Instance
();
recorder
.
active
();
ExecContext
exec_ctx
{
this
};
// regist weights
size_t
addr_base
=
recorder
.
peak_mem_size
();
size_t
chunk_id
=
recorder
.
set_weight_chunk_id
();
for
(
auto
&&
i
:
*
(
this
->
m_opr_seq
))
{
auto
op
=
i
->
output
();
for
(
auto
&&
j
:
op
)
{
auto
&
mp
=
j
->
mem_plan
();
if
(
mp
.
valid
())
{
auto
&
mc
=
mp
.
chunk
();
if
(
mp
.
valid
()
&&
mc
.
mem_alloc_status
.
is_from_owner_var
())
{
recorder
.
regist_memory_chunk
(
{
chunk_id
++
,
mc
.
size
(),
0
,
this
->
m_opr_seq
->
size
(),
addr_base
,
addr_base
+
mc
.
size
(),
0
,
false
,
mc
.
owner_var
->
name
()});
addr_base
+=
mc
.
size
();
}
}
}
}
recorder
.
set_sum_mem_size
(
addr_base
);
mgb_assert
(
svg_name
.
length
()
>
4
,
"svg_name must be end with
\"
.svg
\"\n
"
);
mgb_assert
(
svg_name
.
compare
(
svg_name
.
length
()
-
4
,
4
,
".svg"
)
==
0
,
"svg_name must be end with
\"
.svg
\"\n
"
);
recorder
.
show
(
svg_name
);
}
AsyncExecutable
&
ComputingGraphImpl
::
ComputingSequence
::
wait
()
{
AsyncExecutable
&
ComputingGraphImpl
::
ComputingSequence
::
wait
()
{
do_wait
(
true
);
do_wait
(
true
);
return
*
this
;
return
*
this
;
...
...
src/core/impl/graph/cg_impl_seq.h
浏览文件 @
b06b5899
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "megbrain/comp_node_env.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/plugin/var_sanity_check.h"
#include "megbrain/plugin/var_sanity_check.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/plugin/static_mem_record.h"
namespace
mgb
{
namespace
mgb
{
namespace
cg
{
namespace
cg
{
...
@@ -169,6 +170,9 @@ public:
...
@@ -169,6 +170,9 @@ public:
}
}
std
::
unique_ptr
<
RecordedComputingSequence
>
as_recorded_seq
();
std
::
unique_ptr
<
RecordedComputingSequence
>
as_recorded_seq
();
void
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
=
"static_mem_record.svg"
)
override
;
};
};
class
ComputingGraphImpl
::
MegDNNDtorCheck
:
public
NonCopyableObj
{
class
ComputingGraphImpl
::
MegDNNDtorCheck
:
public
NonCopyableObj
{
...
...
src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
浏览文件 @
b06b5899
...
@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() {
...
@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() {
ThinHashMap
<
MemAllocPlan
::
Chunk
*
,
MemChunkLifeInterval
>
chk2interval
;
ThinHashMap
<
MemAllocPlan
::
Chunk
*
,
MemChunkLifeInterval
>
chk2interval
;
// get all memory chunks
// get all memory chunks
if
(
StaticMemRecorder
::
Instance
().
valid
())
{
StaticMemRecorder
::
Instance
().
clear_opr_seq
();
}
for
(
size_t
idx
=
0
;
idx
<
m_cur_seq_full
->
size
();
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
m_cur_seq_full
->
size
();
++
idx
)
{
OperatorNodeBase
*
opr
=
m_cur_seq_full
->
at
(
idx
);
OperatorNodeBase
*
opr
=
m_cur_seq_full
->
at
(
idx
);
if
(
StaticMemRecorder
::
Instance
().
valid
())
{
StaticMemRecorder
::
Instance
().
regist_opr_seq
(
{
idx
,
0
,
opr
->
name
()});
}
auto
&&
dep_map
=
opr
->
node_prop
().
dep_map
();
auto
&&
dep_map
=
opr
->
node_prop
().
dep_map
();
if
(
in_sys_alloc
(
opr
))
{
if
(
in_sys_alloc
(
opr
))
{
...
@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
...
@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
chk
.
chunk
->
mem_alloc_status
.
set_static_offset
(
chk
.
chunk
->
mem_alloc_status
.
set_static_offset
(
allocator
->
get_start_addr
(
&
chk
));
allocator
->
get_start_addr
(
&
chk
));
}
}
auto
&
recorder
=
StaticMemRecorder
::
Instance
();
if
(
recorder
.
valid
())
{
for
(
size_t
i
=
0
;
i
<
chunks
.
size
();
i
++
)
{
recorder
.
regist_memory_chunk_owner_var_name
(
i
,
chunks
.
at
(
i
).
chunk
->
owner_var
->
name
());
}
recorder
.
regist_peak_mem_size
(
size
);
}
}
}
return
should_realloc
;
return
should_realloc
;
...
...
src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h
浏览文件 @
b06b5899
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#pragma once
#pragma once
#include "megbrain/plugin/static_mem_record.h"
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
#include <cstddef>
#include <cstddef>
...
...
src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
浏览文件 @
b06b5899
...
@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
...
@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
check_result_and_calc_lower_bound
();
check_result_and_calc_lower_bound
();
if
(
StaticMemRecorder
::
Instance
().
valid
())
{
StaticMemRecorder
::
Instance
().
clear_memory_chunk
();
for
(
auto
&&
i
:
m_interval
)
{
size_t
overwrite_dest_id
=
0
;
bool
is_overwrite
=
!
i
->
is_overwrite_root
();
if
(
is_overwrite
)
{
overwrite_dest_id
=
i
->
overwrite_dest_root
()
->
id
;
}
StaticMemRecorder
::
Instance
().
regist_memory_chunk
(
{
i
->
id
,
i
->
size_orig
,
i
->
time_begin
,
i
->
time_end
,
i
->
addr_begin
,
i
->
addr_end
(),
overwrite_dest_id
,
is_overwrite
,
""
});
}
}
return
*
this
;
return
*
this
;
}
}
...
...
src/core/include/megbrain/graph/bases.h
浏览文件 @
b06b5899
...
@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable,
...
@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable,
m_user_data
.
get_user_data
<
OutputVarsUserData
>
();
m_user_data
.
get_user_data
<
OutputVarsUserData
>
();
return
(
*
(
output_vars_pair
.
first
))
->
get_output_vars
();
return
(
*
(
output_vars_pair
.
first
))
->
get_output_vars
();
}
}
virtual
void
get_static_memory_alloc_info
(
const
std
::
string
&
svg_name
)
{
mgb_assert
(
svg_name
.
length
()
<
0
,
"can't call this function directly
\n
"
);
}
};
};
...
...
src/plugin/impl/static_mem_record.cpp
0 → 100644
浏览文件 @
b06b5899
/**
* \file src/plugin/impl/static_mem_record.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/plugin/static_mem_record.h"
#include <fstream>
#include <iostream>
using
namespace
mgb
;
using
namespace
cg
;
namespace
{
#define SVG_WIDTH 20000.0
#define SVG_HEIGHT 15000.0
#define OPR_RECT_WIDTH 40.0
#define OPR_RECT_HEIGHT 20.0
const
std
::
string
rect
=
"<rect x=
\"
{}
\"
y=
\"
{}
\"
width=
\"
{}
\"
height=
\"
{}
\"
fill=
\"
{}
\"
"
" {}></rect>"
;
const
std
::
string
text
=
"<text x=
\"
{}
\"
y=
\"
{}
\"
font-size=
\"
{}
\"
>{}</text>"
;
const
std
::
string
polyline
=
"<polyline points=
\"
{}
\"
style=
\"
fill:none;stroke:{};stroke-width:{}
\"
"
"/>"
;
const
std
::
string
opr_info
=
"mge:type=
\"
opr
\"
mge:id=
\"
{}
\"
mge:size=
\"
{}
\"
mge:name=
\"
{}
\"
"
;
const
std
::
string
chunk_info
=
"mge:type=
\"
chunk
\"
mge:id=
\"
{}
\"
mge:time=
\"
{}
\"
mge:addr=
\"
{}
\"
"
"mge:size=
\"
{}
\"
mge:owner_var_name=
\"
{}
\"
"
;
const
std
::
string
animate
=
"<animate attributeName=
\"
opacity
\"
from=
\"
0
\"
to=
\"
1
\"
"
"begin=
\"
{}.mouseover
\"
fill=
\"
freeze
\"
dur=
\"
1s
\"
/>
\n
<animate "
"attributeName=
\"
opacity
\"
from=
\"
1
\"
to=
\"
0
\"
begin=
\"
{}.mouseout
\"
"
"fill=
\"
freeze
\"
dur=
\"
1s
\"
/>"
;
std
::
string
&
replace_by_parameter
(
std
::
string
&
original_str
,
size_t
index
)
{
return
original_str
;
}
template
<
typename
...
Args
>
std
::
string
&
replace_by_parameter
(
std
::
string
&
original_str
,
size_t
index
,
const
std
::
string
&
parameter
,
const
Args
&
...
args
)
{
index
=
original_str
.
find
(
"{}"
,
index
);
original_str
.
replace
(
index
,
2
,
parameter
);
index
+=
parameter
.
length
();
replace_by_parameter
(
original_str
,
index
,
args
...);
return
original_str
;
}
std
::
string
set_opr_info
(
std
::
string
id
,
std
::
string
size
,
std
::
string
name
,
std
::
string
info
=
opr_info
)
{
return
replace_by_parameter
(
info
,
0
,
id
,
size
,
name
);
}
std
::
string
set_chunk_info
(
std
::
string
id
,
std
::
string
time
,
std
::
string
addr
,
std
::
string
size
,
std
::
string
owner_var_name
,
std
::
string
info
=
chunk_info
)
{
return
replace_by_parameter
(
info
,
0
,
id
,
time
,
addr
,
size
,
owner_var_name
);
}
std
::
string
draw_rect
(
std
::
string
x
,
std
::
string
y
,
std
::
string
widith
,
std
::
string
height
,
std
::
string
color
,
std
::
string
info
,
std
::
string
r
=
rect
)
{
return
replace_by_parameter
(
r
,
0
,
x
,
y
,
widith
,
height
,
color
,
info
);
}
std
::
string
draw_text
(
std
::
string
x
,
std
::
string
y
,
std
::
string
font_size
,
std
::
string
txt
,
std
::
string
t
=
text
)
{
return
replace_by_parameter
(
t
,
0
,
x
,
y
,
font_size
,
txt
);
}
std
::
string
draw_polyline
(
std
::
string
point_seq
,
std
::
string
color
,
std
::
string
width
,
std
::
string
p
=
polyline
)
{
return
replace_by_parameter
(
p
,
0
,
point_seq
,
color
,
width
);
}
}
// namespace
void
StaticMemRecorder
::
dump_svg
(
std
::
string
svg_name
)
{
float
svg_width
=
SVG_WIDTH
,
svg_height
=
SVG_HEIGHT
,
opr_rect_width
=
OPR_RECT_WIDTH
,
opr_rect_height
=
OPR_RECT_HEIGHT
;
float
address_scale
=
1
;
size_t
opr_nr
=
m_opr_seq_recorder
.
size
();
if
(
opr_nr
*
OPR_RECT_WIDTH
>
SVG_WIDTH
)
{
svg_width
=
SVG_WIDTH
;
opr_rect_width
=
svg_width
/
opr_nr
;
opr_rect_height
=
opr_rect_width
/
2
;
}
else
{
opr_rect_width
=
OPR_RECT_WIDTH
;
svg_width
=
opr_nr
*
opr_rect_width
;
}
if
(
m_sum_mem_size
>
SVG_HEIGHT
)
{
svg_height
=
SVG_HEIGHT
;
address_scale
=
svg_height
/
m_sum_mem_size
;
}
else
{
svg_height
=
m_sum_mem_size
;
}
// Rescale
float
aspect_ratio
=
SVG_WIDTH
/
SVG_HEIGHT
;
if
(
svg_width
/
svg_height
<
1
)
{
svg_width
=
svg_height
*
aspect_ratio
;
opr_rect_width
=
svg_width
/
opr_nr
;
opr_rect_height
=
opr_rect_width
/
2
;
}
else
if
(
svg_width
/
svg_height
>
aspect_ratio
)
{
svg_height
=
svg_width
/
aspect_ratio
;
address_scale
=
svg_height
/
m_sum_mem_size
;
}
svg_height
=
svg_height
+
opr_rect_height
*
2
;
std
::
ofstream
outfile
;
outfile
.
open
(
svg_name
);
outfile
<<
"<?xml version=
\"
1.0
\"
standalone=
\"
no
\"
?>"
<<
std
::
endl
;
outfile
<<
"<!DOCTYPE svg PUBLIC
\"
-//W3C//DTD SVG 1.1//EN/
\"
"
"
\"
http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd
\"
>"
<<
std
::
endl
;
outfile
<<
"<svg width=
\"
"
+
std
::
to_string
(
svg_width
)
+
"
\"
height=
\"
"
+
std
::
to_string
(
svg_height
)
+
"
\"
version=
\"
1.1
\"
"
"xmlns=
\"
http://www.w3.org/2000/svg
\"
>"
<<
std
::
endl
;
float
base_height
=
svg_height
-
opr_rect_height
;
std
::
string
peak_mem_polyline
=
"0,"
+
std
::
to_string
(
base_height
-
m_peak_mem_size
*
address_scale
)
+
" "
+
std
::
to_string
(
m_opr_seq_recorder
.
size
()
*
opr_rect_width
)
+
","
+
std
::
to_string
(
base_height
-
m_peak_mem_size
*
address_scale
);
std
::
string
sum_mem_polyline
=
"0,"
+
std
::
to_string
(
base_height
-
m_sum_mem_size
*
address_scale
)
+
" "
+
std
::
to_string
(
m_opr_seq_recorder
.
size
()
*
opr_rect_width
)
+
","
+
std
::
to_string
(
base_height
-
m_sum_mem_size
*
address_scale
);
std
::
string
memory_polyline
=
""
;
for
(
size_t
i
=
0
;
i
<
m_opr_seq_recorder
.
size
();
i
++
)
{
auto
&&
opr
=
m_opr_seq_recorder
.
at
(
i
);
memory_polyline
+=
std
::
to_string
((
i
+
0.5
)
*
opr_rect_width
)
+
","
+
std
::
to_string
(
base_height
-
opr
.
size
*
address_scale
)
+
" "
;
outfile
<<
draw_text
(
std
::
to_string
(
i
*
opr_rect_width
),
std
::
to_string
(
svg_height
-
opr_rect_height
*
0.5
),
std
::
to_string
(
opr_rect_height
*
0.5
),
"opr"
+
std
::
to_string
(
i
))
<<
std
::
endl
;
std
::
string
opr_info
=
set_opr_info
(
std
::
to_string
(
opr
.
id
),
std
::
to_string
(
opr
.
size
)
+
"B("
+
std
::
to_string
(
opr
.
size
/
1024.0
/
1024.0
)
+
"MiB)"
,
opr
.
name
)
+
" opacity=
\"
0
\"
"
;
outfile
<<
draw_rect
(
std
::
to_string
(
i
*
opr_rect_width
),
std
::
to_string
(
base_height
),
std
::
to_string
(
opr_rect_width
),
std
::
to_string
(
opr_rect_height
),
"white"
,
opr_info
)
<<
std
::
endl
;
}
for
(
size_t
i
=
0
;
i
<
m_memory_chunk_recorder
.
size
();
i
++
)
{
auto
&&
chunk
=
m_memory_chunk_recorder
.
at
(
i
);
std
::
string
chunk_info
=
set_chunk_info
(
std
::
to_string
(
chunk
.
id
),
"["
+
std
::
to_string
(
chunk
.
time_begin
)
+
","
+
std
::
to_string
(
chunk
.
time_end
)
+
")"
,
"["
+
std
::
to_string
(
chunk
.
addr_begin
)
+
","
+
std
::
to_string
(
chunk
.
addr_end
)
+
")"
,
std
::
to_string
(
chunk
.
addr_end
-
chunk
.
addr_begin
)
+
"B("
+
std
::
to_string
((
chunk
.
addr_end
-
chunk
.
addr_begin
)
/
1024.0
/
1024.0
)
+
"MiB)"
,
chunk
.
owner_var_name
);
outfile
<<
draw_rect
(
std
::
to_string
(
chunk
.
time_begin
*
opr_rect_width
),
std
::
to_string
(
base_height
-
chunk
.
addr_end
*
address_scale
),
std
::
to_string
((
chunk
.
time_end
-
chunk
.
time_begin
)
*
opr_rect_width
),
std
::
to_string
((
chunk
.
addr_end
-
chunk
.
addr_begin
)
*
address_scale
),
"gray"
,
chunk_info
)
<<
std
::
endl
;
outfile
<<
draw_text
(
std
::
to_string
(
chunk
.
time_begin
*
opr_rect_width
),
std
::
to_string
(
base_height
-
chunk
.
addr_end
*
address_scale
+
9
),
std
::
to_string
(
9
),
"chunk"
+
std
::
to_string
(
chunk
.
id
))
<<
std
::
endl
;
}
outfile
<<
draw_text
(
"0"
,
std
::
to_string
(
base_height
-
m_peak_mem_size
*
address_scale
+
opr_rect_height
*
0.5
),
std
::
to_string
(
opr_rect_height
*
0.5
),
"peak_memory_size:"
+
std
::
to_string
(
m_peak_mem_size
)
+
"B("
+
std
::
to_string
(
m_peak_mem_size
/
1024.0
/
1024.0
)
+
"MiB)"
)
<<
std
::
endl
;
outfile
<<
draw_text
(
"0"
,
std
::
to_string
(
base_height
-
m_sum_mem_size
*
address_scale
+
opr_rect_height
*
0.5
),
std
::
to_string
(
opr_rect_height
*
0.5
),
"sum_memory_size:"
+
std
::
to_string
(
m_sum_mem_size
)
+
"B("
+
std
::
to_string
(
m_sum_mem_size
/
1024.0
/
1024.0
)
+
"MiB)"
)
<<
std
::
endl
;
outfile
<<
draw_polyline
(
memory_polyline
,
"blue"
,
std
::
to_string
(
opr_rect_height
*
0.1
))
<<
std
::
endl
;
outfile
<<
draw_polyline
(
peak_mem_polyline
,
"green"
,
std
::
to_string
(
opr_rect_height
*
0.1
))
<<
std
::
endl
;
outfile
<<
draw_polyline
(
sum_mem_polyline
,
"red"
,
std
::
to_string
(
opr_rect_height
*
0.1
))
<<
std
::
endl
;
outfile
<<
"<text svg:info=
\"
The abscissa represents the opr sequence, the "
"ordinate represents the logical address.
\"
"
"svg:chunk_time=
\"
[opra,oprb) means the chunk is created when "
"opra execute and is freed before oprb
\"
"
"svg:chunk_oner_var_name=
\"
var that first creates this "
"chunk
\"
></text>"
<<
std
::
endl
;
outfile
<<
"</svg>"
<<
std
::
endl
;
outfile
.
close
();
}
void
StaticMemRecorder
::
show
(
std
::
string
svg_name
)
{
for
(
auto
&&
i
:
m_memory_chunk_recorder
)
{
if
(
i
.
id
>=
m_weight_chunk_id
)
{
break
;
}
size_t
begin
=
i
.
time_begin
,
end
=
i
.
time_end
;
if
(
i
.
is_overwrite
)
{
begin
++
;
}
for
(
size_t
j
=
begin
;
j
<
end
;
j
++
)
{
m_opr_seq_recorder
.
at
(
j
).
size
+=
i
.
size_orig
;
}
}
// log peak memory size, where it is reached and which chunks constitute it.
mgb_log
(
"peak_mem_size = %zu
\n
"
,
m_peak_mem_size
);
size_t
max_size
=
0
;
std
::
vector
<
size_t
>
opr_ids
;
for
(
auto
&&
i
:
m_opr_seq_recorder
)
{
if
(
i
.
size
==
max_size
)
{
opr_ids
.
push_back
(
i
.
id
);
}
else
if
(
i
.
size
>
max_size
)
{
max_size
=
i
.
size
;
opr_ids
.
clear
();
opr_ids
.
push_back
(
i
.
id
);
}
}
auto
opr2chunk
=
get_chunk_construct
(
opr_ids
);
mgb_log
(
"oprs reach the peak memory:
\n
"
);
for
(
auto
&&
i
:
opr_ids
)
{
mgb_log
(
"opr id = %zu
\n
"
,
i
);
}
mgb_log
(
"More details:
\n
"
);
for
(
size_t
i
=
0
;
i
<
opr2chunk
.
size
();
i
++
)
{
mgb_log
(
"opr id = %zu
\n
"
,
opr_ids
.
at
(
i
));
if
(
i
+
1
<
opr2chunk
.
size
()
&&
opr2chunk
.
at
(
i
)
==
opr2chunk
.
at
(
i
+
1
))
{
continue
;
}
for
(
size_t
j
=
0
;
j
<
opr2chunk
.
at
(
i
).
size
();
j
++
)
{
auto
&&
chunk
=
m_memory_chunk_recorder
.
at
(
opr2chunk
.
at
(
i
).
at
(
j
));
mgb_log
(
"[memory_chunk_id=%zu, size=%zu B, "
"[life_begin=%zu,life_end=%zu), owner_opr_name=%s]
\n
"
,
chunk
.
id
,
chunk
.
size_orig
,
chunk
.
time_begin
,
chunk
.
time_end
,
m_opr_seq_recorder
.
at
(
chunk
.
time_begin
).
name
.
c_str
());
}
}
dump_svg
(
svg_name
);
}
std
::
vector
<
std
::
vector
<
size_t
>>
StaticMemRecorder
::
get_chunk_construct
(
std
::
vector
<
size_t
>
opr_ids
)
{
std
::
vector
<
std
::
vector
<
size_t
>>
chunk_ids
;
chunk_ids
.
resize
(
opr_ids
.
size
());
for
(
auto
&&
i
:
m_memory_chunk_recorder
)
{
if
(
i
.
id
>=
m_weight_chunk_id
)
{
break
;
}
size_t
begin
=
i
.
time_begin
,
end
=
i
.
time_end
;
if
(
i
.
is_overwrite
)
{
begin
=
begin
+
1
;
}
if
(
opr_ids
.
front
()
>=
end
||
opr_ids
.
back
()
<
begin
)
{
continue
;
}
for
(
size_t
k
=
0
;
k
<
opr_ids
.
size
();
k
++
)
{
if
(
opr_ids
.
at
(
k
)
>=
end
)
{
break
;
}
else
if
(
opr_ids
.
at
(
k
)
>=
begin
)
{
chunk_ids
.
at
(
k
).
push_back
(
i
.
id
);
}
}
}
return
chunk_ids
;
}
\ No newline at end of file
src/plugin/include/megbrain/plugin/static_mem_record.h
0 → 100644
浏览文件 @
b06b5899
/**
* \file src/plugin/include/megbrain/plugin/static_mem_record.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/utils/metahelper.h"
namespace
mgb
{
namespace
cg
{
class
StaticMemRecorder
:
public
NonCopyableObj
{
public:
static
StaticMemRecorder
&
Instance
()
{
static
StaticMemRecorder
StaticMemRecorder
;
return
StaticMemRecorder
;
}
struct
opr_record
{
size_t
id
,
size
;
std
::
string
name
;
};
struct
memory_chunk_record
{
size_t
id
,
size_orig
,
time_begin
,
time_end
,
addr_begin
,
addr_end
,
overwrite_dest_id
;
bool
is_overwrite
;
std
::
string
owner_var_name
;
};
void
active
()
{
m_is_record
=
true
;
}
bool
valid
()
{
return
m_is_record
;
}
void
clear_opr_seq
()
{
m_opr_seq_recorder
.
clear
();
}
void
regist_opr_seq
(
opr_record
opr
)
{
m_opr_seq_recorder
.
push_back
(
opr
);
}
void
clear_memory_chunk
()
{
m_memory_chunk_recorder
.
clear
();
}
void
regist_memory_chunk
(
memory_chunk_record
mcr
)
{
m_memory_chunk_recorder
.
push_back
(
mcr
);
}
void
regist_memory_chunk_owner_var_name
(
size_t
id
,
std
::
string
name
)
{
m_memory_chunk_recorder
.
at
(
id
).
owner_var_name
=
name
;
}
void
regist_peak_mem_size
(
size_t
size
)
{
m_peak_mem_size
=
size
;
}
const
size_t
&
peak_mem_size
()
{
return
m_peak_mem_size
;
}
void
set_sum_mem_size
(
size_t
size
)
{
m_sum_mem_size
=
size
;
}
const
size_t
&
sum_mem_size
()
{
return
m_sum_mem_size
;
}
const
size_t
&
set_weight_chunk_id
()
{
m_weight_chunk_id
=
m_memory_chunk_recorder
.
size
();
return
m_weight_chunk_id
;
}
const
size_t
&
weight_chunk_id
()
{
return
m_weight_chunk_id
;
}
void
dump_svg
(
std
::
string
svg_name
);
void
show
(
std
::
string
svg_name
);
private:
bool
m_is_record
=
false
;
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are
// weights memory chunks
size_t
m_peak_mem_size
,
m_sum_mem_size
,
m_weight_chunk_id
;
std
::
vector
<
opr_record
>
m_opr_seq_recorder
;
std
::
vector
<
memory_chunk_record
>
m_memory_chunk_recorder
;
std
::
vector
<
std
::
vector
<
size_t
>>
get_chunk_construct
(
std
::
vector
<
size_t
>
opr_ids
);
};
}
// namespace cg
}
// namespace mgb
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录