Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
90d5895e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
90d5895e
编写于
6月 29, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
7月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/gopt): remove redundant reshape in nchw->nchw4 pass
GitOrigin-RevId: 0f5c7c3e485b4da0cdfe9b0db3e23945ac43ee16
上级
0d12ae80
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
8 addition
and
7 deletion
+8
-7
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+6
-7
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+2
-0
未找到文件。
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
90d5895e
...
@@ -435,13 +435,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
...
@@ -435,13 +435,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
};
auto
tshp0
=
opr
::
Concat
::
make
(
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
{
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
);
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
4
,
2
});
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
4
,
2
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y1
.
node
();
return
y2
.
node
();
};
};
reformat
[
LayoutType
::
NCHW4_TO_NCHW
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
reformat
[
LayoutType
::
NCHW4_TO_NCHW
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
x
=
SymbolVar
(
inp
);
...
@@ -455,7 +452,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
...
@@ -455,7 +452,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
return
y1
.
node
();
};
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_DENSE
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_DENSE
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
...
@@ -471,7 +469,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
...
@@ -471,7 +469,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
return
y2
.
node
();
};
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_GROUP
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_GROUP
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
...
...
src/gopt/test/inference.cpp
浏览文件 @
90d5895e
...
@@ -2450,6 +2450,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
...
@@ -2450,6 +2450,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
,
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
,
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
auto
nr_reshape
=
find_opr_num
<
mgb
::
opr
::
Reshape
>
(
y_opt
);
ASSERT_EQ
(
2u
,
nr_reshape
);
graph
->
compile
({{
y_opt
,
{}}})
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
to_json
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录