Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2ab5c53f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2ab5c53f
编写于
6月 09, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): support nhwc conv in tensor reformat pass
GitOrigin-RevId: 43e78d758ab352c9e47d9ca1bb5fe868d4443458
上级
009c90a2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
134 addition
and
16 deletion
+134
-16
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+134
-16
未找到文件。
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
2ab5c53f
...
...
@@ -4618,6 +4618,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
);
return
ovar
.
node
();
}
else
if
(
fmt
==
Format
::
NHWC
)
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW
);
return
ovar
.
node
();
}
else
{
mgb_assert
(
fmt
==
Format
::
NCHW64
);
auto
ovar
=
RelayoutPlaceholder
::
make
(
...
...
@@ -4679,6 +4684,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW32_TO_NCHW4
);
return
ovar
.
node
();
}
else
if
(
fmt
==
Format
::
NHWC
)
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW4
);
return
ovar
.
node
();
}
else
{
mgb_assert
(
fmt
==
Format
::
NCHW64
);
auto
ovar
=
RelayoutPlaceholder
::
make
(
...
...
@@ -4741,6 +4751,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW32
);
return
ovar
.
node
();
}
else
if
(
fmt
==
Format
::
NHWC
)
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW32
);
return
ovar
.
node
();
}
else
{
mgb_assert
(
fmt
==
Format
::
NCHW64
);
auto
ovar
=
RelayoutPlaceholder
::
make
(
...
...
@@ -4800,6 +4815,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW64
);
return
ovar
.
node
();
}
else
if
(
fmt
==
Format
::
NHWC
)
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW64
);
return
ovar
.
node
();
}
else
{
mgb_assert
(
fmt
==
Format
::
NCHW32
);
auto
ovar
=
RelayoutPlaceholder
::
make
(
...
...
@@ -4818,10 +4838,75 @@ EnableNCHW64Pass::make_nchw64_converter() {
return
ret
;
};
auto
try_transform_to_nhwc
=
[
make_new_conv
,
&
format_map
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
->
VarNode
*
{
// fint4XWint4 and fuint4XWint4
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
bool
check_dtype
=
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS4
||
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
new_inp
[
1
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS4
;
if
(
opr
->
input
().
size
()
>=
3
)
check_dtype
&=
new_inp
[
2
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS32
;
if
(
opr
->
input
().
size
()
>=
4
)
check_dtype
&=
new_inp
[
3
]
->
dtype
().
enumv
()
==
new_inp
[
0
]
->
dtype
().
enumv
();
if
(
!
check_dtype
)
return
nullptr
;
size_t
out_channels
=
opr
->
input
(
1
)
->
shape
()[
0
];
size_t
in_channels
=
opr
->
input
(
1
)
->
shape
()[
1
];
bool
check_channels
=
out_channels
%
8
==
0
&&
in_channels
%
8
==
0
;
if
(
!
check_channels
)
return
nullptr
;
auto
inps
=
new_inp
;
auto
process
=
[
&
](
size_t
i
)
->
VarNode
*
{
auto
iter
=
format_map
.
find
(
new_inp
[
i
]
->
owner_opr
());
if
(
iter
==
format_map
.
end
())
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NHWC
);
return
ovar
.
node
();
}
else
{
const
auto
&
fmt
=
iter
->
second
;
if
(
fmt
==
Format
::
NHWC
)
{
return
inps
[
i
];
}
else
if
(
fmt
==
Format
::
NCHW4
)
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NHWC
);
return
ovar
.
node
();
}
else
if
(
fmt
==
Format
::
NCHW32
)
{
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW32_TO_NHWC
);
return
ovar
.
node
();
}
else
{
mgb_assert
(
fmt
==
Format
::
NCHW64
);
auto
ovar
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NCHW64_TO_NHWC
);
return
ovar
.
node
();
}
}
};
for
(
size_t
i
=
0
;
i
<
inps
.
size
();
++
i
)
{
inps
[
i
]
=
process
(
i
);
}
auto
&
conv_bias
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
auto
ret
=
make_new_conv
(
inps
,
&
conv_bias
,
Format
::
NHWC
);
format_map
.
insert
(
std
::
make_pair
(
ret
->
owner_opr
(),
Format
::
NHWC
));
return
ret
;
};
// replace rule for conv bias opr
auto
replace_conv_bias_opr
=
[
&
format_map
,
try_transform_to_nchw4
,
try_transform_to_nchw32
,
try_transform_to_nchw64
,
try_transform_to_nchw
](
try_transform_to_nchw64
,
try_transform_to_nhwc
,
try_transform_to_nchw
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
using
Param
=
megdnn
::
param
::
ConvBias
;
...
...
@@ -4833,7 +4918,8 @@ EnableNCHW64Pass::make_nchw64_converter() {
VarNode
*
new_var
=
nullptr
;
if
((
new_var
=
try_transform_to_nchw32
(
opr
,
new_inp
))
||
(
new_var
=
try_transform_to_nchw4
(
opr
,
new_inp
))
||
(
new_var
=
try_transform_to_nchw64
(
opr
,
new_inp
))
||
(
new_var
=
try_transform_to_nchw64
(
opr
,
new_inp
))
||
(
new_var
=
try_transform_to_nhwc
(
opr
,
new_inp
))
||
(
new_var
=
try_transform_to_nchw
(
opr
,
new_inp
)))
{
return
new_var
->
owner_opr
();
}
else
{
...
...
@@ -4891,6 +4977,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
NCHW_TO_NCHW4
)
.
node
();
break
;
case
Format
::
NHWC
:
inps
[
1
]
=
RelayoutPlaceholder
::
make
(
inps
[
1
],
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NHWC
)
.
node
();
break
;
case
Format
::
NCHW32
:
inps
[
1
]
=
RelayoutPlaceholder
::
make
(
inps
[
1
],
RelayoutPlaceholder
::
LayoutType
::
...
...
@@ -4991,6 +5083,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
cb
(
NCHW4
,
NCHW
),
cb
(
NCHW4
,
NCHW32
),
cb
(
NCHW4
,
NCHW64
),
cb
(
NCHW32
,
NCHW
),
cb
(
NCHW32
,
NCHW4
),
cb
(
NCHW32
,
NCHW64
),
cb
(
NCHW32
,
NCHW
),
cb
(
NCHW32
,
NCHW4
),
cb
(
NCHW32
,
NCHW64
),
cb
(
NCHW
,
NHWC
),
cb
(
NCHW4
,
NHWC
),
cb
(
NCHW32
,
NHWC
),
cb
(
NCHW64
,
NHWC
),
cb
(
NHWC
,
NCHW
),
cb
(
NHWC
,
NCHW4
),
cb
(
NHWC
,
NCHW32
),
cb
(
NHWC
,
NCHW64
),
#undef cb
};
auto
inps
=
new_inp
;
...
...
@@ -5037,26 +5132,27 @@ EnableNCHW64Pass::make_nchw64_converter() {
case
Format
::
NCHW
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_N
CHW64
)
NCHW_TO_N
HWC
)
.
node
();
break
;
case
Format
::
NCHW4
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_N
CHW64
)
NCHW4_TO_N
HWC
)
.
node
();
break
;
case
Format
::
NCHW32
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NCHW32_TO_N
CHW64
)
NCHW32_TO_N
HWC
)
.
node
();
break
;
default:
mgb_assert
(
cur
==
Format
::
NCHW64
);
mgb_assert
(
cur
==
Format
::
NCHW64
||
cur
==
Format
::
NHWC
);
}
auto
target_format
=
cur
==
Format
::
NCHW64
?
cur
:
Format
::
NHWC
;
auto
param
=
warp
.
param
();
param
.
format
=
Format
::
NCHW64
;
param
.
format
=
target_format
;
SymbolVar
new_warp
;
if
(
inps
.
size
()
==
3
)
{
new_warp
=
opr
::
WarpPerspectiveForward
::
make
(
...
...
@@ -5069,7 +5165,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
warp
.
config
());
}
auto
ret
=
new_warp
.
node
()
->
owner_opr
();
format_map
.
insert
(
std
::
make_pair
(
ret
,
Format
::
NCHW64
));
format_map
.
insert
(
std
::
make_pair
(
ret
,
target_format
));
return
ret
;
}
else
if
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
Format
cur
;
...
...
@@ -5087,6 +5183,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
NCHW_TO_NCHW4
)
.
node
();
break
;
case
Format
::
NHWC
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW4
)
.
node
();
break
;
case
Format
::
NCHW32
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
...
...
@@ -5154,31 +5256,31 @@ EnableNCHW64Pass::make_nchw64_converter() {
case
Format
::
NCHW
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_N
CHW64
)
NCHW_TO_N
HWC
)
.
node
();
break
;
case
Format
::
NCHW4
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_N
CHW64
)
NCHW4_TO_N
HWC
)
.
node
();
break
;
case
Format
::
NCHW32
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
NCHW32_TO_N
CHW64
)
NCHW32_TO_N
HWC
)
.
node
();
break
;
default:
mgb_assert
(
cur
==
Format
::
NCHW64
);
mgb_assert
(
cur
==
Format
::
NCHW64
||
cur
==
Format
::
NHWC
);
}
auto
target_format
=
cur
==
Format
::
NCHW64
?
cur
:
Format
::
NHWC
;
auto
param
=
pooling
.
param
();
param
.
format
=
Format
::
NCHW64
;
param
.
format
=
target_format
;
auto
new_pool
=
opr
::
PoolingForward
::
make
(
inps
[
0
],
param
,
pooling
.
config
());
auto
ret
=
new_pool
.
node
()
->
owner_opr
();
format_map
.
insert
(
std
::
make_pair
(
ret
,
Format
::
NCHW64
));
format_map
.
insert
(
std
::
make_pair
(
ret
,
target_format
));
return
ret
;
}
else
if
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
Format
cur
;
...
...
@@ -5188,12 +5290,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
}
else
{
cur
=
iter
->
second
;
}
size_t
in_channels
=
new_inp
[
0
]
->
shape
()[
1
];
bool
use_nchw32
=
false
;
auto
inps
=
new_inp
;
using
LayoutType
=
RelayoutPlaceholder
::
LayoutType
;
switch
(
cur
)
{
case
Format
::
NCHW
:
{
size_t
in_channels
=
new_inp
[
0
]
->
shape
()[
1
];
use_nchw32
=
in_channels
%
32
==
0
;
auto
layout_type
=
use_nchw32
?
LayoutType
::
NCHW_TO_NCHW32
:
LayoutType
::
NCHW_TO_NCHW4
;
...
...
@@ -5201,6 +5303,15 @@ EnableNCHW64Pass::make_nchw64_converter() {
.
node
();
break
;
}
case
Format
::
NHWC
:
{
size_t
in_channels
=
new_inp
[
0
]
->
shape
()[
3
];
use_nchw32
=
in_channels
%
32
==
0
;
auto
layout_type
=
use_nchw32
?
LayoutType
::
NHWC_TO_NCHW32
:
LayoutType
::
NHWC_TO_NCHW4
;
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
layout_type
)
.
node
();
break
;
}
case
Format
::
NCHW64
:
inps
[
0
]
=
RelayoutPlaceholder
::
make
(
inps
[
0
],
RelayoutPlaceholder
::
LayoutType
::
...
...
@@ -5253,6 +5364,13 @@ EnableNCHW64Pass::make_nchw64_converter() {
auto
fmt
=
iter
!=
format_map
.
end
()
?
iter
->
second
:
Format
::
NCHW
;
if
(
iter
!=
format_map
.
end
())
{
switch
(
fmt
)
{
case
Format
::
NHWC
:
inps
[
i
]
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW
)
.
node
();
break
;
case
Format
::
NCHW4
:
inps
[
i
]
=
RelayoutPlaceholder
::
make
(
inps
[
i
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录