提交 fdd14e09 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mgb/opt): add nchw->nchw4 for tensorrt replace pass

GitOrigin-RevId: db114549be9af37287ea91314aa3c394020378fd
上级 3eb29f5e
...@@ -464,7 +464,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( ...@@ -464,7 +464,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
#if MGB_ENABLE_TENSOR_RT #if MGB_ENABLE_TENSOR_RT
if (options().graph_opt.tensorrt) { if (options().graph_opt.tensorrt) {
options().graph_opt.tensorrt = false; options().graph_opt.tensorrt = false;
tensorrt::transform_dest_vars_inplace(dest_vars); tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt);
} }
#endif #endif
......
...@@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() { ...@@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() {
} }
} }
void mgb::tensorrt::transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars) { void mgb::tensorrt::transform_dest_vars_inplace(
mgb::cg::VarNodeArray& dest_vars,
cg::GraphCommonOptimizeOptions& options) {
gopt::GraphOptimizer optimizer; gopt::GraphOptimizer optimizer;
//! As in megengine, the layout is NCHW, while tensorrt pass currently
//! only support NCHW4(int8), so we transform layout to nchw4 firstly.
if (options.has_set_nchw4()) {
options.disable_nchw4();
optimizer.add_pass<FuseConvBiasNonlinPass>();
optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter());
}
optimizer.add_pass<ExpandFusedArithPass>(); optimizer.add_pass<ExpandFusedArithPass>();
optimizer.add_pass<gopt::TensorRTReplacePass>(); optimizer.add_pass<gopt::TensorRTReplacePass>();
optimizer.add_pass<ArithFusePass>(); optimizer.add_pass<ArithFusePass>();
......
...@@ -32,7 +32,8 @@ public: ...@@ -32,7 +32,8 @@ public:
namespace tensorrt { namespace tensorrt {
void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars); void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars,
cg::GraphCommonOptimizeOptions& options);
} }
} // namespace mgb } // namespace mgb
......
...@@ -1930,7 +1930,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) { ...@@ -1930,7 +1930,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) {
param.stride_h = param.stride_w = 1; param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1; param.pad_h = param.pad_w = 1;
auto y = opr::Convolution::make(x, w, param); auto y = opr::Convolution::make(x, w, param);
auto nchw2nchw4 = [](SymbolVar x) { auto nchw2nchw4 = [](SymbolVar x) {
auto xshp = opr::GetVarShape::make(x); auto xshp = opr::GetVarShape::make(x);
...@@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) { ...@@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) {
MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3);
} }
TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{
1.2f, 127 * 127};
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp))
.rename(name),
dtype);
};
auto x = mkvar("x", {32, 4, 28, 28}, dtype::QuantizedS8(2.5f)),
w = mkcvar("w", {16, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f));
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
auto y = opr::ConvBias::make(x, w, b, param, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto z = opr::TypeCvt::make(y, dtype::Float32());
SymbolVar trt_z;
SymbolVar mgb_z;
ComputingGraph::Options opt;
opt.graph_opt_level = 0;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass(gopt::EnableNCHW4Pass::make_nchw4_converter())
.add_pass<gopt::ExpandFusedArithPass>()
.add_pass<gopt::TensorRTReplacePass>()
.add_pass<gopt::ArithFusePass>()
.apply({{z}})
.endpoint_vars(),
trt_z);
opt.graph_opt_level = 0;
unpack_vector(gopt::GraphOptimizer{}.apply({{z}}).endpoint_vars(),
mgb_z);
ComputingGraph::OutputSpec outspec(2);
SmallVector<HostTensorND> outputs(2);
outspec[0] = make_callback_copy(trt_z, outputs[0], false);
outspec[1] = make_callback_copy(mgb_z, outputs[1], false);
graph->options().graph_opt.tensorrt = false;
auto func = graph->compile(outspec);
func->execute();
MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[1], 1e-3);
}
#endif // MGB_ENABLE_TENSOR_RT #endif // MGB_ENABLE_TENSOR_RT
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册