diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index c1492b4ab20e04e03cb6b85b534c5a41ba68c755..9839b6099a6fcf41faef8de9cdd66ada6af6684f 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -481,7 +481,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( #if MGB_ENABLE_TENSOR_RT if (options().graph_opt.tensorrt) { options().graph_opt.tensorrt = false; - tensorrt::transform_dest_vars_inplace(dest_vars); + tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt); } #endif diff --git a/src/tensorrt/impl/opr_replace.cpp b/src/tensorrt/impl/opr_replace.cpp index 4f61be64bbdf710c94f276700b64ceee1b6e8403..b6ebb876b4c0128016c6d36d20f45e9813875333 100644 --- a/src/tensorrt/impl/opr_replace.cpp +++ b/src/tensorrt/impl/opr_replace.cpp @@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() { } } -void mgb::tensorrt::transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars) { +void mgb::tensorrt::transform_dest_vars_inplace( + mgb::cg::VarNodeArray& dest_vars, + cg::GraphCommonOptimizeOptions& options) { gopt::GraphOptimizer optimizer; + //! As in megengine, the layout is NCHW, while tensorrt pass currently + //! only support NCHW4(int8), so we transform layout to nchw4 firstly. + if (options.has_set_nchw4()) { + options.disable_nchw4(); + optimizer.add_pass(); + optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter()); + } optimizer.add_pass(); optimizer.add_pass(); optimizer.add_pass(); diff --git a/src/tensorrt/include/megbrain/tensorrt/opr_replace.h b/src/tensorrt/include/megbrain/tensorrt/opr_replace.h index a285b7e8293cc85416875cfec0dfee88e699f5ee..f61a4e06ce3417e49539fcb37664729c0615c151 100644 --- a/src/tensorrt/include/megbrain/tensorrt/opr_replace.h +++ b/src/tensorrt/include/megbrain/tensorrt/opr_replace.h @@ -32,7 +32,8 @@ public: namespace tensorrt { -void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars); +void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars, + cg::GraphCommonOptimizeOptions& options); } } // namespace mgb diff --git a/src/tensorrt/test/opr_replace.cpp b/src/tensorrt/test/opr_replace.cpp index cd6a20f08051e117a36a93801677b22a7e51426d..5d03884c453f0c54eba2b88b74fdf548b5843ac4 100644 --- a/src/tensorrt/test/opr_replace.cpp +++ b/src/tensorrt/test/opr_replace.cpp @@ -1930,7 +1930,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) { param.stride_h = param.stride_w = 1; param.pad_h = param.pad_w = 1; auto y = opr::Convolution::make(x, w, param); - + auto nchw2nchw4 = [](SymbolVar x) { auto xshp = opr::GetVarShape::make(x); @@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) { MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); } +TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) { + REQUIRE_GPU(1); + HostTensorGenerator gen{ + 1.2f, 127 * 127}; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {32, 4, 28, 28}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {16, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + auto z = opr::TypeCvt::make(y, dtype::Float32()); + + SymbolVar trt_z; + SymbolVar mgb_z; + + ComputingGraph::Options opt; + opt.graph_opt_level = 0; + unpack_vector( + gopt::GraphOptimizer{} + .add_pass() + .add_pass(gopt::EnableNCHW4Pass::make_nchw4_converter()) + .add_pass() + .add_pass() + .add_pass() + .apply({{z}}) + .endpoint_vars(), + trt_z); + + opt.graph_opt_level = 0; + unpack_vector(gopt::GraphOptimizer{}.apply({{z}}).endpoint_vars(), + mgb_z); + + ComputingGraph::OutputSpec outspec(2); + SmallVector outputs(2); + outspec[0] = make_callback_copy(trt_z, outputs[0], false); + outspec[1] = make_callback_copy(mgb_z, outputs[1], false); + graph->options().graph_opt.tensorrt = false; + auto func = graph->compile(outspec); + func->execute(); + + MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[1], 1e-3); +} + #endif // MGB_ENABLE_TENSOR_RT // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}