Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2e70cf1d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2e70cf1d
编写于
6月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opt): add nchw->nchw4 in tensorcore pass
GitOrigin-RevId: 755f8dfefe28bb14dba5d86a54a9bb725af285ed
上级
1e8337f1
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
191 addition
and
60 deletion
+191
-60
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+2
-0
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+25
-16
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+164
-44
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
2e70cf1d
...
@@ -756,6 +756,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
...
@@ -756,6 +756,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb
(
nchw32
,
{
cb
(
nchw32
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
FuseConvBiasZPass
>
();
add_pass
<
FuseConvBiasZPass
>
();
add_pass
(
EnableNCHW4Pass
::
make_nchw4_converter
());
add_pass
(
EnableTensorCorePass
::
make_tensorcore_converter
());
add_pass
(
EnableTensorCorePass
::
make_tensorcore_converter
());
add_pass
<
ShuffleShuffleRemovePass
>
();
add_pass
<
ShuffleShuffleRemovePass
>
();
add_pass
<
RemoveRedundantTypeCvtPass
>
();
add_pass
<
RemoveRedundantTypeCvtPass
>
();
...
@@ -763,6 +764,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
...
@@ -763,6 +764,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb
(
chwn4
,
{
cb
(
chwn4
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
FuseConvBiasZPass
>
();
add_pass
<
FuseConvBiasZPass
>
();
add_pass
(
EnableNCHW4Pass
::
make_nchw4_converter
());
add_pass
(
EnableCHWN4Pass
::
make_chwn4_converter
());
add_pass
(
EnableCHWN4Pass
::
make_chwn4_converter
());
add_pass
<
ShuffleShuffleRemovePass
>
();
add_pass
<
ShuffleShuffleRemovePass
>
();
add_pass
<
RemoveRedundantTypeCvtPass
>
();
add_pass
<
RemoveRedundantTypeCvtPass
>
();
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
2e70cf1d
...
@@ -1356,16 +1356,17 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1356,16 +1356,17 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
megdnn
::
param
::
Convolution
::
Format
conv_format
=
megdnn
::
param
::
Convolution
::
Format
conv_format
=
megdnn
::
param
::
Convolution
::
Format
::
NCHW4
;
megdnn
::
param
::
Convolution
::
Format
::
NCHW4
;
megdnn
::
param
::
ConvBias
::
Format
conv_bias_format
=
megdnn
::
param
::
ConvBias
::
Format
conv_bias_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW4
;
megdnn
::
param
::
ConvBias
::
Format
::
NCHW4
;
megdnn
::
param
::
BatchConvBias
::
Format
batch_conv_bias_format
=
megdnn
::
param
::
BatchConvBias
::
Format
batch_conv_bias_format
=
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW4
;
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW4
;
RelayoutMode
src_to_nchw4_mode
=
RelayoutMode
::
NCHW_TO_NCHW4
;
RelayoutMode
src_to_nchw4_mode
=
RelayoutMode
::
NCHW_TO_NCHW4
;
RelayoutMode
src_to_nchw_mode
=
RelayoutMode
::
NCHW4_TO_NCHW
;
RelayoutMode
src_to_nchw_mode
=
RelayoutMode
::
NCHW4_TO_NCHW
;
RelayoutMode
weight_to_nchw4_mode_dense
=
RelayoutMode
weight_to_nchw4_mode_dense
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW4_DENSE
;
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW4_DENSE
;
RelayoutMode
weight_to_nchw4_mode_group
=
RelayoutMode
weight_to_nchw4_mode_group
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW4_GROUP
;
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW4_GROUP
;
auto
trans_nchw4
=
[
weight_to_nchw4_mode_dense
,
auto
trans_nchw4
=
[
weight_to_nchw4_mode_dense
,
weight_to_nchw4_mode_group
](
weight_to_nchw4_mode_group
](
const
megdnn
::
param
::
Convolution
::
Sparse
conv_mode
,
const
megdnn
::
param
::
Convolution
::
Sparse
conv_mode
,
...
@@ -1391,9 +1392,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1391,9 +1392,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
conv_opr
=
opr
->
cast_final_safe
<
opr
::
ConvolutionForward
>
();
auto
&
conv_opr
=
opr
->
cast_final_safe
<
opr
::
ConvolutionForward
>
();
mgb_assert
(
conv_opr
.
param
().
format
==
if
(
conv_opr
.
param
().
format
!=
megdnn
::
param
::
Convolution
::
Format
::
NCHW
,
megdnn
::
param
::
Convolution
::
Format
::
NCHW
)
{
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
VarNode
*
conv_src
=
new_inp
[
0
],
*
conv_filter
=
new_inp
[
1
];
VarNode
*
conv_src
=
new_inp
[
0
],
*
conv_filter
=
new_inp
[
1
];
// src: NCHW --> NCWH4
// src: NCHW --> NCWH4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
...
@@ -1427,7 +1430,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1427,7 +1430,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
batch_conv_bias_opr
=
auto
&
batch_conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
BatchConvBiasForward
>
();
opr
->
cast_final_safe
<
opr
::
BatchConvBiasForward
>
();
mgb_assert
(
batch_conv_bias_opr
.
param
().
format
==
if
(
batch_conv_bias_opr
.
param
().
format
!=
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW
)
{
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
mgb_assert
(
batch_conv_bias_opr
.
param
().
format
==
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW
,
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
// what should be converted: src, weight
// what should be converted: src, weight
...
@@ -1494,9 +1503,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1494,9 +1503,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
const
VarNodeArray
&
new_inp
)
{
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
auto
&
conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
mgb_assert
(
conv_bias_opr
.
param
().
format
==
if
(
conv_bias_opr
.
param
().
format
!=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
,
megdnn
::
param
::
Convolution
::
Format
::
NCHW
)
{
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
// what should be converted: src, weight
// what should be converted: src, weight
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
];
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
];
// src: NCHW --> NCHW4
// src: NCHW --> NCHW4
...
@@ -1604,8 +1616,9 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1604,8 +1616,9 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using
Format
=
Param
::
Format
;
using
Format
=
Param
::
Format
;
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
pooling
=
opr
->
cast_final_safe
<
opr
::
PoolingForward
>
();
auto
&
pooling
=
opr
->
cast_final_safe
<
opr
::
PoolingForward
>
();
mgb_assert
(
pooling
.
param
().
format
==
Format
::
NCHW
,
if
(
pooling
.
param
().
format
!=
Format
::
NCHW
)
{
"ConvertFormat Pass only support converting NCHW to NCHW4."
);
return
opr
;
}
if
(
new_inp
[
0
]
->
shape
().
ndim
==
5
)
{
if
(
new_inp
[
0
]
->
shape
().
ndim
==
5
)
{
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
auto
new_param
=
pooling
.
param
();
auto
new_param
=
pooling
.
param
();
...
@@ -1628,8 +1641,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1628,8 +1641,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using
Format
=
Param
::
Format
;
using
Format
=
Param
::
Format
;
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
resize
=
opr
->
cast_final_safe
<
opr
::
ResizeForward
>
();
auto
&
resize
=
opr
->
cast_final_safe
<
opr
::
ResizeForward
>
();
mgb_assert
(
resize
.
param
().
format
==
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHW4."
);
if
(
new_inp
[
0
]
->
shape
().
ndim
==
5
)
{
if
(
new_inp
[
0
]
->
shape
().
ndim
==
5
)
{
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
auto
new_param
=
resize
.
param
();
auto
new_param
=
resize
.
param
();
...
@@ -1652,8 +1663,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
...
@@ -1652,8 +1663,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using
Format
=
Param
::
Format
;
using
Format
=
Param
::
Format
;
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
warp
=
opr
->
cast_final_safe
<
opr
::
WarpPerspectiveForward
>
();
auto
&
warp
=
opr
->
cast_final_safe
<
opr
::
WarpPerspectiveForward
>
();
mgb_assert
(
warp
.
param
().
format
==
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHW4."
);
if
(
new_inp
[
0
]
->
shape
().
ndim
==
5
)
{
if
(
new_inp
[
0
]
->
shape
().
ndim
==
5
)
{
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
auto
new_param
=
warp
.
param
();
auto
new_param
=
warp
.
param
();
...
...
src/gopt/test/inference.cpp
浏览文件 @
2e70cf1d
...
@@ -1512,6 +1512,7 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) {
...
@@ -1512,6 +1512,7 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) {
#if MGB_CUDA
#if MGB_CUDA
TEST
(
TestEnableTensorCore
,
SmallInputShape
)
{
TEST
(
TestEnableTensorCore
,
SmallInputShape
)
{
REQUIRE_GPU
(
1
);
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
...
@@ -1579,6 +1580,104 @@ TEST(TestEnableTensorCore, SmallInputShape) {
...
@@ -1579,6 +1580,104 @@ TEST(TestEnableTensorCore, SmallInputShape) {
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
}
}
TEST
(
TestEnableTensorCore
,
Nchw4Nchw
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
cn
.
activate
();
auto
&&
prop
=
CompNodeEnv
::
from_comp_node
(
cn
).
cuda_env
().
device_prop
;
auto
sm_ver
=
prop
.
major
*
10
+
prop
.
minor
;
if
(
sm_ver
<
75
)
{
printf
(
"This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)
\n
"
,
sm_ver
,
75
);
return
;
}
HostTensorGenerator
<
dtype
::
Int8
>
gen
;
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
,
cn
)).
rename
(
name
),
dtype
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
),
dtype
);
};
auto
mkshape
=
[](
opr
::
ConvBias
::
Param
::
Format
format
,
size_t
N
,
size_t
C
,
size_t
H
,
size_t
W
)
->
TensorShape
{
mgb_assert
(
C
%
4
==
0
);
if
(
format
==
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
)
{
return
{
N
,
C
/
4
,
H
,
W
,
4
};
}
else
{
mgb_assert
(
format
==
opr
::
ConvBias
::
Param
::
Format
::
NCHW
);
return
{
N
,
C
,
H
,
W
};
}
};
for
(
auto
format
:
{
opr
::
ConvBias
::
Param
::
Format
::
NCHW
,
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
})
{
auto
x
=
mkvar
(
"x"
,
mkshape
(
format
,
32
,
64
,
16
,
16
),
dtype
::
QuantizedS8
(
2.5
f
)),
w
=
mkcvar
(
"w1"
,
mkshape
(
format
,
64
,
64
,
3
,
3
),
dtype
::
QuantizedS8
(
2.5
f
)),
b
=
mkcvar
(
"b"
,
mkshape
(
format
,
1
,
64
,
1
,
1
),
dtype
::
QuantizedS32
(
6.25
f
)),
z
=
mkcvar
(
"b1"
,
mkshape
(
format
,
32
,
64
,
8
,
8
),
dtype
::
QuantizedS8
(
2.5
f
));
opr
::
ConvBias
::
Param
param
;
param
.
format
=
format
;
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
y
=
opr
::
ConvBias
::
make
(
x
,
w
,
b
,
z
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
(
2.5
f
)});
y
=
opr
::
ConvBias
::
make
(
y
,
w
,
b
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
(
2.5
f
)});
y
=
opr
::
TypeCvt
::
make
(
y
,
dtype
::
Float32
());
SymbolVar
y_opt
;
SymbolVar
y_no_tc
;
{
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nchw32
().
enable_fuse_conv_bias_nonlinearity
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
}
{
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_fuse_conv_bias_nonlinearity
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_no_tc
);
}
auto
nr_dimshuffle
=
find_opr_num
<
mgb
::
opr
::
Dimshuffle
>
(
y_opt
);
std
::
string
json_name
;
ASSERT_EQ
(
2u
,
nr_dimshuffle
);
if
(
format
==
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
)
{
json_name
=
"TestGoptInference.Nchw4Nchw.NCHW4.json"
;
}
else
{
mgb_assert
(
format
==
opr
::
ConvBias
::
Param
::
Format
::
NCHW
);
json_name
=
"TestGoptInference.Nchw4Nchw.NCHW.json"
;
}
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
json_name
.
c_str
()));
HostTensorND
host_y
,
host_y_opt
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y_no_tc
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
}
}
TEST
(
TestEnableTensorCore
,
ConvBiasWithZ
)
{
TEST
(
TestEnableTensorCore
,
ConvBiasWithZ
)
{
REQUIRE_GPU
(
1
);
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
...
@@ -2043,53 +2142,74 @@ TEST(TestGoptInference, EnableCHWN4) {
...
@@ -2043,53 +2142,74 @@ TEST(TestGoptInference, EnableCHWN4) {
.
rename
(
name
),
.
rename
(
name
),
dtype
);
dtype
);
};
};
auto
mkshape
=
[](
opr
::
ConvBias
::
Param
::
Format
format
,
size_t
N
,
size_t
C
,
size_t
H
,
size_t
W
)
->
TensorShape
{
mgb_assert
(
C
%
4
==
0
);
if
(
format
==
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
)
{
return
{
N
,
C
/
4
,
H
,
W
,
4
};
}
else
{
mgb_assert
(
format
==
opr
::
ConvBias
::
Param
::
Format
::
NCHW
);
return
{
N
,
C
,
H
,
W
};
}
};
auto
x
=
mkvar
(
"x"
,
{
32
,
16
,
16
,
16
,
4
},
dtype
::
QuantizedS8
(
2.5
f
)),
for
(
auto
format
:
{
opr
::
ConvBias
::
Param
::
Format
::
NCHW
,
w
=
mkcvar
(
"w1"
,
{
64
,
16
,
3
,
3
,
4
},
dtype
::
QuantizedS8
(
2.5
f
)),
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
})
{
b
=
mkcvar
(
"b"
,
{
1
,
16
,
1
,
1
,
4
},
dtype
::
QuantizedS32
(
6.25
f
)),
auto
x
=
mkvar
(
"x"
,
mkshape
(
format
,
32
,
64
,
16
,
16
),
b1
=
mkvar
(
"b1"
,
{
32
,
16
,
16
,
16
,
4
},
dtype
::
QuantizedS8
(
2.5
f
));
dtype
::
QuantizedS8
(
2.5
f
)),
opr
::
ConvBias
::
Param
param
;
w
=
mkcvar
(
"w1"
,
mkshape
(
format
,
64
,
64
,
3
,
3
),
param
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
;
dtype
::
QuantizedS8
(
2.5
f
)),
param
.
stride_h
=
param
.
stride_w
=
1
;
b
=
mkcvar
(
"b"
,
mkshape
(
format
,
1
,
64
,
1
,
1
),
param
.
pad_h
=
param
.
pad_w
=
1
;
dtype
::
QuantizedS32
(
6.25
f
)),
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
b1
=
mkvar
(
"b1"
,
mkshape
(
format
,
32
,
64
,
16
,
16
),
dtype
::
QuantizedS8
(
2.5
f
));
opr
::
ConvBias
::
Param
param
;
param
.
format
=
format
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
1
;
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
auto
y
=
opr
::
ConvBiasForward
::
make
(
auto
y
=
opr
::
ConvBiasForward
::
make
(
x
,
w
,
b
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
x
,
w
,
b
,
param
,
{},
auto
y1
=
opr
::
ElemwiseMultiType
::
make
(
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
{
y
,
b1
},
opr
::
ElemwiseMultiType
::
Mode
::
QFUSE_ADD_RELU
,
auto
y1
=
opr
::
ElemwiseMultiType
::
make
(
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
{
y
,
b1
},
opr
::
ElemwiseMultiType
::
Mode
::
QFUSE_ADD_RELU
,
auto
y2
=
opr
::
ConvBiasForward
::
make
(
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
y
,
w
,
b
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
auto
y2
=
opr
::
ConvBiasForward
::
make
(
auto
y3
=
opr
::
ElemwiseMultiType
::
make
(
y
,
w
,
b
,
param
,
{},
{
y
,
b1
},
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
QSUB
,
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
auto
y3
=
opr
::
ElemwiseMultiType
::
make
(
auto
y4
=
opr
::
ElemwiseMultiType
::
make
(
{
y
,
b1
},
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
QSUB
,
{
y1
,
y2
},
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
QADD
,
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
auto
y4
=
opr
::
ElemwiseMultiType
::
make
(
y4
=
opr
::
ElemwiseMultiType
::
make
(
{
y1
,
y2
},
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
QADD
,
{
y3
,
y4
},
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
QADD
,
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
y4
=
opr
::
ElemwiseMultiType
::
make
(
y4
=
opr
::
TypeCvt
::
make
(
y4
,
dtype
::
Float32
());
{
y3
,
y4
},
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
QADD
,
SymbolVar
y_opt
;
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
SymbolVar
y_cudnn
;
y4
=
opr
::
TypeCvt
::
make
(
y4
,
dtype
::
Float32
());
{
SymbolVar
y_opt
;
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
SymbolVar
y_cudnn
;
options
.
enable_chwn4
();
{
unpack_vector
(
gopt
::
optimize_for_inference
({
y4
},
options
),
y_opt
);
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_chwn4
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y4
},
options
),
y_opt
);
}
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
FuseConvBiasNonlinPass
>
()
.
add_pass
<
gopt
::
FuseConvBiasZPass
>
()
.
apply
({{
y4
}})
.
endpoint_vars
(),
y_cudnn
);
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
CHWN4
,
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
HostTensorND
host_y
,
host_y_opt
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y_cudnn
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
}
}
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
FuseConvBiasNonlinPass
>
()
.
add_pass
<
gopt
::
FuseConvBiasZPass
>
()
.
apply
({{
y4
}})
.
endpoint_vars
(),
y_cudnn
);
HostTensorND
host_y
,
host_y_opt
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y_cudnn
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
}
}
TEST
(
TestGoptInference
,
EnableCHWN4WarpPespective
)
{
TEST
(
TestGoptInference
,
EnableCHWN4WarpPespective
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录