Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e8a5932d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e8a5932d
编写于
7月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mgb/gopt): optimize impl of reformat builders
GitOrigin-RevId: 844b7e8d393290a6235e70d8455ea2d6d0e124cd
上级
58b8b145
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
203 addition
and
33 deletion
+203
-33
src/gopt/impl/reformat_emitter.cpp
src/gopt/impl/reformat_emitter.cpp
+66
-22
src/gopt/include/megbrain/gopt/reformat_emitter.h
src/gopt/include/megbrain/gopt/reformat_emitter.h
+14
-3
src/gopt/test/reformat_emitter.cpp
src/gopt/test/reformat_emitter.cpp
+123
-8
未找到文件。
src/gopt/impl/reformat_emitter.cpp
浏览文件 @
e8a5932d
...
...
@@ -19,6 +19,7 @@ using namespace gopt;
using
Dimension
=
megdnn
::
Dimension
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
// =================== ModifyShapeMixin ====================*/
ModifyShapeMixin
::
Pattern
ModifyShapeMixin
::
mixin_analyze
()
const
{
static
constexpr
uint32_t
UNDETERMINED_EXTENT
=
Dimension
::
UNDETERMINED_EXTENT
;
...
...
@@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const {
ModifyShapeMixin
::
Checker
ModifyShapeMixin
::
mixin_emit_checker
(
const
Pattern
&
pattern
)
const
{
auto
src
=
m_src
;
auto
checker
=
[
src
,
pattern
](
VarNode
*
var
)
{
auto
checker
=
[
src
,
pattern
](
const
VarNodeArray
&
input
)
{
mgb_assert
(
input
.
size
()
>=
1
);
const
auto
&
var
=
input
.
front
();
const
auto
&
shp
=
var
->
shape
();
if
(
shp
.
ndim
!=
src
.
ndim
)
return
false
;
...
...
@@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker(
return
checker
;
}
ReshapeEmitter
::
EmitResult
ReshapeEmitter
::
emit
()
const
{
// =================== MakeShapeEmitter ====================*/
MakeShapeEmitter
::
EmitResult
MakeShapeEmitter
::
emit
()
const
{
auto
pattern
=
mixin_analyze
();
auto
builder
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
auto
builder
=
[
pattern
](
const
VarNodeArray
&
input
)
{
mgb_assert
(
input
.
size
()
==
1
,
"number of input of MakeShapeBuilder should be 1(got:%zu)"
,
input
.
size
());
auto
sym_var
=
SymbolVar
(
input
.
front
());
auto
shp
=
opr
::
GetVarShape
::
make
(
sym_var
);
auto
cv
=
[
&
sym_var
](
int
c
)
{
return
sym_var
.
make_scalar
(
c
);
};
auto
sub
=
[
&
shp
,
&
cv
](
int
ax
)
{
...
...
@@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const {
}
}
auto
tshp
=
opr
::
Concat
::
make
(
axs
,
0
);
auto
ovar
=
opr
::
Reshape
::
make
(
sym_var
,
tshp
);
return
tshp
.
node
();
};
auto
checker
=
mixin_emit_checker
(
pattern
);
return
std
::
make_tuple
(
builder
,
checker
);
}
// =================== ReshapeEmitter ====================*/
ReshapeEmitter
::
EmitResult
ReshapeEmitter
::
emit
()
const
{
auto
pattern
=
mixin_analyze
();
auto
builder
=
[
pattern
](
const
VarNodeArray
&
input
)
{
mgb_assert
(
input
.
size
()
==
2
,
"number of input of Reshape should be 2(got:%zu)"
,
input
.
size
());
auto
ovar
=
opr
::
Reshape
::
make
(
input
[
0
],
input
[
1
]);
return
ovar
.
node
();
};
auto
checker
=
mixin_emit_checker
(
pattern
);
return
std
::
make_tuple
(
builder
,
checker
);
}
// =================== DimshuffleEmitter ====================*/
DimshuffleEmitter
::
EmitResult
DimshuffleEmitter
::
emit
()
const
{
auto
&&
pattern
=
m_pattern
;
auto
builder
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
auto
builder
=
[
pattern
](
const
VarNodeArray
&
input
)
{
mgb_assert
(
input
.
size
()
==
1
,
"number of input of Dimshuffle should be 1(got:%zu)"
,
input
.
size
());
auto
sym_var
=
SymbolVar
(
input
.
front
());
return
opr
::
Dimshuffle
::
make
(
sym_var
,
pattern
).
node
();
};
auto
checker
=
[
pattern
](
VarNode
*
var
)
{
return
var
->
shape
().
ndim
==
pattern
.
size
();
auto
checker
=
[
pattern
](
const
VarNodeArray
&
input
)
{
mgb_assert
(
input
.
size
()
==
1
,
"number of input of Dimshuffle should be 1(got:%zu)"
,
input
.
size
());
return
input
.
front
()
->
shape
().
ndim
==
pattern
.
size
();
};
return
std
::
make_tuple
(
builder
,
checker
);
}
// =================== ReformatEmitter ====================*/
ReformatEmitter
::
EmitResult
ReformatEmitter
::
emit
()
const
{
auto
ops
=
analyze
();
auto
builder
=
[
ops
](
VarNode
*
var
)
{
VarNode
*
ovar
=
var
;
for
(
const
auto
&
i
:
ops
)
{
ovar
=
i
(
ovar
);
auto
builders
=
analyze
();
auto
builder
=
[
builders
](
const
VarNodeArray
&
input
)
{
VarNode
*
var
,
*
ovar
;
var
=
ovar
=
input
.
front
();
if
(
builders
.
make_shape1
)
{
auto
shp1
=
builders
.
make_shape1
({
var
});
ovar
=
builders
.
reshape1
({
ovar
,
shp1
});
}
ovar
=
builders
.
dimshuffle
({
ovar
});
if
(
builders
.
make_shape2
)
{
auto
shp2
=
builders
.
make_shape2
({
var
});
ovar
=
builders
.
reshape2
({
ovar
,
shp2
});
}
return
ovar
;
};
...
...
@@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const {
return
std
::
make_tuple
(
builder
,
checker
);
}
SmallVector
<
ReformatEmitter
::
Builder
>
ReformatEmitter
::
analyze
()
const
{
ReformatEmitter
::
UnderlyingBuilders
ReformatEmitter
::
analyze
()
const
{
struct
Dim
{
Dimension
dim
;
int
index
;
...
...
@@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const {
i1
[
i
]
=
src_dims
[
src_perm
[
i
]].
dim
;
i2
[
i
]
=
src_dims
[
src_perm
[
permute
[
i
]]].
dim
;
}
SmallVector
<
Builder
>
ops
;
if
(
!
m_src
.
eq_shape
(
i1
))
ops
.
emplace_back
(
std
::
get
<
0
>
(
ReshapeEmitter
(
m_src
,
i1
).
emit
()));
ops
.
emplace_back
(
std
::
get
<
0
>
(
DimshuffleEmitter
(
permute
).
emit
()));
if
(
!
m_dest
.
eq_shape
(
i2
))
ops
.
emplace_back
(
std
::
get
<
0
>
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
()));
return
ops
;
UnderlyingBuilders
builders
;
if
(
!
m_src
.
eq_shape
(
i1
))
{
builders
.
make_shape1
=
std
::
move
(
std
::
get
<
0
>
(
MakeShapeEmitter
(
m_src
,
i1
).
emit
()));
builders
.
reshape1
=
std
::
move
(
std
::
get
<
0
>
(
ReshapeEmitter
(
m_src
,
i1
).
emit
()));
}
builders
.
dimshuffle
=
std
::
move
(
std
::
get
<
0
>
(
DimshuffleEmitter
(
permute
).
emit
()));
if
(
!
m_dest
.
eq_shape
(
i2
))
{
builders
.
make_shape2
=
std
::
move
(
std
::
get
<
0
>
(
MakeShapeEmitter
(
m_src
,
m_dest
).
emit
()));
builders
.
reshape2
=
std
::
move
(
std
::
get
<
0
>
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
()));
}
return
builders
;
}
// vim: syntax=cpp.doxygen
src/gopt/include/megbrain/gopt/reformat_emitter.h
浏览文件 @
e8a5932d
...
...
@@ -20,8 +20,8 @@ namespace gopt {
class
Emitter
{
public:
using
Builder
=
thin_function
<
VarNode
*
(
VarNode
*
)
>
;
using
Checker
=
thin_function
<
bool
(
VarNode
*
)
>
;
using
Builder
=
thin_function
<
VarNode
*
(
const
VarNodeArray
&
)
>
;
using
Checker
=
thin_function
<
bool
(
const
VarNodeArray
&
)
>
;
using
EmitResult
=
std
::
tuple
<
Builder
,
Checker
>
;
virtual
~
Emitter
()
=
default
;
virtual
EmitResult
emit
()
const
=
0
;
...
...
@@ -39,6 +39,14 @@ protected:
megdnn
::
NamedTensorShape
m_src
,
m_dest
;
};
class
MakeShapeEmitter
final
:
public
Emitter
,
ModifyShapeMixin
{
public:
MakeShapeEmitter
(
const
megdnn
::
NamedTensorShape
&
src
,
const
megdnn
::
NamedTensorShape
&
dest
)
:
ModifyShapeMixin
(
src
,
dest
)
{}
EmitResult
emit
()
const
override
;
};
class
ReshapeEmitter
final
:
public
Emitter
,
ModifyShapeMixin
{
public:
ReshapeEmitter
(
const
megdnn
::
NamedTensorShape
&
src
,
...
...
@@ -64,7 +72,10 @@ public:
EmitResult
emit
()
const
override
;
private:
SmallVector
<
Builder
>
analyze
()
const
;
struct
UnderlyingBuilders
{
Builder
make_shape1
,
make_shape2
,
reshape1
,
reshape2
,
dimshuffle
;
};
UnderlyingBuilders
analyze
()
const
;
};
}
// namespace gopt
}
// namespace mgb
...
...
src/gopt/test/reformat_emitter.cpp
浏览文件 @
e8a5932d
...
...
@@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) {
constexpr
size_t
N
=
12
,
C
=
64
,
H
=
7
,
W
=
7
;
HostTensorGenerator
<>
gen
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW4
);
auto
src
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW32
);
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW4
);
auto
&&
tuple
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
reformat
=
std
::
get
<
0
>
(
tuple
);
auto
checker
=
std
::
get
<
1
>
(
tuple
);
...
...
@@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) {
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
32
,
H
,
W
,
32
});
EXPECT_TRUE
(
checker
(
x
.
node
()
));
EXPECT_TRUE
(
checker
(
{
x
.
node
()}
));
auto
x_
=
mkvar
(
"x"
,
{
N
,
H
,
W
,
C
});
EXPECT_FALSE
(
checker
(
x_
.
node
()));
auto
y1
=
SymbolVar
(
reformat
(
x
.
node
()));
EXPECT_FALSE
(
checker
({
x_
.
node
()}));
auto
y1
=
SymbolVar
(
reformat
({
x
.
node
()}));
size_t
nr_shapeof
=
0
;
size_t
nr_reshape
=
0
;
cg
::
DepOprIter
{[
&
nr_shapeof
,
&
nr_reshape
](
cg
::
OperatorNodeBase
*
o
)
{
if
(
o
->
same_type
<
opr
::
GetVarShape
>
())
nr_shapeof
++
;
if
(
o
->
same_type
<
opr
::
Reshape
>
())
nr_reshape
++
;
}}
.
add
(
y1
.
node
()
->
owner_opr
());
ASSERT_EQ
(
nr_shapeof
,
1
);
ASSERT_EQ
(
nr_reshape
,
2
);
auto
y2
=
SymbolVar
(
nchw32_to_nchw4
(
x
.
node
()));
HostTensorND
t1
,
t2
;
auto
func1
=
graph
->
compile
({
make_callback_copy
(
y1
,
t1
)});
...
...
@@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) {
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
64
,
H
,
W
,
64
});
EXPECT_TRUE
(
checker
(
x
.
node
()
));
EXPECT_TRUE
(
checker
(
{
x
.
node
()}
));
auto
x_
=
mkvar
(
"x"
,
{
N
,
H
,
W
,
C
});
EXPECT_FALSE
(
checker
(
x_
.
node
()
));
auto
y
=
SymbolVar
(
reformat
(
x
.
node
()
));
EXPECT_FALSE
(
checker
(
{
x_
.
node
()}
));
auto
y
=
SymbolVar
(
reformat
(
{
x
.
node
()}
));
HostTensorND
t
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
t
)});
func
->
execute
();
}
TEST
(
TestReformatEmitter
,
EliminateRedudantReshape
)
{
constexpr
size_t
N
=
16
,
C
=
64
,
H
=
7
,
W
=
7
;
HostTensorGenerator
<>
gen
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
auto
src
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW
);
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NHWC
);
auto
&&
tuple
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
reformat
=
std
::
get
<
0
>
(
tuple
);
auto
checker
=
std
::
get
<
1
>
(
tuple
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
nchw_to_nhwc
=
[](
VarNode
*
in
)
{
auto
x
=
SymbolVar
(
in
);
auto
y
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
2
,
3
,
1
});
return
y
.
node
();
};
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
,
H
,
W
});
EXPECT_TRUE
(
checker
({
x
.
node
()}));
auto
y1
=
SymbolVar
(
reformat
({
x
.
node
()}));
size_t
nr_reshape
=
0
;
cg
::
DepOprIter
{[
&
nr_reshape
](
cg
::
OperatorNodeBase
*
o
)
{
if
(
o
->
same_type
<
opr
::
Reshape
>
())
nr_reshape
++
;
}}
.
add
(
y1
.
node
()
->
owner_opr
());
ASSERT_EQ
(
nr_reshape
,
0
);
HostTensorND
t1
,
t2
;
auto
func1
=
graph
->
compile
({
make_callback_copy
(
y1
,
t1
)});
func1
->
execute
();
auto
y2
=
SymbolVar
(
nchw_to_nhwc
(
x
.
node
()));
auto
func2
=
graph
->
compile
({
make_callback_copy
(
y2
,
t2
)});
func2
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
t1
,
t2
);
}
TEST
(
TestReformatEmitter
,
Nchw4ToNchw
)
{
constexpr
size_t
N
=
12
,
C
=
64
,
H
=
7
,
W
=
7
;
HostTensorGenerator
<>
gen
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
auto
src
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW4
);
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW
);
auto
&&
tuple
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
reformat
=
std
::
get
<
0
>
(
tuple
);
auto
checker
=
std
::
get
<
1
>
(
tuple
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
nchw4_to_nchw
=
[](
VarNode
*
in
)
{
auto
x
=
SymbolVar
(
in
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
*
4
,
sub
(
2
),
sub
(
3
)},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
1
,
4
,
2
,
3
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp
);
return
y1
.
node
();
};
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
4
,
H
,
W
,
4
});
EXPECT_TRUE
(
checker
({
x
.
node
()}));
auto
y1
=
SymbolVar
(
reformat
({
x
.
node
()}));
SmallVector
<
VarNode
*>
reshapes
;
VarNode
*
dimshuffle
;
cg
::
DepOprIter
{[
&
dimshuffle
,
&
reshapes
](
cg
::
OperatorNodeBase
*
o
)
{
if
(
o
->
same_type
<
opr
::
Reshape
>
())
{
reshapes
.
push_back
(
o
->
output
(
0
));
}
if
(
o
->
same_type
<
opr
::
Dimshuffle
>
())
dimshuffle
=
o
->
output
(
0
);
}}
.
add
(
y1
.
node
()
->
owner_opr
());
ASSERT_EQ
(
reshapes
.
size
(),
1
);
{
gopt
::
SubGraph
graph
({
y1
});
gopt
::
UniqReaderCheck
check
(
graph
);
EXPECT_TRUE
(
check
(
reshapes
[
0
]));
EXPECT_TRUE
(
dimshuffle
);
}
auto
y2
=
SymbolVar
(
nchw4_to_nchw
(
x
.
node
()));
HostTensorND
t1
,
t2
;
auto
func1
=
graph
->
compile
({
make_callback_copy
(
y1
,
t1
)});
func1
->
execute
();
auto
func2
=
graph
->
compile
({
make_callback_copy
(
y2
,
t2
)});
func2
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
t1
,
t2
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录