From 69793a27e92888ec57dc2f7feeb443e39f5914ea Mon Sep 17 00:00:00 2001 From: Leo Chen <39020268+leo0519@users.noreply.github.com> Date: Fri, 11 Feb 2022 20:46:52 +0800 Subject: [PATCH] Add TensorRT inspector into Paddle-TRT (#38362) --- AUTHORS.md | 1 + paddle/fluid/inference/analysis/argument.h | 1 + .../inference/analysis/ir_pass_manager.cc | 1 + .../ir_passes/tensorrt_subgraph_pass.cc | 2 + paddle/fluid/inference/api/analysis_config.cc | 3 + .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/api/paddle_analysis_config.h | 4 + paddle/fluid/inference/tensorrt/engine.cc | 22 ++++- paddle/fluid/inference/tensorrt/engine.h | 16 ++-- .../operators/tensorrt/tensorrt_engine_op.h | 6 +- paddle/fluid/pybind/inference_api.cc | 4 + .../unittests/ir/inference/CMakeLists.txt | 1 + .../ir/inference/inference_pass_test.py | 16 +++- .../ir/inference/test_trt_inspector.py | 82 +++++++++++++++++++ 14 files changed, 143 insertions(+), 17 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_inspector.py diff --git a/AUTHORS.md b/AUTHORS.md index 60f5b424abb..e5481d83de1 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -83,3 +83,4 @@ | jeng1220 | Bai-Cheng(Ryan) Jeng (NVIDIA) | | mingxu1067 | Ming Huang (NVIDIA) | | zlsh80826 | Reese Wang (NVIDIA) | +| leo0519 | Leo Chen (NVIDIA) | diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index febfdec0b5c..f474ccd260e 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -219,6 +219,7 @@ struct Argument { bool); DECL_ARGUMENT_FIELD(tensorrt_allow_build_at_runtime, TensorRtAllowBuildAtRuntime, bool); + DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool); DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool); DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 3abda782ab6..837b83004de 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -156,6 +156,7 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("use_static_engine", new bool(use_static_engine)); pass->Set("model_from_memory", new bool(argument->model_from_memory())); + pass->Set("use_inspector", new bool(argument->tensorrt_use_inspector())); // tuned trt dynamic_shape pass->Set("trt_shape_range_info_path", diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 55bbc554508..904baebcb0b 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -265,6 +265,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("parameters", params); op_desc->SetAttr("allow_build_at_runtime", allow_build_at_runtime); op_desc->SetAttr("shape_range_info_path", shape_range_info_path); + op_desc->SetAttr("use_inspector", Get("use_inspector")); // we record all inputs' shapes in attr to check if they are consistent // with the real inputs' shapes retrieved from scope when trt runs. @@ -375,6 +376,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( trt_engine->SetWithInterleaved(Get("with_interleaved")); trt_engine->SetUseDLA(Get("trt_use_dla")); trt_engine->SetDLACore(Get("trt_dla_core")); + trt_engine->SetUseInspector(Get("use_inspector")); trt_engine->SetWithErnie( graph->Has(framework::ir::kEmbEltwiseLayernormPass) && diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 27369071933..57e49733b32 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -194,6 +194,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(trt_allow_build_at_runtime_); CP_MEMBER(collect_shape_range_info_); CP_MEMBER(shape_range_info_path_); + CP_MEMBER(trt_use_inspector_); // Dlnne related CP_MEMBER(use_dlnne_); CP_MEMBER(dlnne_min_subgraph_size_); @@ -427,6 +428,8 @@ void AnalysisConfig::EnableTensorRtDLA(int dla_core) { trt_dla_core_ = dla_core; } +void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; } + void AnalysisConfig::Exp_DisableTensorRtOPs( const std::vector &ops) { trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 628d974c123..694e84ad756 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -615,6 +615,7 @@ void AnalysisPredictor::PrepareArgument() { config_.tuned_tensorrt_dynamic_shape()); argument_.SetTensorRtAllowBuildAtRuntime( config_.trt_allow_build_at_runtime()); + argument_.SetTensorRtUseInspector(config_.trt_use_inspector_); } if (config_.dlnne_enabled()) { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index f65170daccb..4b13ca073bc 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -521,6 +521,9 @@ struct PD_INFER_DECL AnalysisConfig { /// bool tensorrt_dla_enabled() { return trt_use_dla_; } + void EnableTensorRtInspector(); + bool tensorrt_inspector_enabled() { return trt_use_inspector_; } + void EnableDlnne(int min_subgraph_size = 3); bool dlnne_enabled() const { return use_dlnne_; } @@ -807,6 +810,7 @@ struct PD_INFER_DECL AnalysisConfig { bool trt_allow_build_at_runtime_{false}; // tune to get dynamic_shape info. bool trt_tuned_dynamic_shape_{false}; + bool trt_use_inspector_{false}; // In CollectShapeInfo mode, we will collect the shape information of // all intermediate tensors in the compute graph and calculate the diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index aa69463674f..794475dfc10 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -57,7 +57,6 @@ void TensorRTEngine::Execute(int batch_size, std::vector *buffers, } else { #if IS_TRT_VERSION_GE(6000) infer_context->enqueueV2(buffers->data(), stream, nullptr); - GetEngineInfo(); #endif } SetRuntimeBatch(batch_size); @@ -244,8 +243,10 @@ void TensorRTEngine::FreezeNetwork() { #endif } #if IS_TRT_VERSION_GE(8200) - infer_builder_config_->setProfilingVerbosity( - nvinfer1::ProfilingVerbosity::kDETAILED); + if (use_inspector_) { + infer_builder_config_->setProfilingVerbosity( + nvinfer1::ProfilingVerbosity::kDETAILED); + } #endif #if IS_TRT_VERSION_LT(8000) @@ -411,6 +412,21 @@ void TensorRTEngine::freshDeviceId() { platform::SetDeviceId(device_id_); } +void TensorRTEngine::GetEngineInfo() { +#if IS_TRT_VERSION_GE(8200) + LOG(INFO) << "====== engine info ======"; + std::unique_ptr infer_inspector( + infer_engine_->createEngineInspector()); + auto infer_context = context(); + infer_inspector->setExecutionContext(infer_context); + LOG(INFO) << infer_inspector->getEngineInformation( + nvinfer1::LayerInformationFormat::kONELINE); + LOG(INFO) << "====== engine info end ======"; +#else + LOG(INFO) << "Inspector needs TensorRT version 8.2 and after."; +#endif +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 1f90ff216ad..b2764ca61c1 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -580,17 +580,10 @@ class TensorRTEngine { } void SetProfileNum(int num) { max_profile_num_ = num; } - void GetEngineInfo() { -#if IS_TRT_VERSION_GE(8200) - std::unique_ptr infer_inspector( - infer_engine_->createEngineInspector()); - infer_inspector->setExecutionContext(context()); - VLOG(3) << infer_inspector->getEngineInformation( - nvinfer1::LayerInformationFormat::kJSON); -#else - VLOG(3) << "Inspector needs TensorRT version 8.2 and after."; -#endif - } + + void GetEngineInfo(); + + void SetUseInspector(bool use_inspector) { use_inspector_ = use_inspector; } private: // Each ICudaEngine object is bound to a specific GPU when it is instantiated, @@ -664,6 +657,7 @@ class TensorRTEngine { std::vector> owned_pluginv2_; #endif std::mutex mutex_; + bool use_inspector_; }; // class TensorRTEngine // Add a layer__ into engine__ with args ARGS. diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 9357eb4b229..f67295a5dbf 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -140,6 +140,7 @@ class TensorRTEngineOp : public framework::OperatorBase { bool enable_int8_; bool enable_fp16_; bool use_calib_mode_; + bool use_inspector_; std::string calibration_data_; std::string engine_key_; std::string calibration_engine_key_; @@ -175,6 +176,7 @@ class TensorRTEngineOp : public framework::OperatorBase { shape_range_info_path_ = Attr("shape_range_info_path"); allow_build_at_runtime_ = Attr("allow_build_at_runtime"); use_static_engine_ = Attr("use_static_engine"); + use_inspector_ = HasAttr("use_inspector") && Attr("use_inspector"); if (use_static_engine_) { model_opt_cache_dir_ = Attr("model_opt_cache_dir"); } @@ -285,6 +287,9 @@ class TensorRTEngineOp : public framework::OperatorBase { return; } auto *trt_engine = GetEngine(scope, dev_place); + if (use_inspector_) { + trt_engine->GetEngineInfo(); + } if (trt_engine->with_dynamic_shape()) { // get runtime input shapes. std::map> runtime_input_shape; @@ -331,7 +336,6 @@ class TensorRTEngineOp : public framework::OperatorBase { anc = &scope; } PrepareTRTEngine(*anc, trt_engine); - // update shape_range_info_pbtxt if (!shape_range_info_path_.empty()) { inference::UpdateShapeRangeInfo( diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 5b788caeb12..eafd5baab7d 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -615,6 +615,10 @@ void BindAnalysisConfig(py::module *m) { .def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA, py::arg("dla_core") = 0) .def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled) + .def("enable_tensorrt_inspector", + &AnalysisConfig::EnableTensorRtInspector) + .def("tensorrt_inspector_enabled", + &AnalysisConfig::tensorrt_inspector_enabled) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("enable_dlnne", &AnalysisConfig::EnableDlnne, py::arg("min_subgraph_size") = 3) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index e3680104251..94ffbe1fe1a 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -75,6 +75,7 @@ set_tests_properties(test_trt_activation_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_conv_pass PROPERTIES TIMEOUT 120) #set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200) set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120) +set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60) if(WITH_NV_JETSON) set_tests_properties(test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 450) set_tests_properties(test_trt_pool3d_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 450) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index b5a3e1a257e..20d9b9d972d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -122,6 +122,11 @@ class InferencePassTest(unittest.TestCase): self.trt_parameters.precision, self.trt_parameters.use_static, self.trt_parameters.use_calib_mode) + if self.trt_parameters.use_inspector: + config.enable_tensorrt_inspector() + self.assertTrue( + config.tensorrt_inspector_enabled(), + "The inspector option is not set correctly.") if self.dynamic_shape_params: config.set_trt_dynamic_shape_info( @@ -244,14 +249,21 @@ class InferencePassTest(unittest.TestCase): Prepare TensorRT subgraph engine parameters. ''' - def __init__(self, workspace_size, max_batch_size, min_subgraph_size, - precision, use_static, use_calib_mode): + def __init__(self, + workspace_size, + max_batch_size, + min_subgraph_size, + precision, + use_static, + use_calib_mode, + use_inspector=False): self.workspace_size = workspace_size self.max_batch_size = max_batch_size self.min_subgraph_size = min_subgraph_size self.precision = precision self.use_static = use_static self.use_calib_mode = use_calib_mode + self.use_inspector = use_inspector class DynamicShapeParam: ''' diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_inspector.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_inspector.py new file mode 100644 index 00000000000..3d4b2dc10c2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_inspector.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import threading +import time +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig +import subprocess + + +class TensorRTInspectorTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[1, 16, 16], dtype="float32") + matmul_out = fluid.layers.matmul( + x=data, + y=data, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) + out = fluid.layers.batch_norm(matmul_out, is_test=True) + + self.feeds = {"data": np.ones([1, 16, 16]).astype("float32"), } + self.enable_trt = True + self.trt_parameters = InferencePassTest.TensorRTParam( + 1 << 30, 1, 0, AnalysisConfig.Precision.Float32, False, False, True) + self.fetch_list = [out] + + def set_params(self): + self.transpose_x = True + self.transpose_y = True + self.alpha = 2.0 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + build_engine = subprocess.run( + [sys.executable, 'test_trt_inspector.py', '--build-engine'], + stderr=subprocess.PIPE) + engine_info = build_engine.stderr.decode('ascii') + trt_compile_version = paddle.inference.get_trt_compile_version() + trt_runtime_version = paddle.inference.get_trt_runtime_version() + valid_version = (8, 2, 0) + if trt_compile_version >= valid_version and trt_runtime_version >= valid_version: + self.assertTrue('====== engine info ======' in engine_info) + self.assertTrue('====== engine info end ======' in engine_info) + self.assertTrue('matmul' in engine_info) + self.assertTrue('LayerType: Scale' in engine_info) + self.assertTrue('batch_norm' in engine_info) + else: + self.assertTrue( + 'Inspector needs TensorRT version 8.2 and after.' in + engine_info) + + +if __name__ == "__main__": + if '--build-engine' in sys.argv: + test = TensorRTInspectorTest() + test.setUp() + use_gpu = True + test.check_output_with_option(use_gpu) + else: + unittest.main() -- GitLab