Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
45e2beea
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
45e2beea
编写于
5月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add nchw4 optpass
GitOrigin-RevId: 551b6b828d33916b8e0a8bec73e6d3c6abd65536
上级
f2e1bb41
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
554 addition
and
50 deletion
+554
-50
python_module/megengine/_internal/__init__.py
python_module/megengine/_internal/__init__.py
+3
-0
python_module/megengine/jit/__init__.py
python_module/megengine/jit/__init__.py
+1
-0
python_module/src/swig/misc.i
python_module/src/swig/misc.i
+1
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+7
-0
sdk/load-and-run/src/mgblar.cpp
sdk/load-and-run/src/mgblar.cpp
+1
-0
src/core/include/megbrain/graph/cg.h
src/core/include/megbrain/graph/cg.h
+2
-0
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+7
-0
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+393
-50
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+13
-0
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+126
-0
未找到文件。
python_module/megengine/_internal/__init__.py
浏览文件 @
45e2beea
...
...
@@ -541,6 +541,7 @@ def optimize_for_inference(
fuse_conv_bias_nonlinearity
=
False
,
use_nchw32
=
False
,
fuse_conv_bias_with_z
=
False
,
use_nchw4
=
False
,
use_nchw88
=
False
,
use_nchw44
=
False
,
use_chwn4
=
False
...
...
@@ -561,6 +562,7 @@ def optimize_for_inference(
OpenCL devices
:param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. This is supported only in NHWCD4 format.
:param use_nchw4: whether to use NCHW4 tensor format.
:param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some
times.
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
...
...
@@ -588,6 +590,7 @@ def optimize_for_inference(
layout_tranform
=
None
for
k
,
v
in
{
"use_nchw4"
:
"nchw4"
,
"use_nhwcd4"
:
"nhwcd4"
,
"use_nchw32"
:
"nchw32"
,
"use_nchw88"
:
"nchw88"
,
...
...
python_module/megengine/jit/__init__.py
浏览文件 @
45e2beea
...
...
@@ -463,6 +463,7 @@ class trace:
"enable_io16xc32"
:
"f16_io_f32_comp"
,
"enable_ioc16"
:
"f16_io_comp"
,
"enable_hwcd4"
:
"use_nhwcd4"
,
"enable_nchw4"
:
"use_nchw4"
,
"enable_nchw88"
:
"use_nchw88"
,
"enable_nchw32"
:
"use_nchw32"
,
"enable_nchw44"
:
"use_nchw44"
,
...
...
python_module/src/swig/misc.i
浏览文件 @
45e2beea
...
...
@@ -80,6 +80,7 @@ struct _OptimizeForInferenceOptions {
#
define
SET
(
_trans
,
_trans_capital
)
\
void
enable_
##
_trans
()
;
\
SET
(
nchw4
,
NCHW4
)
;
SET
(
nhwcd4
,
NHWCD4
)
;
SET
(
nchw88
,
NCHW88
)
;
SET
(
nchw44
,
NCHW44
)
;
...
...
sdk/load-and-run/dump_with_testcase_mge.py
浏览文件 @
45e2beea
...
...
@@ -252,6 +252,7 @@ def optimize_for_inference(args, outputs):
'enable_io16xc32'
:
'f16_io_f32_comp'
,
'enable_ioc16'
:
'f16_io_comp'
,
'enable_hwcd4'
:
'use_nhwcd4'
,
'enable_nchw4'
:
'use_nchw4'
,
'enable_nchw88'
:
'use_nchw88'
,
'enable_nchw44'
:
'use_nchw44'
,
'enable_nchw32'
:
'use_nchw32'
,
...
...
@@ -381,6 +382,12 @@ def main():
'for inference; you may need to disable CUDA and set '
'MGB_USE_MEGDNN_DBG=2'
)
parser
.
add_argument
(
'--enable-nchw4'
,
action
=
'store_true'
,
help
=
'transform the model format from NCHW to NCHW4 '
'for inference'
)
parser
.
add_argument
(
'--enable-nchw88'
,
action
=
'store_true'
,
...
...
sdk/load-and-run/src/mgblar.cpp
浏览文件 @
45e2beea
...
...
@@ -980,6 +980,7 @@ Args Args::from_argv(int argc, char **argv) {
continue; \
}
cb
(
nchw4
);
cb
(
chwn4
);
cb
(
nchw44
);
cb
(
nchw88
);
...
...
src/core/include/megbrain/graph/cg.h
浏览文件 @
45e2beea
...
...
@@ -97,6 +97,7 @@ struct GraphCommonOptimizeOptions {
bool
fuse_conv_bias_with_z
=
false
;
enum
LayoutTransform
:
uint32_t
{
DEFAULT
,
NCHW4
,
///< compute using NCHW4 tensor format
NHWCD4
,
///< compute using NHWCD4 tensor format
NCHW88
,
///< compute using NCHW88 tensor format
NCHW44
,
///< compute using NCHW44 tensor format
...
...
@@ -137,6 +138,7 @@ struct GraphCommonOptimizeOptions {
return layout_transform == LayoutTransform::_trans_capital; \
}
SET
(
nchw4
,
NCHW4
);
SET
(
nhwcd4
,
NHWCD4
);
SET
(
nchw88
,
NCHW88
);
SET
(
nchw44
,
NCHW44
);
...
...
src/gopt/impl/framework.cpp
浏览文件 @
45e2beea
...
...
@@ -725,6 +725,13 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb
(
f16_io_comp
,
{
add_pass
(
ConvertF32ToF16Pass
::
make
(
false
));
});
cb
(
f16_io_f32_comp
,
{
add_pass
(
ConvertF32ToF16Pass
::
make
(
true
));
});
cb
(
nchw4
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
FuseConvBiasZPass
>
();
add_pass
(
EnableNCHW4Pass
::
make_nchw4_converter
());
add_pass
<
ShuffleShuffleRemovePass
>
();
add_pass
<
RemoveRedundantTypeCvtPass
>
();
});
cb
(
nhwcd4
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
(
ConvertFormatPass
::
make_nhwcd4_converter
());
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
45e2beea
...
...
@@ -63,10 +63,15 @@ public:
NCHW32_TO_NCHW4
,
//!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4
,
//!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4
,
//!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4
,
//!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW
,
//!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88
,
//!< from nchw layout to nchw88 layout
NCHW_TO_NCHW44
,
//!< from nchw layout to nchw44 layout
NCHW88_TO_NCHW
,
//!< from nchw88 layout to nchw layout
NCHW44_TO_NCHW
,
//!< from nchw44 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE
,
//!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP
,
//!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW88_DENSE
,
//!< weight from nchw layout to nchw88
//!< layout
...
...
@@ -167,6 +172,42 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
3
]
=
inp_shape
[
2
];
dst
[
4
]
=
inp_shape
[
4
];
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NCHW4
){
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
1
]
%
4
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
/
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
){
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
*
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_DENSE
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
1
]
%
4
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
/
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_GROUP
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
2
]
%
4
==
0
);
dst
.
ndim
=
6
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
];
dst
[
2
]
=
inp_shape
[
2
]
/
4
;
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
inp_shape
[
4
];
dst
[
5
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NCHW88
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
1
]
%
8
==
0
);
dst
.
ndim
=
5
;
...
...
@@ -226,23 +267,6 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
];
dst
[
4
]
=
8
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NCHW44
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
1
]
%
4
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
/
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW44_TO_NCHW
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
*
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DENSE
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
0
]
%
4
==
0
&&
...
...
@@ -394,6 +418,66 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW_TO_NCHW4
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
4
,
2
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW4_TO_NCHW
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
*
4
,
sub
(
2
),
sub
(
3
)},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
1
,
4
,
2
,
3
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_DENSE
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
4
,
2
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_GROUP
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
),
sub
(
2
)
/
4
,
cv
(
4
),
sub
(
3
),
sub
(
4
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
),
sub
(
2
)
/
4
,
sub
(
3
),
sub
(
4
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
2
,
4
,
5
,
3
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW_TO_NCHW88
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
...
...
@@ -492,34 +576,6 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW_TO_NCHW44
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
4
,
2
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW44_TO_NCHW
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
*
4
,
sub
(
2
),
sub
(
3
)},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
1
,
4
,
2
,
3
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DENSE
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
...
...
@@ -1239,6 +1295,293 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
return
ret
;
}
/* ================ EnableNCHW4Pass ================ */
VarNode
*
EnableNCHW4Pass
::
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
{
if
(
!
orig_var
->
shape
().
eq_shape
(
new_var
->
shape
()))
{
return
RelayoutPlaceholder
::
make
(
new_var
,
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
)
.
node
();
}
return
new_var
;
}
std
::
unique_ptr
<
EnableNCHW4Pass
>
EnableNCHW4Pass
::
make_nchw4_converter
(){
auto
ret
=
std
::
make_unique
<
EnableNCHW4Pass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
megdnn
::
param
::
Convolution
::
Format
conv_format
=
megdnn
::
param
::
Convolution
::
Format
::
NCHW4
;
megdnn
::
param
::
ConvBias
::
Format
conv_bias_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW4
;
megdnn
::
param
::
BatchConvBias
::
Format
batch_conv_bias_format
=
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW4
;
RelayoutMode
src_to_nchw4_mode
=
RelayoutMode
::
NCHW_TO_NCHW4
;
RelayoutMode
src_to_nchw_mode
=
RelayoutMode
::
NCHW4_TO_NCHW
;
RelayoutMode
weight_to_nchw4_mode_dense
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW4_DENSE
;
RelayoutMode
weight_to_nchw4_mode_group
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW4_GROUP
;
auto
trans_nchw4
=
[
weight_to_nchw4_mode_dense
,
weight_to_nchw4_mode_group
](
const
megdnn
::
param
::
Convolution
::
Sparse
conv_mode
,
const
VarNode
*
filter
)
->
RelayoutMode
{
if
(
conv_mode
==
megdnn
::
param
::
Convolution
::
Sparse
::
DENSE
)
{
mgb_assert
(
filter
->
shape
().
ndim
==
4
,
"The origin filter is not NCHW mode"
);
size_t
IC
=
filter
->
shape
()[
1
];
mgb_assert
(
IC
%
4
==
0
,
"The input channel should be divisible by 4"
);
return
weight_to_nchw4_mode_dense
;
}
else
{
mgb_assert
(
conv_mode
==
megdnn
::
param
::
Convolution
::
Sparse
::
GROUP
);
mgb_assert
(
filter
->
shape
().
ndim
==
5
,
"The origin filter if not NCHW mode"
);
size_t
IC
=
filter
->
shape
()[
2
];
mgb_assert
(
IC
%
4
==
0
,
"The input channel should be divisible by 4"
);
return
weight_to_nchw4_mode_group
;
}
};
auto
replace_conv_opr
=
[
trans_nchw4
,
conv_format
,
src_to_nchw4_mode
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
conv_opr
=
opr
->
cast_final_safe
<
opr
::
ConvolutionForward
>
();
mgb_assert
(
conv_opr
.
param
().
format
==
megdnn
::
param
::
Convolution
::
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
VarNode
*
conv_src
=
new_inp
[
0
],
*
conv_filter
=
new_inp
[
1
];
// src: NCHW --> NCWH4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
src_to_nchw4_mode
);
conv_src
=
new_src
.
node
();
}
// weight: NCHW --> NCHW4
auto
weight_mode
=
trans_nchw4
(
conv_opr
.
param
().
sparse
,
new_inp
[
1
]);
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
weight_mode
);
conv_filter
=
new_filter
.
node
();
// format: NCHW --> NCHW4
auto
new_param
=
conv_opr
.
param
();
new_param
.
format
=
conv_format
;
// dst
auto
new_conv_opr
=
opr
::
Convolution
::
make
(
conv_src
,
conv_filter
,
new_param
,
conv_opr
.
execution_policy
(),
conv_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchw4"
);
return
new_opr
;
};
auto
replace_batch_conv_bias_opr
=
[
batch_conv_bias_format
,
src_to_nchw4_mode
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
batch_conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
BatchConvBiasForward
>
();
mgb_assert
(
batch_conv_bias_opr
.
param
().
format
==
megdnn
::
param
::
BatchConvBias
::
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
// what should be converted: src, weight
VarNode
*
src
=
new_inp
[
0
],
*
filter
=
new_inp
[
1
];
// src: NCHW --> NCHW4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
src_to_nchw4_mode
);
src
=
new_src
.
node
();
}
// weight: BNCHW --> BNCHW4
// only support dense mode, which is similar with conv->group.
auto
weight_mode
=
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW4_GROUP
;
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
weight_mode
);
filter
=
new_filter
.
node
();
// format: NCHW --> NCHW4
auto
new_param
=
batch_conv_bias_opr
.
param
();
new_param
.
format
=
batch_conv_bias_format
;
if
(
new_inp
.
size
()
==
2
)
{
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
batch_conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
dst
.
node
()
->
owner_opr
();
mgb_assert
(
dst
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
return
new_opr
;
}
// bias: NCHW --> NCHW4
VarNode
*
bias
=
new_inp
[
2
];
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchw4_mode
);
bias
=
new_bias
.
node
();
}
if
(
new_inp
.
size
()
==
3
)
{
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
bias
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
batch_conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
dst
.
node
()
->
owner_opr
();
mgb_assert
(
dst
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
return
new_opr
;
}
// z_inp: NCHW --> NCHW4
VarNode
*
z_inp
=
new_inp
[
3
];
if
(
new_inp
[
3
]
->
shape
().
ndim
==
4
)
{
auto
new_z
=
RelayoutPlaceholder
::
make
(
new_inp
[
3
],
src_to_nchw4_mode
);
z_inp
=
new_z
.
node
();
}
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
bias
,
z_inp
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
batch_conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
dst
.
node
()
->
owner_opr
();
mgb_assert
(
dst
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
return
new_opr
;
};
auto
replace_conv_bias_opr
=
[
trans_nchw4
,
conv_bias_format
,
src_to_nchw4_mode
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
mgb_assert
(
conv_bias_opr
.
param
().
format
==
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHW4"
);
// what should be converted: src, weight
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
];
// src: NCHW --> NCHW4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
src_to_nchw4_mode
);
conv_bias_src
=
new_src
.
node
();
}
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
auto
weight_mode
=
trans_nchw4
(
conv_bias_opr
.
param
().
sparse
,
new_inp
[
1
]);
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
weight_mode
);
conv_bias_filter
=
new_filter
.
node
();
// format: NCHW --> NCHW4
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
conv_bias_format
;
if
(
new_inp
.
size
()
==
2
)
{
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
return
new_opr
;
}
// bias: NCHW --> NCHW4
VarNode
*
conv_bias_bias
=
new_inp
[
2
];
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchw4_mode
);
conv_bias_bias
=
new_bias
.
node
();
}
if
(
new_inp
.
size
()
==
3
)
{
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
return
new_opr
;
}
// z_inp: NCHW --> NCHW4
VarNode
*
z_inp
=
new_inp
[
3
];
if
(
new_inp
[
3
]
->
shape
().
ndim
==
4
)
{
auto
new_z
=
RelayoutPlaceholder
::
make
(
new_inp
[
3
],
src_to_nchw4_mode
);
z_inp
=
new_z
.
node
();
}
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
z_inp
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
return
new_opr
;
};
auto
replace_elemwise_opr
=
[
=
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
bool
has_inp_changed
=
false
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
new_inp
[
i
]
->
shape
().
ndim
==
5
)
{
has_inp_changed
=
true
;
break
;
}
}
if
(
has_inp_changed
)
{
auto
temp_inp
=
new_inp
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
new_inp
[
i
]
->
shape
().
ndim
==
4
)
{
auto
new_var
=
RelayoutPlaceholder
::
make
(
new_inp
[
i
],
src_to_nchw4_mode
);
temp_inp
[
i
]
=
new_var
.
node
();
}
else
{
mgb_assert
((
new_inp
[
i
]
->
shape
().
ndim
==
5
)
||
new_inp
[
i
]
->
shape
().
is_scalar
());
}
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
temp_inp
,
opr
->
config
());
}
else
{
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
};
auto
relayout_inp_to_nchw
=
[
=
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
VarNodeArray
temp_inp
=
new_inp
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
!
opr
->
input
(
i
)
->
shape
().
eq_shape
(
new_inp
[
i
]
->
shape
()))
{
mgb_assert
(
opr
->
input
(
i
)
->
shape
().
ndim
==
4
);
mgb_assert
(
new_inp
[
i
]
->
shape
().
ndim
==
5
);
auto
new_var
=
RelayoutPlaceholder
::
make
(
new_inp
[
i
],
src_to_nchw_mode
);
temp_inp
[
i
]
=
new_var
.
node
();
}
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
temp_inp
,
opr
->
config
());
};
auto
&&
replace_func
=
ret
->
m_opr_replace_func
;
//! supportted nchw4
replace_func
[
opr
::
Convolution
::
typeinfo
()]
=
replace_conv_opr
;
replace_func
[
opr
::
ConvBias
::
typeinfo
()]
=
replace_conv_bias_opr
;
replace_func
[
opr
::
BatchConvBias
::
typeinfo
()]
=
replace_batch_conv_bias_opr
;
replace_func
[
opr
::
Elemwise
::
typeinfo
()]
=
replace_elemwise_opr
;
replace_func
[
opr
::
TypeCvt
::
typeinfo
()]
=
replace_elemwise_opr
;
replace_func
[
opr
::
ElemwiseMultiType
::
typeinfo
()]
=
replace_elemwise_opr
;
replace_func
[
opr
::
PowC
::
typeinfo
()]
=
replace_elemwise_opr
;
//! not supported nchw4
replace_func
[
opr
::
PoolingForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Concat
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
ConvolutionBackwardData
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Subtensor
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
GetVarShape
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Dimshuffle
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Reduce
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
AssertEqual
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
IncrSubtensor
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
ResizeForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
WarpAffineForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
return
ret
;
}
/* ================ EnableNchwxxPass =============== */
VarNode
*
EnableNchwxxPass
::
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
{
...
...
@@ -1251,7 +1594,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
}
else
if
(
m_pack_c_size
==
4
)
{
return
RelayoutPlaceholder
::
make
(
new_var
,
RelayoutPlaceholder
::
LayoutType
::
NCHW4
4
_TO_NCHW
)
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
)
.
node
();
}
}
...
...
@@ -1287,8 +1630,8 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
weight_to_nchwxx_mode_group
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_GROUP
;
weight_to_nchwxx_mode_chan
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_CHAN
;
hybrid_nchw_nchwxx
=
RelayoutMode
::
WEIGHT_HYBIRD_NCHW_NCHW44
;
src_to_nchwxx_mode
=
RelayoutMode
::
NCHW_TO_NCHW4
4
;
src_to_nchw_mode
=
RelayoutMode
::
NCHW4
4
_TO_NCHW
;
src_to_nchwxx_mode
=
RelayoutMode
::
NCHW_TO_NCHW4
;
src_to_nchw_mode
=
RelayoutMode
::
NCHW4_TO_NCHW
;
conv_bias_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW44
;
conv_format
=
megdnn
::
param
::
ConvolutionV0
::
Format
::
NCHW44
;
pooling_format
=
megdnn
::
param
::
Pooling
::
Format
::
NCHW44
;
...
...
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
45e2beea
...
...
@@ -229,6 +229,19 @@ namespace gopt {
static
std
::
unique_ptr
<
EnableCHWN4Pass
>
make_chwn4_converter
();
};
/*!
* \brief convert tensor format to nchw4 to speed up inference on CUDA
*/
class
EnableNCHW4Pass
final
:
public
TensorReformatPass
{
VarNode
*
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
override
;
public:
const
char
*
name
()
const
override
{
return
mgb_cstr_log
(
"tensor_format_nchw4"
);
}
//! make nchw -> nchw4 converter opt pass
static
std
::
unique_ptr
<
EnableNCHW4Pass
>
make_nchw4_converter
();
};
/*!
* \brief convert tensor format to nchwxx to speed up inference on certain
* devices
...
...
src/gopt/test/inference.cpp
浏览文件 @
45e2beea
...
...
@@ -2327,8 +2327,134 @@ TEST(TestGoptInference, EnableCHWN4ShuffleRemove) {
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
}
TEST
(
TestGoptInference
,
ConvertFormatNCHW4GPU
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
cn
.
activate
();
auto
&&
prop
=
CompNodeEnv
::
from_comp_node
(
cn
).
cuda_env
().
device_prop
;
auto
sm_ver
=
prop
.
major
*
10
+
prop
.
minor
;
if
(
sm_ver
<
61
)
{
printf
(
"This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)
\n
"
,
sm_ver
,
61
);
return
;
}
HostTensorGenerator
<
dtype
::
Int8
>
gen
;
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
,
cn
)).
rename
(
name
),
dtype
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
),
dtype
);
};
auto
x
=
mkvar
(
"x"
,
{
2
,
4
,
16
,
16
},
dtype
::
QuantizedS8
(
2.5
f
));
opr
::
ConvBias
::
Param
param_conv_bias
;
param_conv_bias
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW
;
param_conv_bias
.
stride_h
=
param_conv_bias
.
stride_w
=
1
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
param_conv_bias
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
// dense
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
DENSE
;
auto
w1
=
mkcvar
(
"w1"
,
{
8
,
4
,
3
,
3
},
dtype
::
QuantizedS8
(
2.5
f
)),
b1
=
mkcvar
(
"b1"
,
{
1
,
8
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
auto
conv1
=
opr
::
ConvBiasForward
::
make
(
x
,
w1
,
b1
,
param_conv_bias
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
// group
// icpg != 1 && ocpg != 1
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
w2
=
mkcvar
(
"w2"
,
{
2
,
4
,
4
,
3
,
3
},
dtype
::
QuantizedS8
(
2.5
f
)),
b2
=
mkcvar
(
"b2"
,
{
1
,
8
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
auto
conv2
=
opr
::
ConvBiasForward
::
make
(
conv1
,
w2
,
b2
,
param_conv_bias
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
auto
y
=
opr
::
TypeCvt
::
make
(
conv2
,
dtype
::
Float32
());
SymbolVar
y_opt
;
{
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nchw4
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
}
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
,
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.ConvertFormatNCHW4GPU.json"
));
HostTensorND
host_y
,
host_y_opt
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_opt
);
}
#endif
TEST
(
TestGoptInference
,
ConvertFormatNCHW4
)
{
HostTensorGenerator
<>
gen
;
auto
cn
=
CompNode
::
load
(
"cpu0"
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
,
cn
)).
rename
(
name
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
);
};
auto
x
=
mkvar
(
"x"
,
{
2
,
4
,
16
,
16
});
// ConvBias
opr
::
ConvBias
::
Param
param_conv_bias
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
DENSE
;
auto
w1
=
mkcvar
(
"w1"
,
{
8
,
4
,
3
,
3
}),
b1
=
mkcvar
(
"b1"
,
{
1
,
8
,
1
,
1
});
auto
conv1
=
opr
::
ConvBias
::
make
(
x
,
w1
,
b1
,
param_conv_bias
);
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
w2
=
mkcvar
(
"w2"
,
{
2
,
4
,
4
,
3
,
3
}),
b2
=
mkcvar
(
"b2"
,
{
1
,
8
,
1
,
1
});
auto
conv2
=
opr
::
ConvBias
::
make
(
conv1
,
w2
,
b2
,
param_conv_bias
);
// Convolution
opr
::
Convolution
::
Param
param_conv
;
param_conv
.
pad_h
=
param_conv
.
pad_w
=
1
;
param_conv
.
sparse
=
opr
::
Convolution
::
Param
::
Sparse
::
DENSE
;
auto
w3
=
mkcvar
(
"w3"
,
{
8
,
8
,
3
,
3
});
auto
y
=
opr
::
Convolution
::
make
(
conv2
,
w3
,
param_conv
);
SymbolVar
y_opt
;
{
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nchw4
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
}
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
NCHW4
,
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.ConvertFormatNCHW4.json"
));
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-3
);
}
TEST
(
TestGoptInference
,
ConvertFormatNCHW88
)
{
HostTensorGenerator
<>
gen
;
auto
cn
=
CompNode
::
load
(
"cpu0"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录