From fdd14e09e4493acda90f084daa52afa7dba35d5c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 30 Jun 2020 15:46:23 +0800 Subject: [PATCH] feat(mgb/opt): add nchw->nchw4 for tensorrt replace pass GitOrigin-RevId: db114549be9af37287ea91314aa3c394020378fd --- src/core/impl/graph/cg_impl.cpp | 2 +- src/tensorrt/impl/opr_replace.cpp | 11 +++- .../include/megbrain/tensorrt/opr_replace.h | 3 +- src/tensorrt/test/opr_replace.cpp | 64 ++++++++++++++++++- 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index a3b15c3c1..26e4a4dc1 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -464,7 +464,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( #if MGB_ENABLE_TENSOR_RT if (options().graph_opt.tensorrt) { options().graph_opt.tensorrt = false; - tensorrt::transform_dest_vars_inplace(dest_vars); + tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt); } #endif diff --git a/src/tensorrt/impl/opr_replace.cpp b/src/tensorrt/impl/opr_replace.cpp index 4f61be64b..b6ebb876b 100644 --- a/src/tensorrt/impl/opr_replace.cpp +++ b/src/tensorrt/impl/opr_replace.cpp @@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() { } } -void mgb::tensorrt::transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars) { +void mgb::tensorrt::transform_dest_vars_inplace( + mgb::cg::VarNodeArray& dest_vars, + cg::GraphCommonOptimizeOptions& options) { gopt::GraphOptimizer optimizer; + //! As in megengine, the layout is NCHW, while tensorrt pass currently + //! only support NCHW4(int8), so we transform layout to nchw4 firstly. + if (options.has_set_nchw4()) { + options.disable_nchw4(); + optimizer.add_pass(); + optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter()); + } optimizer.add_pass(); optimizer.add_pass(); optimizer.add_pass(); diff --git a/src/tensorrt/include/megbrain/tensorrt/opr_replace.h b/src/tensorrt/include/megbrain/tensorrt/opr_replace.h index a285b7e82..f61a4e06c 100644 --- a/src/tensorrt/include/megbrain/tensorrt/opr_replace.h +++ b/src/tensorrt/include/megbrain/tensorrt/opr_replace.h @@ -32,7 +32,8 @@ public: namespace tensorrt { -void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars); +void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars, + cg::GraphCommonOptimizeOptions& options); } } // namespace mgb diff --git a/src/tensorrt/test/opr_replace.cpp b/src/tensorrt/test/opr_replace.cpp index cd6a20f08..5d03884c4 100644 --- a/src/tensorrt/test/opr_replace.cpp +++ b/src/tensorrt/test/opr_replace.cpp @@ -1930,7 +1930,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) { param.stride_h = param.stride_w = 1; param.pad_h = param.pad_w = 1; auto y = opr::Convolution::make(x, w, param); - + auto nchw2nchw4 = [](SymbolVar x) { auto xshp = opr::GetVarShape::make(x); @@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) { MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); } +TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) { + REQUIRE_GPU(1); + HostTensorGenerator gen{ + 1.2f, 127 * 127}; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {32, 4, 28, 28}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {16, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + auto z = opr::TypeCvt::make(y, dtype::Float32()); + + SymbolVar trt_z; + SymbolVar mgb_z; + + ComputingGraph::Options opt; + opt.graph_opt_level = 0; + unpack_vector( + gopt::GraphOptimizer{} + .add_pass() + .add_pass(gopt::EnableNCHW4Pass::make_nchw4_converter()) + .add_pass() + .add_pass() + .add_pass() + .apply({{z}}) + .endpoint_vars(), + trt_z); + + opt.graph_opt_level = 0; + unpack_vector(gopt::GraphOptimizer{}.apply({{z}}).endpoint_vars(), + mgb_z); + + ComputingGraph::OutputSpec outspec(2); + SmallVector outputs(2); + outspec[0] = make_callback_copy(trt_z, outputs[0], false); + outspec[1] = make_callback_copy(mgb_z, outputs[1], false); + graph->options().graph_opt.tensorrt = false; + auto func = graph->compile(outspec); + func->execute(); + + MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[1], 1e-3); +} + #endif // MGB_ENABLE_TENSOR_RT // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab