diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index aa8ebcb4930b997e5d4fc51fe2e7a534d6c75b48..1bf106ed7c1a1c88515b5af98b5cff6a93a2fbe4 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -202,6 +202,8 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); + DECL_ARGUMENT_FIELD(tensorrt_disabled_ops, TensorRtDisabledOPs, + std::vector); DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, AnalysisConfig::Precision); DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 3566b856912da86125f6b03943dbdf8356d635b3..a6466c32af80de3789e1d173a1738a5100675ec7 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -141,6 +141,8 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("optim_input_shape", new std::map>( argument->optim_input_shape())); + pass->Set("trt_disabled_ops", new std::vector( + argument->tensorrt_disabled_ops())); // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will // not // run fp16. diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 158c834c256f59bbb89c2938b268d077181285eb..10204271c42d6dbf1a01adfdb8bc60f20ee2baf7 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -39,8 +39,15 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( auto enable_int8 = Get("enable_int8"); auto use_calib_mode = Get("use_calib_mode"); bool no_calib_int8 = enable_int8 && !(use_calib_mode); + auto trt_disabled_ops = Get>("trt_disabled_ops"); auto teller = [&](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; + if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), + node->Op()->Type()) != trt_disabled_ops.end()) { + VLOG(3) << node->Op()->Type().c_str() + << " is diabled by config in TensorRT"; + return false; + } return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(), no_calib_int8); }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 92e1404b6adbfd30ba099e408d821d919d3f6e0e..fcef2a5cbc9ab96a226bef43206a0f0679571eef 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -125,6 +125,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(tensorrt_max_batchsize_); CP_MEMBER(tensorrt_min_subgraph_size_); CP_MEMBER(tensorrt_precision_mode_); + CP_MEMBER(trt_disabled_ops_); CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_calib_mode_); CP_MEMBER(trt_use_oss_); @@ -304,6 +305,11 @@ void AnalysisConfig::SetTRTDynamicShapeInfo( disable_trt_plugin_fp16_ = disable_trt_plugin_fp16; } +void AnalysisConfig::Exp_DisableTensorRtOPs( + const std::vector &ops) { + trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); +} + void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; } // TODO(Superjomn) refactor this, buggy. @@ -443,6 +449,9 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << tensorrt_max_batchsize_; ss << tensorrt_min_subgraph_size_; + for (auto &op : trt_disabled_ops_) ss << op.c_str(); + ss << ";"; + ss << enable_memory_optim_; ss << use_mkldnn_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 4603702cde1fc68ccdaddaef64e2c0042cede910..d47a9536abc63b260d82cf3c9fd5f354a993b612 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -476,6 +476,7 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); + argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index edf2c323e82fbe2b736383f8e07de36c07098fb2..ccc971f99bb2bdc5d0b8884a758895e872edb654 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -313,10 +313,17 @@ struct PD_INFER_DECL AnalysisConfig { std::map> optim_input_shape, bool disable_trt_plugin_fp16 = false); + /// + /// \brief Prevent ops running in Paddle-TRT + /// NOTE: just experimental, not an official stable API, easy to be broken. + /// + void Exp_DisableTensorRtOPs(const std::vector& ops); + /// /// \brief Replace some TensorRT plugins to TensorRT OSS( - /// https://github.com/NVIDIA/TensorRT), with which some models's inference may - /// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is needed. + /// https://github.com/NVIDIA/TensorRT), with which some models's inference + /// may be more high-performance. Libnvinfer_plugin.so greater than + /// V7.2.1 is needed. /// void EnableTensorRtOSS(); /// @@ -587,6 +594,7 @@ struct PD_INFER_DECL AnalysisConfig { std::map> min_input_shape_{}; std::map> max_input_shape_{}; std::map> optim_input_shape_{}; + std::vector trt_disabled_ops_{}; bool disable_trt_plugin_fp16_{false}; // memory reuse related. diff --git a/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc b/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc index c7c7356b6e8831bc0bcd0e9ea4ad0fbdec8b6be2..4a84a972bacadbc3ae4f2705018dca4628ddc6a9 100644 --- a/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc +++ b/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc @@ -57,6 +57,7 @@ TEST(PredictorPool, use_gpu) { config.EnableUseGpu(100, 0); config.SetModel(model_dir); config.EnableTensorRtEngine(); + config.Exp_DisableTensorRtOPs({"fc"}); services::PredictorPool pred_pool(config, 1); auto predictor = pred_pool.Retrive(0);