From 05b27695f10dd2ff2e5214ecf8fe84864167dcd5 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Wed, 6 Jan 2021 15:54:40 +0800 Subject: [PATCH] =?UTF-8?q?add=20inference=20api=EF=BC=9A=20DisableTensorR?= =?UTF-8?q?tOps=20(#30109)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * snap * add inference api: DisableTensorRtOPs * fix code style * update api to experimental * update variable name --- paddle/fluid/inference/analysis/argument.h | 2 ++ paddle/fluid/inference/analysis/ir_pass_manager.cc | 2 ++ .../analysis/ir_passes/tensorrt_subgraph_pass.cc | 7 +++++++ paddle/fluid/inference/api/analysis_config.cc | 9 +++++++++ paddle/fluid/inference/api/analysis_predictor.cc | 1 + paddle/fluid/inference/api/paddle_analysis_config.h | 12 ++++++++++-- .../fluid/inference/tests/api/trt_mobilenet_test.cc | 1 + 7 files changed, 32 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index aa8ebcb493..1bf106ed7c 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -202,6 +202,8 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); + DECL_ARGUMENT_FIELD(tensorrt_disabled_ops, TensorRtDisabledOPs, + std::vector); DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, AnalysisConfig::Precision); DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 3566b85691..a6466c32af 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -141,6 +141,8 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("optim_input_shape", new std::map>( argument->optim_input_shape())); + pass->Set("trt_disabled_ops", new std::vector( + argument->tensorrt_disabled_ops())); // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will // not // run fp16. diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index c84bba33be..61117cc603 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -39,8 +39,15 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( auto enable_int8 = Get("enable_int8"); auto use_calib_mode = Get("use_calib_mode"); bool no_calib_int8 = enable_int8 && !(use_calib_mode); + auto trt_disabled_ops = Get>("trt_disabled_ops"); auto teller = [&](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; + if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), + node->Op()->Type()) != trt_disabled_ops.end()) { + VLOG(3) << node->Op()->Type().c_str() + << " is diabled by config in TensorRT"; + return false; + } return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(), no_calib_int8); }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 92e1404b6a..fcef2a5cbc 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -125,6 +125,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(tensorrt_max_batchsize_); CP_MEMBER(tensorrt_min_subgraph_size_); CP_MEMBER(tensorrt_precision_mode_); + CP_MEMBER(trt_disabled_ops_); CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_calib_mode_); CP_MEMBER(trt_use_oss_); @@ -304,6 +305,11 @@ void AnalysisConfig::SetTRTDynamicShapeInfo( disable_trt_plugin_fp16_ = disable_trt_plugin_fp16; } +void AnalysisConfig::Exp_DisableTensorRtOPs( + const std::vector &ops) { + trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); +} + void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; } // TODO(Superjomn) refactor this, buggy. @@ -443,6 +449,9 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << tensorrt_max_batchsize_; ss << tensorrt_min_subgraph_size_; + for (auto &op : trt_disabled_ops_) ss << op.c_str(); + ss << ";"; + ss << enable_memory_optim_; ss << use_mkldnn_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 4603702cde..d47a9536ab 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -476,6 +476,7 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); + argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index edf2c323e8..ccc971f99b 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -313,10 +313,17 @@ struct PD_INFER_DECL AnalysisConfig { std::map> optim_input_shape, bool disable_trt_plugin_fp16 = false); + /// + /// \brief Prevent ops running in Paddle-TRT + /// NOTE: just experimental, not an official stable API, easy to be broken. + /// + void Exp_DisableTensorRtOPs(const std::vector& ops); + /// /// \brief Replace some TensorRT plugins to TensorRT OSS( - /// https://github.com/NVIDIA/TensorRT), with which some models's inference may - /// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is needed. + /// https://github.com/NVIDIA/TensorRT), with which some models's inference + /// may be more high-performance. Libnvinfer_plugin.so greater than + /// V7.2.1 is needed. /// void EnableTensorRtOSS(); /// @@ -587,6 +594,7 @@ struct PD_INFER_DECL AnalysisConfig { std::map> min_input_shape_{}; std::map> max_input_shape_{}; std::map> optim_input_shape_{}; + std::vector trt_disabled_ops_{}; bool disable_trt_plugin_fp16_{false}; // memory reuse related. diff --git a/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc b/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc index c7c7356b6e..4a84a972ba 100644 --- a/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc +++ b/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc @@ -57,6 +57,7 @@ TEST(PredictorPool, use_gpu) { config.EnableUseGpu(100, 0); config.SetModel(model_dir); config.EnableTensorRtEngine(); + config.Exp_DisableTensorRtOPs({"fc"}); services::PredictorPool pred_pool(config, 1); auto predictor = pred_pool.Retrive(0); -- GitLab