Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f3378100
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f3378100
编写于
1月 14, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): enable output dynamic memory alloc
GitOrigin-RevId: c809629034f87dfeacd6586ca96aa9c110e0a3c9
上级
e82fa4ec
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
90 addition
and
4 deletion
+90
-4
src/core/impl/graph/bases.cpp
src/core/impl/graph/bases.cpp
+2
-0
src/core/impl/graph/cg_impl.cpp
src/core/impl/graph/cg_impl.cpp
+18
-1
src/core/impl/graph/cg_impl.h
src/core/impl/graph/cg_impl.h
+1
-0
src/core/include/megbrain/graph/bases.h
src/core/include/megbrain/graph/bases.h
+26
-1
src/core/include/megbrain/graph/cg.h
src/core/include/megbrain/graph/cg.h
+7
-0
src/core/include/megbrain/graph/var_node.h
src/core/include/megbrain/graph/var_node.h
+1
-2
src/core/test/graph/misc.cpp
src/core/test/graph/misc.cpp
+35
-0
未找到文件。
src/core/impl/graph/bases.cpp
浏览文件 @
f3378100
...
...
@@ -14,6 +14,8 @@
using
namespace
mgb
::
cg
;
MGB_TYPEINFO_OBJ_IMPL
(
OutputVarsUserData
);
GraphNodeBase
::
GraphNodeBase
(
ComputingGraph
*
owner_graph
)
:
m_owner_graph
{
owner_graph
}
{
...
...
src/core/impl/graph/cg_impl.cpp
浏览文件 @
f3378100
...
...
@@ -563,6 +563,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
std
::
unordered_map
<
CallbackCallerKey
,
CallbackCallerVal
,
CallbackCallerKey
::
Hash
>
opr2vars
;
using
F
=
VarNode
::
Flag
;
if
(
dest_vars
[
0
]
->
owner_graph
()
->
options
().
force_output_dynamic_alloc
)
{
for
(
auto
&&
i
:
dest_vars
)
{
if
(
!
i
->
contain_flag
(
F
::
NO_SYS_MEM_ALLOC
|
F
::
NO_SYS_STATIC_MEM_ALLOC
))
{
mgb_assert
(
!
i
->
contain_flag
(
F
::
DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC
),
"Can not force graph output dynamic alloc with "
"DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s"
,
i
->
cname
());
i
->
add_flag
(
F
::
NO_SYS_STATIC_MEM_ALLOC
);
}
i
->
add_flag
(
F
::
NO_MEM_RECLAIM
);
}
}
for
(
size_t
i
=
0
;
i
<
out_spec
.
size
();
++
i
)
{
auto
&&
cb
=
out_spec
[
i
].
second
;
if
(
cb
)
{
...
...
@@ -641,13 +657,14 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
init_opr_seq
();
#endif // MGB_ENABLE_SUBLINEAR
return
{
std
::
move
(
extra_info
),
opr_seq
};
return
{
std
::
move
(
extra_info
),
opr_seq
,
std
::
move
(
dest_vars
)
};
}
std
::
unique_ptr
<
AsyncExecutable
>
ComputingGraphImpl
::
compile_commit
(
CompileState
state
)
{
auto
comp_seq
=
std
::
make_unique
<
ComputingSequence
>
(
shared_from_this
());
comp_seq
->
extra_info
=
std
::
move
(
state
.
extra_info
);
comp_seq
->
set_output_vars
(
state
.
dest_vars
);
auto
opr_seq
=
state
.
opr_seq
;
auto
&&
cmpnt
=
components
();
...
...
src/core/impl/graph/cg_impl.h
浏览文件 @
f3378100
...
...
@@ -38,6 +38,7 @@ class ComputingGraphImpl final : public ComputingGraph {
//! extra info that must be set in the ComputingSequence
CompSeqExtraInfo
extra_info
;
const
OprNodeArray
*
opr_seq
=
nullptr
;
VarNodeArray
dest_vars
;
};
struct
CallbackCallerKey
{
...
...
src/core/include/megbrain/graph/bases.h
浏览文件 @
f3378100
...
...
@@ -67,9 +67,10 @@ namespace static_infer {
};
using
GraphError
=
mgb
::
GraphError
;
class
VarNode
;
class
OperatorNodeBase
;
class
ComputingGraph
;
using
VarNodeArray
=
mgb
::
SmallVector
<
VarNode
*>
;
/*!
* \brief Base class for a node in the graph.
*
...
...
@@ -102,6 +103,17 @@ class GraphNodeBase: public json::Serializable, public NonCopyableObj {
}
};
class
OutputVarsUserData
final
:
public
mgb
::
UserDataContainer
::
UserData
{
MGB_TYPEINFO_OBJ_DECL
;
private:
VarNodeArray
m_output_vars
;
public:
void
set_output_vars
(
VarNodeArray
vars
)
{
m_output_vars
=
std
::
move
(
vars
);
}
const
VarNodeArray
&
get_output_vars
()
const
{
return
m_output_vars
;
}
};
/*!
* \brief an object that executes asynchronously
*/
...
...
@@ -165,6 +177,19 @@ class AsyncExecutable : public json::Serializable,
UserDataContainer
&
user_data
()
{
return
m_user_data
;
}
void
set_output_vars
(
const
VarNodeArray
&
vars
)
{
std
::
shared_ptr
<
OutputVarsUserData
>
ud
=
std
::
make_shared
<
OutputVarsUserData
>
();
ud
->
set_output_vars
(
vars
);
m_user_data
.
add_user_data
(
ud
);
}
const
VarNodeArray
&
get_output_vars
()
const
{
auto
output_vars_pair
=
m_user_data
.
get_user_data
<
OutputVarsUserData
>
();
return
(
*
(
output_vars_pair
.
first
))
->
get_output_vars
();
}
};
...
...
src/core/include/megbrain/graph/cg.h
浏览文件 @
f3378100
...
...
@@ -399,6 +399,12 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! force dynamic memory alloc for all vars
bool
force_dynamic_alloc
=
false
;
/*!
* force dynamic memory alloc for output vars which are used as
* CallbackCaller input when call compile() function
*/
bool
force_output_dynamic_alloc
=
false
;
//! whether to perform var sanity check on first run
bool
var_sanity_check_first_run
=
true
;
...
...
@@ -657,6 +663,7 @@ SymbolVar SymbolVar::insert_single_output_opr(Args &&...args) const {
std
::
make_unique
<
Node
>
(
std
::
forward
<
Args
>
(
args
)...))
->
output
(
0
);
}
}
// namespace cg
}
// namespace mgb
...
...
src/core/include/megbrain/graph/var_node.h
浏览文件 @
f3378100
...
...
@@ -34,7 +34,7 @@ namespace static_infer {
class
StaticInferManagerImpl
;
}
class
VarNode
;
class
VarDevMemDefragmenter
;
class
EagerEvalManager
;
...
...
@@ -685,7 +685,6 @@ bool VarNode::contain_flag(Flag flag) const {
return
static_cast
<
bool
>
(
m_flag
&
flag
);
}
using
VarNodeArray
=
mgb
::
SmallVector
<
VarNode
*>
;
using
VarNodeSet
=
ThinHashSet
<
VarNode
*>
;
DType
MemAllocPlan
::
dtype
()
const
{
...
...
src/core/test/graph/misc.cpp
浏览文件 @
f3378100
...
...
@@ -2287,4 +2287,39 @@ TEST(TestGraph, CallbackCaller) {
}
}
TEST
(
TestGraph
,
DynamicOutput
)
{
using
namespace
opr
;
REQUIRE_GPU
(
1
);
auto
cn0
=
CompNode
::
load
(
"gpu0"
);
constexpr
size_t
C1
=
20
,
C2
=
20
;
constexpr
size_t
C
=
C1
+
C2
;
HostTensorGenerator
<>
gen
;
auto
host_opr0
=
gen
({
C
},
cn0
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
force_output_dynamic_alloc
=
true
;
SymbolVar
opr0
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_opr0
);
auto
spl_0
=
opr
::
Split
::
make
(
opr0
,
Split
::
Options
::
make_partition
(
opr0
,
0
,
{
C1
,
C2
}));
auto
sum
=
opr
::
add
(
spl_0
[
1
],
spl_0
[
1
]);
HostTensorND
expect_sum
,
expect_spl_0_0
,
result_sum
,
result_spl_0_0
;
auto
func1
=
graph
->
compile
({
make_callback_copy
(
sum
,
expect_sum
),
make_callback_copy
(
spl_0
[
0
],
expect_spl_0_0
)});
func1
->
execute
().
wait
();
auto
func2
=
graph
->
compile
({{
sum
,
nullptr
},
{
spl_0
[
0
],
nullptr
}});
auto
&&
dest_vars
=
func2
->
get_output_vars
();
func2
->
execute
().
wait
();
result_sum
.
copy_from
(
dest_vars
[
0
]
->
dev_tensor
()).
sync
();
MGB_ASSERT_TENSOR_NEAR
(
expect_sum
,
result_sum
,
1e-4
);
result_spl_0_0
.
copy_from
(
dest_vars
[
1
]
->
dev_tensor
()).
sync
();
MGB_ASSERT_TENSOR_NEAR
(
expect_spl_0_0
,
result_spl_0_0
,
1e-4
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录