Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7f3f9a94
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7f3f9a94
编写于
12月 24, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/core): add shape hint for graph optimization
GitOrigin-RevId: eaad25a7efe61c388f0e45fa780fdbbb12402ae7
上级
d1fbec4f
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
208 addition
and
0 deletion
+208
-0
src/core/impl/graph/cg_impl.cpp
src/core/impl/graph/cg_impl.cpp
+10
-0
src/core/impl/graph/helper.cpp
src/core/impl/graph/helper.cpp
+3
-0
src/core/include/megbrain/graph/helper.h
src/core/include/megbrain/graph/helper.h
+1
-0
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+5
-0
src/gopt/impl/misc.cpp
src/gopt/impl/misc.cpp
+26
-0
src/gopt/include/megbrain/gopt/misc.h
src/gopt/include/megbrain/gopt/misc.h
+6
-0
src/opr/impl/utility.cpp
src/opr/impl/utility.cpp
+53
-0
src/opr/impl/utility.oprdecl
src/opr/impl/utility.oprdecl
+11
-0
src/opr/impl/utility.sereg.h
src/opr/impl/utility.sereg.h
+11
-0
src/opr/include/megbrain/opr/utility.h
src/opr/include/megbrain/opr/utility.h
+21
-0
src/opr/test/utility.cpp
src/opr/test/utility.cpp
+61
-0
未找到文件。
src/core/impl/graph/cg_impl.cpp
浏览文件 @
7f3f9a94
...
...
@@ -514,6 +514,16 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
optimizer
.
add_passes_for_optimize_options
(
options
().
graph_opt
,
true
);
optimizer
.
apply_inplace
(
dest_vars
);
if
(
sopr_stat
.
has_shape_hint
)
{
// FIXME(zhangxuanrun): strictly speaking, it could and has to remove
// ShapeHints even they were occured in subgraph
mgb_assert
(
!
m_parent_graph
,
"can not use ShapeHint in subgraph"
);
// always need remove shape hint
gopt
::
GraphOptimizer
opt
;
opt
.
add_pass
<
gopt
::
RemoveShapeHintPass
>
();
opt
.
apply_inplace
(
dest_vars
);
}
const
OprNodeArray
*
opr_seq
=
nullptr
;
CompSeqExtraInfo
extra_info
;
cmpnt
.
seq_comp_node_opt
.
optimize_comp_nodes
(
dest_vars
);
...
...
src/core/impl/graph/helper.cpp
浏览文件 @
7f3f9a94
...
...
@@ -564,6 +564,9 @@ void ExtraDependencyMerger::on_opr(OperatorNodeBase* opr) {
sopr_stat
->
has_virtual_grad
=
true
;
}
#endif
if
(
sopr_stat
&&
opr
->
same_type
<
opr
::
ShapeHint
>
())
{
sopr_stat
->
has_shape_hint
=
true
;
}
}
}
...
...
src/core/include/megbrain/graph/helper.h
浏览文件 @
7f3f9a94
...
...
@@ -149,6 +149,7 @@ SymbolVar current_grad_target(ComputingGraph &graph);
struct
SpecialOprStat
{
bool
has_virtual_grad
=
false
;
bool
has_shape_hint
=
false
;
};
/*!
...
...
src/gopt/impl/framework.cpp
浏览文件 @
7f3f9a94
...
...
@@ -678,6 +678,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_pass
<
ParamMergePass
>
();
add_pass
<
FuseDeconvCvtPass
>
();
}
if
(
inference_opt
)
{
// remove shape hint after inference optimization
add_pass
<
RemoveShapeHintPass
>
();
}
return
*
this
;
}
...
...
src/gopt/impl/misc.cpp
浏览文件 @
7f3f9a94
...
...
@@ -1055,4 +1055,30 @@ void PackAllReduceReplacePass::insert_packed_oprs(
#endif // MGB_ENABLE_OPR_MM
/* ======================= RemoveShapeHintPass ====================== */
const
char
*
RemoveShapeHintPass
::
name
()
const
{
return
"remove_shape_hint"
;
}
void
RemoveShapeHintPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"RemoveShapeHintPass::apply"
)
opt
.
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
CHECK_DTYPE
);
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
if
(
auto
sh
=
try_cast_as_op
<
opr
::
ShapeHint
>
(
opr
))
{
auto
inp
=
rewriter
.
get_var
(
sh
->
input
(
0
));
rewriter
.
replace_var
(
sh
->
output
(
0
),
inp
,
mgb_cstr_log
(
"remove shape hint"
));
return
;
}
rewriter
.
auto_replace_outputs
(
opr
);
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/include/megbrain/gopt/misc.h
浏览文件 @
7f3f9a94
...
...
@@ -141,6 +141,12 @@ namespace gopt {
ThinHashMap
<
VarNode
*
,
VarNode
*>&
replace_map
,
int
priority
);
};
class
RemoveShapeHintPass
final
:
public
Pass
{
public:
const
char
*
name
()
const
override
;
void
apply
(
OptState
&
opt
)
const
override
;
};
}
// namespace gopt
}
// namespace mgb
...
...
src/opr/impl/utility.cpp
浏览文件 @
7f3f9a94
...
...
@@ -840,4 +840,57 @@ SymbolVar RequireInputDynamicStorage::make(const SymbolVar input,
input
.
node
(),
config
);
}
/* ===================== ShapeHint ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ShapeHint
);
void
ShapeHint
::
scn_do_execute
()
{
mgb_assert
(
0
);
}
void
ShapeHint
::
init_output_static_infer_desc
()
{
using
namespace
cg
::
static_infer
;
auto
infer_shp
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
)
->
bool
{
const
TensorShape
*
inferred
=
nullptr
;
if
(
cg
::
is_static_var_shape
(
input
(
0
)))
{
inferred
=
owner_graph
()
->
static_infer_manager
().
infer_shape_fallible
(
input
(
0
));
}
if
(
inferred
)
{
dest
=
*
inferred
;
if
(
!
dest
.
eq_shape
(
m_shape
))
{
mgb_log_warn
(
"given shape hint on var %s is different from inferred shape, "
"hint %s vs inferred %s"
,
cg
::
dump_var_info
({
input
(
0
)}).
c_str
(),
m_shape
.
to_string
().
c_str
(),
dest
.
to_string
().
c_str
());
}
}
else
{
dest
=
m_shape
;
}
return
dest
.
ndim
;
};
owner_graph
()
->
static_infer_manager
().
register_shape_infer
(
output
(
0
),
{
m_is_const
?
SourceType
::
CONSTANT
:
SourceType
::
MUTABLE
,
{},
infer_shp
});
}
ShapeHint
::
ShapeHint
(
VarNode
*
inp
,
TensorShape
shape
,
bool
is_const
,
const
OperatorNodeConfig
&
config
)
:
Super
{
inp
->
owner_graph
(),
config
,
"shape_hint"
,
{
inp
}},
m_shape
(
shape
),
m_is_const
(
is_const
)
{
add_input
({
inp
});
add_output
(
None
);
}
SymbolVar
ShapeHint
::
make
(
SymbolVar
inp
,
TensorShape
shape
,
bool
is_const
,
const
OperatorNodeConfig
&
config
)
{
return
inp
.
insert_single_output_opr
<
ShapeHint
>
(
inp
.
node
(),
shape
,
is_const
,
config
);
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ShapeHint
)
{
// since the shape of output(0) could be inferred, no need to
// give hint on out_grad(0)
return
out_grad
.
at
(
0
);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/utility.oprdecl
浏览文件 @
7f3f9a94
...
...
@@ -90,4 +90,15 @@ decl_opr(
params
=
'Empty'
)
decl_raw_opr
(
'shape_hint'
,
desc
=
'a special op providing shape hint only used in graph compilation'
,
inputs
=
[
Doc
(
'input'
,
'input var the shape hint was on'
),
Doc
(
'shape'
,
'given hint shape'
,
'list of int'
),
Doc
(
'is_const'
,
'whether treat given shape as constant'
,
'bool'
,
'False'
)],
body
=
[
'output = _mgb._Opr.shape_hint(input, shape, is_const, config)'
]
)
# vim: ft=python
src/opr/impl/utility.sereg.h
浏览文件 @
7f3f9a94
...
...
@@ -153,6 +153,17 @@ namespace opr {
#endif
MGB_SEREG_OPR
(
PersistentOutputStorage
,
1
);
cg
::
OperatorNodeBase
*
opr_shallow_copy_shape_hint
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
ShapeHint
>
();
mgb_assert
(
inputs
.
size
()
==
1
);
return
ShapeHint
::
make
(
inputs
[
0
],
opr
.
shape
(),
opr
.
is_const
(),
config
)
.
node
()
->
owner_opr
();
}
MGB_REG_OPR_SHALLOW_COPY
(
ShapeHint
,
opr_shallow_copy_shape_hint
);
}
// namespace opr
}
// namespace mgb
...
...
src/opr/include/megbrain/opr/utility.h
浏览文件 @
7f3f9a94
...
...
@@ -512,6 +512,27 @@ public:
const
OperatorNodeConfig
&
config
=
{});
}
;
/*
* \brief a special op providing shape hint only used in graph compilation (gopt)
*/
MGB_DEFINE_OPR_CLASS
(
ShapeHint
,
cg
::
SingleCNOperatorNodeBase
)
// {
TensorShape
m_shape
;
bool
m_is_const
;
void
scn_do_execute
()
override
;
void
init_output_static_infer_desc
()
override
;
public
:
ShapeHint
(
VarNode
*
inp
,
const
TensorShape
shape
,
bool
is_const
,
const
OperatorNodeConfig
&
config
);
static
SymbolVar
make
(
SymbolVar
inp
,
const
TensorShape
shape
,
bool
is_const
=
false
,
const
OperatorNodeConfig
&
config
=
{});
TensorShape
shape
()
const
{
return
m_shape
;
}
bool
is_const
()
const
{
return
m_is_const
;
}
}
;
}
// namespace opr
}
// namespace mgb
...
...
src/opr/test/utility.cpp
浏览文件 @
7f3f9a94
...
...
@@ -12,6 +12,7 @@
#include "megbrain/opr/utility.h"
#include "megbrain/gopt/framework.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/test/helper.h"
using
namespace
mgb
;
...
...
@@ -467,4 +468,64 @@ TEST(TestOprUtility, RequireInputDynamicStorage) {
ASSERT_LT
(
nr_opr
(
func
),
nr0
);
}
TEST
(
TestOprUtility
,
ShapeHint
)
{
HostTensorGenerator
<>
gen
;
HostTensorGenerator
<
dtype
::
Int32
>
gen_int
;
constexpr
size_t
length
=
233
;
{
// basic
for
(
bool
dynamic
:
{
false
,
true
})
{
auto
host_x
=
gen_int
({
length
});
auto
graph
=
ComputingGraph
::
make
();
SymbolVar
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
x_shape_hint
,
y
;
if
(
dynamic
)
{
x_shape_hint
=
opr
::
ShapeHint
::
make
(
opr
::
MarkDynamicVar
::
make
(
x
),
TensorShape
{
length
*
2
});
}
else
{
x_shape_hint
=
opr
::
ShapeHint
::
make
(
x
,
TensorShape
{
length
*
2
});
}
y
=
x_shape_hint
*
2
+
1
;
if
(
dynamic
)
{
ASSERT_TRUE
(
y
.
shape
().
eq_shape
({
length
*
2
}));
}
else
{
ASSERT_TRUE
(
y
.
shape
().
eq_shape
({
length
}));
}
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
func
->
execute
();
ASSERT_TRUE
(
host_y
.
shape
().
eq_shape
({
length
}));
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
ASSERT_EQ
((
*
host_x
->
ptr
<
int32_t
>
())
*
2
+
1
,
*
host_y
.
ptr
<
int32_t
>
());
}
}
}
{
// shallow copy
auto
graph
=
ComputingGraph
::
make
();
auto
host_x
=
gen
({
length
});
SymbolVar
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
ShapeHint
::
make
(
x
,
TensorShape
{
length
*
2
}),
x_unknown
=
opr
::
MarkDynamicVar
::
make
(
x
),
y_copy
=
serialization
::
copy_opr_shallow
(
*
y
.
node
()
->
owner_opr
(),
{
x_unknown
.
node
()})
->
output
(
0
);
ASSERT_TRUE
(
y
.
shape
().
eq_shape
({
length
}));
ASSERT_TRUE
(
y_copy
.
shape
().
eq_shape
({
length
*
2
}));
}
{
// grad
auto
host_x
=
gen
({
1
}),
host_y
=
gen
({
1
});
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_y
),
x_shape_hint
=
opr
::
ShapeHint
::
make
(
opr
::
MarkDynamicVar
::
make
(
x
),
TensorShape
{
1
}),
y_shape_hint
=
opr
::
ShapeHint
::
make
(
y
,
TensorShape
{
1
}),
t
=
x_shape_hint
*
y_shape_hint
;
HostTensorND
host_gx
,
host_gy
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
cg
::
grad
(
t
,
x
),
host_gx
),
make_callback_copy
(
cg
::
grad
(
t
,
y
),
host_gy
)
});
func
->
execute
();
ASSERT_TRUE
(
host_gx
.
shape
().
is_scalar
());
ASSERT_TRUE
(
host_gy
.
shape
().
is_scalar
());
ASSERT_FLOAT_EQ
(
*
host_x
->
ptr
<
float
>
(),
*
host_gy
.
ptr
<
float
>
());
ASSERT_FLOAT_EQ
(
*
host_y
->
ptr
<
float
>
(),
*
host_gx
.
ptr
<
float
>
());
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录