Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
50db9b84
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
50db9b84
编写于
5月 27, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(gopt): fix paramfuse if the endpoint is const
GitOrigin-RevId: f666f6d70037debbff34551149d04b0bd8c256f4
上级
35bc0e1f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
100 addition
and
74 deletion
+100
-74
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+2
-3
src/gopt/impl/inference.cpp
src/gopt/impl/inference.cpp
+34
-49
src/gopt/include/megbrain/gopt/framework.h
src/gopt/include/megbrain/gopt/framework.h
+17
-21
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+47
-1
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
50db9b84
...
@@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
...
@@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
auto
&&
ins
=
m_varmap
.
insert
({
out0
[
i
],
{
true
,
nullptr
}});
auto
&&
ins
=
m_varmap
.
insert
({
out0
[
i
],
{
true
,
nullptr
}});
mgb_assert
(
ins
.
second
||
ins
.
first
->
second
.
first
,
mgb_assert
(
ins
.
second
||
ins
.
first
->
second
.
first
,
"opr output already replaced"
);
"opr output already replaced"
);
// handle repeated call on the same opr
// handle repeated call on the same opr
ins
.
first
->
second
.
second
=
out1
[
i
];
ins
.
first
->
second
.
second
=
out1
[
i
];
on_var_replaced
(
out0
[
i
],
out1
[
i
],
nullptr
);
on_var_replaced
(
out0
[
i
],
out1
[
i
],
nullptr
);
...
@@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
...
@@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
/* ================ ConstVarPropogateBase ================ */
/* ================ ConstVarPropogateBase ================ */
ConstVarPropogate
Base
::
AddOprResult
ConstVarPropogateBas
e
::
add_opr
(
ConstVarPropogate
::
AddOprResult
ConstVarPropogat
e
::
add_opr
(
OperatorNodeBase
*
opr
)
{
OperatorNodeBase
*
opr
)
{
using
ProfFlag
=
OperatorNodeBase
::
NodeProp
::
Flag
;
using
ProfFlag
=
OperatorNodeBase
::
NodeProp
::
Flag
;
auto
&&
info
=
m_oprinfo
[
opr
];
auto
&&
info
=
m_oprinfo
[
opr
];
...
@@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(
...
@@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(
#endif
#endif
info
.
max_size
=
max_input_size
;
info
.
max_size
=
max_input_size
;
info
.
is_const
=
true
;
info
.
is_const
=
true
;
on_midconst_opr
(
opr
,
max_input_size
);
}
}
return
make_ret
();
return
make_ret
();
}
}
...
...
src/gopt/impl/inference.cpp
浏览文件 @
50db9b84
...
@@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const {
...
@@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const {
/* ================ ParamFusePass ================ */
/* ================ ParamFusePass ================ */
class
ParamFusePass
::
ConstVarPropogateWithSizeCheck
final
:
public
ConstVarPropogateBase
{
public:
//! rewrite a var; reader == nullptr means needed by endpoint
using
VarRewriter
=
std
::
function
<
void
(
VarNode
*
var
,
OperatorNodeBase
*
reader
)
>
;
ConstVarPropogateWithSizeCheck
(
const
ParamFusePass
&
pf
,
OptState
&
opt_state
,
const
VarRewriter
&
rewriter
)
:
ConstVarPropogateBase
{
ConstVarType
::
IMMUTABLE_AND_PARAM
},
m_owner
{
pf
},
m_opt_state
{
opt_state
},
m_rewriter
{
rewriter
}
{
}
private:
const
ParamFusePass
&
m_owner
;
OptState
&
m_opt_state
;
VarRewriter
m_rewriter
;
void
on_midconst_opr
(
OperatorNodeBase
*
opr
,
size_t
max_src_size
)
override
{
for
(
auto
var
:
opr
->
output
())
{
if
(
var
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
continue
;
auto
osize
=
var_mem_size
(
var
);
if
(
osize
>=
max_src_size
&&
osize
-
max_src_size
>
m_owner
.
m_param_grow_limit
)
{
return
;
}
// const oprs should be evaluated when output is used by another
// non-const opr or output is needed by the user
if
(
m_opt_state
.
graph
().
endpoint_contain
(
var
))
{
m_rewriter
(
var
,
nullptr
);
}
}
}
};
/*!
/*!
* \brief get name for new param
* \brief get name for new param
*/
*/
...
@@ -565,9 +521,15 @@ const char* ParamFusePass::name() const {
...
@@ -565,9 +521,15 @@ const char* ParamFusePass::name() const {
void
ParamFusePass
::
apply
(
OptState
&
state
)
const
{
void
ParamFusePass
::
apply
(
OptState
&
state
)
const
{
auto
rewriter
=
state
.
graph
().
make_rewriter
();
auto
rewriter
=
state
.
graph
().
make_rewriter
();
auto
cg
=
state
.
graph
().
comp_graph
();
auto
cg
=
state
.
graph
().
comp_graph
();
ConstVarPropogate
cvprop
{
ConstVarType
::
IMMUTABLE_AND_PARAM
};
state
.
graph
().
iter
([
&
cvprop
](
OperatorNodeBase
*
opr
)
{
cvprop
.
add_opr
(
opr
);
});
ThinHashSet
<
VarNode
*>
processed_var
;
ThinHashSet
<
VarNode
*>
processed_var
;
VarNamer
var_namer
;
VarNamer
var_namer
;
// reader: null if used as endvar
// reader: null if used as endvar
auto
replace_single_var
=
[
&
](
VarNode
*
var
,
OperatorNodeBase
*
reader
)
{
auto
replace_single_var
=
[
&
](
VarNode
*
var
,
OperatorNodeBase
*
reader
)
{
if
(
!
processed_var
.
insert
(
var
).
second
)
if
(
!
processed_var
.
insert
(
var
).
second
)
...
@@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const {
...
@@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const {
rewriter
.
replace_var
(
var
,
new_var
.
node
(),
log
.
c_str
());
rewriter
.
replace_var
(
var
,
new_var
.
node
(),
log
.
c_str
());
};
};
ConstVarPropogateWithSizeCheck
cvprop
{
*
this
,
state
,
replace_single_var
};
auto
replace_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
auto
add_ret
=
cvprop
.
opr_rst
(
opr
);
auto
add_ret
=
cvprop
.
add_opr
(
opr
);
if
(
!
add_ret
.
all_const_inp
&&
add_ret
.
has_midconst_inp
)
{
if
(
!
add_ret
.
all_const_inp
&&
add_ret
.
has_midconst_inp
)
{
for
(
auto
i
:
opr
->
input
())
{
for
(
auto
i
:
opr
->
input
())
{
if
(
cvprop
.
is_midconst
(
i
))
{
if
(
cvprop
.
is_midconst
(
i
))
{
...
@@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const {
...
@@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const {
}
}
}
}
rewriter
.
auto_replace_outputs
(
opr
);
rewriter
.
auto_replace_outputs
(
opr
);
//! we should deal with midconst var after auto_replace_outputs, as
//! on_midconst_opr will replace the endpoint output which may cause
//! double replace.
if
(
add_ret
.
all_const_inp
)
{
for
(
auto
var
:
opr
->
output
())
{
if
(
var
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
continue
;
auto
osize
=
ConstVarPropogate
::
var_mem_size
(
var
);
if
(
osize
>=
cvprop
.
max_size
(
opr
)
&&
osize
-
cvprop
.
max_size
(
opr
)
>
m_param_grow_limit
)
{
return
;
}
// const oprs should be evaluated when output is used by another
// non-const opr or output is needed by the user
if
(
state
.
graph
().
endpoint_contain
(
var
))
{
replace_single_var
(
var
,
nullptr
);
}
}
}
};
};
state
.
graph
().
iter
(
on
_opr
);
state
.
graph
().
iter
(
replace
_opr
);
rewriter
.
apply_inplace
();
rewriter
.
apply_inplace
();
}
}
...
...
src/gopt/include/megbrain/gopt/framework.h
浏览文件 @
50db9b84
...
@@ -490,28 +490,17 @@ namespace gopt {
...
@@ -490,28 +490,17 @@ namespace gopt {
* Usually you would want to use ConstVarPropogate, and this base class
* Usually you would want to use ConstVarPropogate, and this base class
* exists to avoid virtual dtor while allowing polymorphism.
* exists to avoid virtual dtor while allowing polymorphism.
*/
*/
class
ConstVarPropogateBase
{
class
ConstVarPropogate
{
protected:
~
ConstVarPropogateBase
()
=
default
;
//! memory usage of a var
static
size_t
var_mem_size
(
VarNode
*
var
)
{
return
var
->
dtype
().
size
(
var
->
shape
().
total_nr_elems
());
}
//! called after a const but non-source opr is visited
virtual
void
on_midconst_opr
(
OperatorNodeBase
*
opr
,
size_t
max_src_size
)
{
MGB_MARK_USED_VAR
(
opr
);
MGB_MARK_USED_VAR
(
max_src_size
);
}
public:
public:
explicit
ConstVarPropogate
Base
(
ConstVarType
const_var_type
)
:
explicit
ConstVarPropogate
(
ConstVarType
const_var_type
)
:
m_const_var_type
{
const_var_type
}
m_const_var_type
{
const_var_type
}
{
{
}
}
ConstVarPropogate
()
=
default
;
~
ConstVarPropogate
()
=
default
;
//! note that both attrs would be false if opr is impure or it is
//! note that both attrs would be false if opr is impure or it is
//! not allowed to be replaced
//! not allowed to be replaced
struct
AddOprResult
{
struct
AddOprResult
{
...
@@ -527,12 +516,19 @@ namespace gopt {
...
@@ -527,12 +516,19 @@ namespace gopt {
AddOprResult
add_opr
(
OperatorNodeBase
*
opr
);
AddOprResult
add_opr
(
OperatorNodeBase
*
opr
);
const
AddOprResult
&
opr_rst
(
OperatorNodeBase
*
opr
)
const
{
return
m_oprinfo
.
at
(
opr
).
result
;
}
bool
is_const
(
OperatorNodeBase
*
opr
)
const
{
bool
is_const
(
OperatorNodeBase
*
opr
)
const
{
return
m_oprinfo
.
at
(
opr
).
is_const
;
return
m_oprinfo
.
at
(
opr
).
is_const
;
}
}
bool
is_const
(
VarNode
*
var
)
const
{
bool
is_const
(
VarNode
*
var
)
const
{
return
is_const
(
var
->
owner_opr
());
return
is_const
(
var
->
owner_opr
());
}
}
size_t
max_size
(
OperatorNodeBase
*
opr
)
const
{
return
m_oprinfo
.
at
(
opr
).
max_size
;
}
//! whether a var is produced by non-source const opr
//! whether a var is produced by non-source const opr
bool
is_midconst
(
OperatorNodeBase
*
opr
)
const
{
bool
is_midconst
(
OperatorNodeBase
*
opr
)
const
{
...
@@ -543,6 +539,11 @@ namespace gopt {
...
@@ -543,6 +539,11 @@ namespace gopt {
return
is_midconst
(
var
->
owner_opr
());
return
is_midconst
(
var
->
owner_opr
());
}
}
//! memory usage of a var
static
size_t
var_mem_size
(
VarNode
*
var
)
{
return
var
->
dtype
().
size
(
var
->
shape
().
total_nr_elems
());
}
private:
private:
struct
OprInfo
{
struct
OprInfo
{
bool
processed
=
false
,
is_const
=
false
;
bool
processed
=
false
,
is_const
=
false
;
...
@@ -556,11 +557,6 @@ namespace gopt {
...
@@ -556,11 +557,6 @@ namespace gopt {
};
};
class
ConstVarPropogate
final
:
public
ConstVarPropogateBase
{
public:
using
ConstVarPropogateBase
::
ConstVarPropogateBase
;
};
}
// namespace gopt
}
// namespace gopt
}
// namespace mgb
}
// namespace mgb
...
...
src/gopt/test/inference.cpp
浏览文件 @
50db9b84
...
@@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H,
...
@@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H,
#endif
#endif
}
// namespace
}
// namespace
TEST
(
TestGoptInference
,
ParamFuseConstEndPoint
)
{
constexpr
size_t
SIZE
=
23
;
HostTensorGenerator
<>
gen
;
auto
host_x
=
gen
({
SIZE
}),
host_y
=
gen
({
1
}),
host_p
=
gen
({
1
});
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
x
=
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
host_x
),
y
=
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
host_y
),
p
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_p
),
q
=
p
+
x
,
a
=
y
+
3
,
z0
=
a
+
q
,
z1
=
a
+
4
;
HostTensorND
host_z0
,
host_z1
;
SymbolVar
z0_1
,
z1_1
;
unpack_vector
(
gopt
::
GraphOptimizer
{}.
add_pass
<
gopt
::
ParamFusePass
>
().
apply
({{
z1
,
z0
}}).
endpoint_vars
(),
z1_1
,
z0_1
);
auto
func
=
graph
->
compile
({
make_callback_copy
(
z0_1
,
host_z0
),
make_callback_copy
(
z1_1
,
host_z1
)});
func
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.ParamFuseEndPoint.json"
));
func
->
execute
();
int
nr_opr
=
0
;
func
->
iter_opr_seq
([
&
](
cg
::
OperatorNodeBase
*
)
{
++
nr_opr
;
return
true
;
});
ASSERT_EQ
(
8
,
nr_opr
);
auto
px
=
host_x
->
ptr
<
float
>
(),
pz0
=
host_z0
.
ptr
<
float
>
();
auto
yv
=
host_y
->
ptr
<
float
>
()[
0
],
pv
=
host_p
->
ptr
<
float
>
()[
0
],
pz1
=
host_z1
.
ptr
<
float
>
()[
0
];
for
(
size_t
i
=
0
;
i
<
SIZE
;
++
i
)
{
MGB_ASSERT_FLOAT_EQ
(
px
[
i
]
+
yv
+
3
+
pv
,
pz0
[
i
]);
}
MGB_ASSERT_FLOAT_EQ
(
yv
+
7
,
pz1
);
}
TEST
(
TestGoptInference
,
ParamFuse
)
{
TEST
(
TestGoptInference
,
ParamFuse
)
{
constexpr
size_t
SIZE
=
23
;
constexpr
size_t
SIZE
=
23
;
HostTensorGenerator
<>
gen
;
HostTensorGenerator
<>
gen
;
...
@@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) {
...
@@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) {
func
->
execute
();
func
->
execute
();
int
nr_opr
=
0
;
int
nr_opr
=
0
;
func
->
iter_opr_seq
([
&
](
cg
::
OperatorNodeBase
*
op
)
{
++
nr_opr
;
return
true
;
});
func
->
iter_opr_seq
([
&
](
cg
::
OperatorNodeBase
*
)
{
++
nr_opr
;
return
true
;
});
ASSERT_EQ
(
6
,
nr_opr
);
ASSERT_EQ
(
6
,
nr_opr
);
auto
px
=
host_x
->
ptr
<
float
>
(),
pz
=
host_z
.
ptr
<
float
>
(),
auto
px
=
host_x
->
ptr
<
float
>
(),
pz
=
host_z
.
ptr
<
float
>
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录