Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
58b8b145
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
58b8b145
编写于
7月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb/gopt): add checker for reformat emitter
GitOrigin-RevId: 53a8c128f57e05147a0acaffbf52fe55bcbad281
上级
55efc8e1
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
116 addition
and
65 deletion
+116
-65
src/gopt/impl/reformat_emitter.cpp
src/gopt/impl/reformat_emitter.cpp
+80
-45
src/gopt/include/megbrain/gopt/reformat_emitter.h
src/gopt/include/megbrain/gopt/reformat_emitter.h
+24
-18
src/gopt/test/reformat_emitter.cpp
src/gopt/test/reformat_emitter.cpp
+12
-2
未找到文件。
src/gopt/impl/reformat_emitter.cpp
浏览文件 @
58b8b145
...
...
@@ -10,8 +10,8 @@
* implied.
*/
#include <numeric>
#include "megbrain/gopt/reformat_emitter.h"
#include <numeric>
#include "megbrain/opr/tensor_manip.h"
using
namespace
mgb
;
...
...
@@ -19,34 +19,7 @@ using namespace gopt;
using
Dimension
=
megdnn
::
Dimension
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
ReshapeEmitter
::
Operator
ReshapeEmitter
::
emit
()
const
{
auto
pattern
=
analyze
();
auto
op
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
auto
shp
=
opr
::
GetVarShape
::
make
(
sym_var
);
auto
cv
=
[
&
sym_var
](
int
c
)
{
return
sym_var
.
make_scalar
(
c
);
};
auto
sub
=
[
&
shp
,
&
cv
](
int
ax
)
{
return
opr
::
IndexAt
::
make
(
shp
,
{{
0
,
cv
(
ax
)}});
};
SymbolVarArray
axs
;
for
(
auto
i
:
pattern
)
{
if
(
std
::
get
<
0
>
(
i
)
>=
0
)
{
if
(
std
::
get
<
2
>
(
i
))
axs
.
emplace_back
(
sub
(
std
::
get
<
0
>
(
i
))
*
std
::
get
<
1
>
(
i
));
else
axs
.
emplace_back
(
sub
(
std
::
get
<
0
>
(
i
))
/
std
::
get
<
1
>
(
i
));
}
else
{
axs
.
emplace_back
(
cv
(
std
::
get
<
1
>
(
i
)));
}
}
auto
tshp
=
opr
::
Concat
::
make
(
axs
,
0
);
auto
ovar
=
opr
::
Reshape
::
make
(
sym_var
,
tshp
);
return
ovar
.
node
();
};
return
op
;
}
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
ReshapeEmitter
::
analyze
()
const
{
ModifyShapeMixin
::
Pattern
ModifyShapeMixin
::
mixin_analyze
()
const
{
static
constexpr
uint32_t
UNDETERMINED_EXTENT
=
Dimension
::
UNDETERMINED_EXTENT
;
ThinHashMap
<
Dimension
::
Name
,
int
>
name2dominant
;
...
...
@@ -58,7 +31,7 @@ SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const {
}
}
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
pattern
(
m_dest
.
ndim
);
Pattern
pattern
(
m_dest
.
ndim
);
for
(
size_t
i
=
0
;
i
<
m_dest
.
ndim
;
++
i
)
{
auto
name
=
m_dest
[
i
].
name
();
if
(
m_dest
[
i
].
extent
()
==
UNDETERMINED_EXTENT
)
{
...
...
@@ -74,28 +47,90 @@ SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const {
return
pattern
;
}
DimshuffleEmitter
::
Operator
DimshuffleEmitter
::
emit
()
const
{
auto
pattern
=
m_pattern
;
auto
op
=
[
pattern
](
VarNode
*
var
)
{
ModifyShapeMixin
::
Checker
ModifyShapeMixin
::
mixin_emit_checker
(
const
Pattern
&
pattern
)
const
{
auto
src
=
m_src
;
auto
checker
=
[
src
,
pattern
](
VarNode
*
var
)
{
const
auto
&
shp
=
var
->
shape
();
if
(
shp
.
ndim
!=
src
.
ndim
)
return
false
;
bool
available
=
true
;
for
(
size_t
i
=
0
;
i
<
shp
.
ndim
;
++
i
)
{
if
(
src
[
i
].
extent
()
!=
Dimension
::
UNDETERMINED_EXTENT
)
{
available
&=
(
shp
[
i
]
==
src
[
i
].
extent
());
}
}
for
(
auto
&&
i
:
pattern
)
{
int
axis
,
factor
;
bool
mul
;
std
::
tie
(
axis
,
factor
,
mul
)
=
i
;
if
(
axis
>=
0
&&
!
mul
)
{
available
&=
(
shp
[
axis
]
%
factor
==
0
);
}
}
return
available
;
};
return
checker
;
}
ReshapeEmitter
::
EmitResult
ReshapeEmitter
::
emit
()
const
{
auto
pattern
=
mixin_analyze
();
auto
builder
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
auto
shp
=
opr
::
GetVarShape
::
make
(
sym_var
);
auto
cv
=
[
&
sym_var
](
int
c
)
{
return
sym_var
.
make_scalar
(
c
);
};
auto
sub
=
[
&
shp
,
&
cv
](
int
ax
)
{
return
opr
::
IndexAt
::
make
(
shp
,
{{
0
,
cv
(
ax
)}});
};
SymbolVarArray
axs
;
for
(
auto
&&
i
:
pattern
)
{
int
axis
,
factor
;
bool
mul
;
std
::
tie
(
axis
,
factor
,
mul
)
=
i
;
if
(
axis
>=
0
)
{
if
(
mul
)
axs
.
emplace_back
(
sub
(
axis
)
*
factor
);
else
axs
.
emplace_back
(
sub
(
axis
)
/
factor
);
}
else
{
axs
.
emplace_back
(
cv
(
factor
));
}
}
auto
tshp
=
opr
::
Concat
::
make
(
axs
,
0
);
auto
ovar
=
opr
::
Reshape
::
make
(
sym_var
,
tshp
);
return
ovar
.
node
();
};
auto
checker
=
mixin_emit_checker
(
pattern
);
return
std
::
make_tuple
(
builder
,
checker
);
}
DimshuffleEmitter
::
EmitResult
DimshuffleEmitter
::
emit
()
const
{
auto
&&
pattern
=
m_pattern
;
auto
builder
=
[
pattern
](
VarNode
*
var
)
{
auto
sym_var
=
SymbolVar
(
var
);
return
opr
::
Dimshuffle
::
make
(
sym_var
,
pattern
).
node
();
};
return
op
;
auto
checker
=
[
pattern
](
VarNode
*
var
)
{
return
var
->
shape
().
ndim
==
pattern
.
size
();
};
return
std
::
make_tuple
(
builder
,
checker
);
}
ReformatEmitter
::
Operator
ReformatEmitter
::
emit
()
const
{
ReformatEmitter
::
EmitResult
ReformatEmitter
::
emit
()
const
{
auto
ops
=
analyze
();
auto
op
=
[
ops
](
VarNode
*
var
)
{
auto
builder
=
[
ops
](
VarNode
*
var
)
{
VarNode
*
ovar
=
var
;
for
(
const
auto
&
o
:
ops
)
{
ovar
=
o
(
ovar
);
for
(
const
auto
&
i
:
ops
)
{
ovar
=
i
(
ovar
);
}
return
ovar
;
};
return
op
;
auto
pattern
=
mixin_analyze
();
auto
checker
=
mixin_emit_checker
(
pattern
);
return
std
::
make_tuple
(
builder
,
checker
);
}
SmallVector
<
ReformatEmitter
::
Operato
r
>
ReformatEmitter
::
analyze
()
const
{
SmallVector
<
ReformatEmitter
::
Builde
r
>
ReformatEmitter
::
analyze
()
const
{
struct
Dim
{
Dimension
dim
;
int
index
;
...
...
@@ -161,12 +196,12 @@ SmallVector<ReformatEmitter::Operator> ReformatEmitter::analyze() const {
i1
[
i
]
=
src_dims
[
src_perm
[
i
]].
dim
;
i2
[
i
]
=
src_dims
[
src_perm
[
permute
[
i
]]].
dim
;
}
SmallVector
<
Operato
r
>
ops
;
SmallVector
<
Builde
r
>
ops
;
if
(
!
m_src
.
eq_shape
(
i1
))
ops
.
emplace_back
(
ReshapeEmitter
(
m_src
,
i1
).
emit
(
));
ops
.
emplace_back
(
DimshuffleEmitter
(
permute
).
emit
(
));
ops
.
emplace_back
(
std
::
get
<
0
>
(
ReshapeEmitter
(
m_src
,
i1
).
emit
()
));
ops
.
emplace_back
(
std
::
get
<
0
>
(
DimshuffleEmitter
(
permute
).
emit
()
));
if
(
!
m_dest
.
eq_shape
(
i2
))
ops
.
emplace_back
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
(
));
ops
.
emplace_back
(
std
::
get
<
0
>
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
()
));
return
ops
;
}
// vim: syntax=cpp.doxygen
src/gopt/include/megbrain/gopt/reformat_emitter.h
浏览文件 @
58b8b145
...
...
@@ -20,45 +20,51 @@ namespace gopt {
class
Emitter
{
public:
using
Operator
=
thin_function
<
VarNode
*
(
VarNode
*
)
>
;
using
Builder
=
thin_function
<
VarNode
*
(
VarNode
*
)
>
;
using
Checker
=
thin_function
<
bool
(
VarNode
*
)
>
;
using
EmitResult
=
std
::
tuple
<
Builder
,
Checker
>
;
virtual
~
Emitter
()
=
default
;
virtual
Operator
emit
()
const
=
0
;
virtual
EmitResult
emit
()
const
=
0
;
};
class
ReshapeEmitter
final
:
public
Emitter
{
class
ModifyShapeMixin
{
protected:
using
Pattern
=
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
;
using
Checker
=
Emitter
::
Checker
;
ModifyShapeMixin
(
const
megdnn
::
NamedTensorShape
&
src
,
const
megdnn
::
NamedTensorShape
&
dest
)
:
m_src
(
src
),
m_dest
(
dest
)
{}
Pattern
mixin_analyze
()
const
;
Checker
mixin_emit_checker
(
const
Pattern
&
pattern
)
const
;
megdnn
::
NamedTensorShape
m_src
,
m_dest
;
};
class
ReshapeEmitter
final
:
public
Emitter
,
ModifyShapeMixin
{
public:
using
Operator
=
typename
Emitter
::
Operator
;
ReshapeEmitter
(
const
megdnn
::
NamedTensorShape
&
src
,
const
megdnn
::
NamedTensorShape
&
dest
)
:
m_src
{
src
},
m_dest
{
dest
}
{}
Operator
emit
()
const
override
;
private:
SmallVector
<
std
::
tuple
<
int
,
int
,
bool
>>
analyze
()
const
;
megdnn
::
NamedTensorShape
m_src
,
m_dest
;
:
ModifyShapeMixin
(
src
,
dest
)
{}
EmitResult
emit
()
const
override
;
};
class
DimshuffleEmitter
final
:
public
Emitter
{
public:
using
Operator
=
typename
Emitter
::
Operator
;
DimshuffleEmitter
(
const
std
::
vector
<
int
>&
pattern
)
:
m_pattern
{
pattern
}
{}
Operator
emit
()
const
override
;
EmitResult
emit
()
const
override
;
private:
std
::
vector
<
int
>
m_pattern
;
};
class
ReformatEmitter
final
:
public
Emitter
{
class
ReformatEmitter
final
:
public
Emitter
,
ModifyShapeMixin
{
public:
using
Operator
=
typename
Emitter
::
Operator
;
ReformatEmitter
(
const
megdnn
::
NamedTensorShape
&
src
,
const
megdnn
::
NamedTensorShape
&
dest
)
:
m_src
{
src
},
m_dest
{
dest
}
{}
Operator
emit
()
const
override
;
:
ModifyShapeMixin
(
src
,
dest
)
{}
EmitResult
emit
()
const
override
;
private:
SmallVector
<
Operator
>
analyze
()
const
;
megdnn
::
NamedTensorShape
m_src
,
m_dest
;
SmallVector
<
Builder
>
analyze
()
const
;
};
}
// namespace gopt
}
// namespace mgb
...
...
src/gopt/test/reformat_emitter.cpp
浏览文件 @
58b8b145
...
...
@@ -25,7 +25,9 @@ TEST(TestReformatEmitter, Basic) {
NamedTensorShape
::
Format
::
NCHW4
);
auto
src
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW32
);
auto
reformat
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
&&
tuple
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
reformat
=
std
::
get
<
0
>
(
tuple
);
auto
checker
=
std
::
get
<
1
>
(
tuple
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
...
...
@@ -51,6 +53,9 @@ TEST(TestReformatEmitter, Basic) {
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
32
,
H
,
W
,
32
});
EXPECT_TRUE
(
checker
(
x
.
node
()));
auto
x_
=
mkvar
(
"x"
,
{
N
,
H
,
W
,
C
});
EXPECT_FALSE
(
checker
(
x_
.
node
()));
auto
y1
=
SymbolVar
(
reformat
(
x
.
node
()));
auto
y2
=
SymbolVar
(
nchw32_to_nchw4
(
x
.
node
()));
HostTensorND
t1
,
t2
;
...
...
@@ -69,7 +74,9 @@ TEST(TestReformatEmitter, MoreComplicated) {
NamedTensorShape
::
Format
::
NCHW64
);
auto
dest
=
NamedTensorShape
::
make_named_tensor_shape
(
NamedTensorShape
::
Format
::
NCHW88
);
auto
reformat
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
&&
tuple
=
gopt
::
ReformatEmitter
(
src
,
dest
).
emit
();
auto
reformat
=
std
::
get
<
0
>
(
tuple
);
auto
checker
=
std
::
get
<
1
>
(
tuple
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
...
...
@@ -77,6 +84,9 @@ TEST(TestReformatEmitter, MoreComplicated) {
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
N
,
C
/
64
,
H
,
W
,
64
});
EXPECT_TRUE
(
checker
(
x
.
node
()));
auto
x_
=
mkvar
(
"x"
,
{
N
,
H
,
W
,
C
});
EXPECT_FALSE
(
checker
(
x_
.
node
()));
auto
y
=
SymbolVar
(
reformat
(
x
.
node
()));
HostTensorND
t
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
t
)});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录