You need to sign in or sign up before continuing.
未验证 提交 05b27695 编写于 作者: S Shang Zhizhou 提交者: GitHub

add inference api: DisableTensorRtOps (#30109)

* snap

* add inference api: DisableTensorRtOPs

* fix code style

* update api to experimental

* update variable name
上级 53bb1265
......@@ -202,6 +202,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(tensorrt_disabled_ops, TensorRtDisabledOPs,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
......
......@@ -141,6 +141,8 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("optim_input_shape",
new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
pass->Set("trt_disabled_ops", new std::vector<std::string>(
argument->tensorrt_disabled_ops()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not
// run fp16.
......
......@@ -39,8 +39,15 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
bool no_calib_int8 = enable_int8 && !(use_calib_mode);
auto trt_disabled_ops = Get<std::vector<std::string>>("trt_disabled_ops");
auto teller = [&](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false;
if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(),
node->Op()->Type()) != trt_disabled_ops.end()) {
VLOG(3) << node->Op()->Type().c_str()
<< " is diabled by config in TensorRT";
return false;
}
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(),
no_calib_int8);
};
......
......@@ -125,6 +125,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(tensorrt_max_batchsize_);
CP_MEMBER(tensorrt_min_subgraph_size_);
CP_MEMBER(tensorrt_precision_mode_);
CP_MEMBER(trt_disabled_ops_);
CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_);
CP_MEMBER(trt_use_oss_);
......@@ -304,6 +305,11 @@ void AnalysisConfig::SetTRTDynamicShapeInfo(
disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}
void AnalysisConfig::Exp_DisableTensorRtOPs(
const std::vector<std::string> &ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}
void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; }
// TODO(Superjomn) refactor this, buggy.
......@@ -443,6 +449,9 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << tensorrt_max_batchsize_;
ss << tensorrt_min_subgraph_size_;
for (auto &op : trt_disabled_ops_) ss << op.c_str();
ss << ";";
ss << enable_memory_optim_;
ss << use_mkldnn_;
......
......@@ -476,6 +476,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
......
......@@ -313,10 +313,17 @@ struct PD_INFER_DECL AnalysisConfig {
std::map<std::string, std::vector<int>> optim_input_shape,
bool disable_trt_plugin_fp16 = false);
///
/// \brief Prevent ops running in Paddle-TRT
/// NOTE: just experimental, not an official stable API, easy to be broken.
///
void Exp_DisableTensorRtOPs(const std::vector<std::string>& ops);
///
/// \brief Replace some TensorRT plugins to TensorRT OSS(
/// https://github.com/NVIDIA/TensorRT), with which some models's inference may
/// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is needed.
/// https://github.com/NVIDIA/TensorRT), with which some models's inference
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed.
///
void EnableTensorRtOSS();
///
......@@ -587,6 +594,7 @@ struct PD_INFER_DECL AnalysisConfig {
std::map<std::string, std::vector<int>> min_input_shape_{};
std::map<std::string, std::vector<int>> max_input_shape_{};
std::map<std::string, std::vector<int>> optim_input_shape_{};
std::vector<std::string> trt_disabled_ops_{};
bool disable_trt_plugin_fp16_{false};
// memory reuse related.
......
......@@ -57,6 +57,7 @@ TEST(PredictorPool, use_gpu) {
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.EnableTensorRtEngine();
config.Exp_DisableTensorRtOPs({"fc"});
services::PredictorPool pred_pool(config, 1);
auto predictor = pred_pool.Retrive(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册