Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cb5c5fd5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cb5c5fd5
编写于
6月 30, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opt): add nchw->nchw4 for tensorrt replace pass
GitOrigin-RevId: db114549be9af37287ea91314aa3c394020378fd
上级
2e70cf1d
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
76 addition
and
4 deletion
+76
-4
src/core/impl/graph/cg_impl.cpp
src/core/impl/graph/cg_impl.cpp
+1
-1
src/tensorrt/impl/opr_replace.cpp
src/tensorrt/impl/opr_replace.cpp
+10
-1
src/tensorrt/include/megbrain/tensorrt/opr_replace.h
src/tensorrt/include/megbrain/tensorrt/opr_replace.h
+2
-1
src/tensorrt/test/opr_replace.cpp
src/tensorrt/test/opr_replace.cpp
+63
-1
未找到文件。
src/core/impl/graph/cg_impl.cpp
浏览文件 @
cb5c5fd5
...
...
@@ -481,7 +481,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
#if MGB_ENABLE_TENSOR_RT
if
(
options
().
graph_opt
.
tensorrt
)
{
options
().
graph_opt
.
tensorrt
=
false
;
tensorrt
::
transform_dest_vars_inplace
(
dest_vars
);
tensorrt
::
transform_dest_vars_inplace
(
dest_vars
,
options
().
graph_opt
);
}
#endif
...
...
src/tensorrt/impl/opr_replace.cpp
浏览文件 @
cb5c5fd5
...
...
@@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() {
}
}
void
mgb
::
tensorrt
::
transform_dest_vars_inplace
(
mgb
::
cg
::
VarNodeArray
&
dest_vars
)
{
void
mgb
::
tensorrt
::
transform_dest_vars_inplace
(
mgb
::
cg
::
VarNodeArray
&
dest_vars
,
cg
::
GraphCommonOptimizeOptions
&
options
)
{
gopt
::
GraphOptimizer
optimizer
;
//! As in megengine, the layout is NCHW, while tensorrt pass currently
//! only support NCHW4(int8), so we transform layout to nchw4 firstly.
if
(
options
.
has_set_nchw4
())
{
options
.
disable_nchw4
();
optimizer
.
add_pass
<
FuseConvBiasNonlinPass
>
();
optimizer
.
add_pass
(
EnableNCHW4Pass
::
make_nchw4_converter
());
}
optimizer
.
add_pass
<
ExpandFusedArithPass
>
();
optimizer
.
add_pass
<
gopt
::
TensorRTReplacePass
>
();
optimizer
.
add_pass
<
ArithFusePass
>
();
...
...
src/tensorrt/include/megbrain/tensorrt/opr_replace.h
浏览文件 @
cb5c5fd5
...
...
@@ -32,7 +32,8 @@ public:
namespace
tensorrt
{
void
transform_dest_vars_inplace
(
mgb
::
cg
::
VarNodeArray
&
dest_vars
);
void
transform_dest_vars_inplace
(
mgb
::
cg
::
VarNodeArray
&
dest_vars
,
cg
::
GraphCommonOptimizeOptions
&
options
);
}
}
// namespace mgb
...
...
src/tensorrt/test/opr_replace.cpp
浏览文件 @
cb5c5fd5
...
...
@@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) {
MGB_ASSERT_TENSOR_NEAR
(
outputs
[
1
],
outputs
[
3
],
1e-3
);
}
TEST
(
TestTensorRTReplace
,
FuseConvAddNchw2nchw4
)
{
REQUIRE_GPU
(
1
);
HostTensorGenerator
<
dtype
::
Float32
,
RandomDistribution
::
UNIFORM
>
gen
{
1.2
f
,
127
*
127
};
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
),
dtype
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
))
.
rename
(
name
),
dtype
);
};
auto
x
=
mkvar
(
"x"
,
{
32
,
4
,
28
,
28
},
dtype
::
QuantizedS8
(
2.5
f
)),
w
=
mkcvar
(
"w"
,
{
16
,
4
,
3
,
3
},
dtype
::
QuantizedS8
(
2.5
f
)),
b
=
mkcvar
(
"b"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
opr
::
ConvBias
::
Param
param
;
param
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
y
=
opr
::
ConvBias
::
make
(
x
,
w
,
b
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
auto
z
=
opr
::
TypeCvt
::
make
(
y
,
dtype
::
Float32
());
SymbolVar
trt_z
;
SymbolVar
mgb_z
;
ComputingGraph
::
Options
opt
;
opt
.
graph_opt_level
=
0
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
FuseConvBiasNonlinPass
>
()
.
add_pass
(
gopt
::
EnableNCHW4Pass
::
make_nchw4_converter
())
.
add_pass
<
gopt
::
ExpandFusedArithPass
>
()
.
add_pass
<
gopt
::
TensorRTReplacePass
>
()
.
add_pass
<
gopt
::
ArithFusePass
>
()
.
apply
({{
z
}})
.
endpoint_vars
(),
trt_z
);
opt
.
graph_opt_level
=
0
;
unpack_vector
(
gopt
::
GraphOptimizer
{}.
apply
({{
z
}}).
endpoint_vars
(),
mgb_z
);
ComputingGraph
::
OutputSpec
outspec
(
2
);
SmallVector
<
HostTensorND
>
outputs
(
2
);
outspec
[
0
]
=
make_callback_copy
(
trt_z
,
outputs
[
0
],
false
);
outspec
[
1
]
=
make_callback_copy
(
mgb_z
,
outputs
[
1
],
false
);
graph
->
options
().
graph_opt
.
tensorrt
=
false
;
auto
func
=
graph
->
compile
(
outspec
);
func
->
execute
();
MGB_ASSERT_TENSOR_NEAR
(
outputs
[
0
],
outputs
[
1
],
1e-3
);
}
#endif // MGB_ENABLE_TENSOR_RT
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录