未验证 提交 a8dfff99 编写于 作者: S Shang Zhizhou 提交者: GitHub

add DLA support:C++&&Python api (#30165) (#30810)

* add dla

* add python api
Co-authored-by: Nshangzhizhou <root@szth-rp-fanyi-opera49.szth.baidu.com>
Co-authored-by: Nshangzhizhou <root@szth-rp-fanyi-opera49.szth.baidu.com>
上级 370b3f36
......@@ -199,6 +199,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_dla, TensorRtUseDLA, bool);
DECL_ARGUMENT_FIELD(tensorrt_dla_core, TensorRtDLACore, int);
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
......
......@@ -143,6 +143,8 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->optim_input_shape()));
pass->Set("trt_disabled_ops", new std::vector<std::string>(
argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not
// run fp16.
......
......@@ -320,6 +320,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
......
......@@ -126,6 +126,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(tensorrt_min_subgraph_size_);
CP_MEMBER(tensorrt_precision_mode_);
CP_MEMBER(trt_disabled_ops_);
CP_MEMBER(trt_use_dla_);
CP_MEMBER(trt_dla_core_);
CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_);
CP_MEMBER(trt_use_oss_);
......@@ -305,6 +307,11 @@ void AnalysisConfig::SetTRTDynamicShapeInfo(
disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
trt_use_dla_ = true;
trt_dla_core_ = dla_core;
}
void AnalysisConfig::Exp_DisableTensorRtOPs(
const std::vector<std::string> &ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
......@@ -452,6 +459,9 @@ std::string AnalysisConfig::SerializeInfoCache() {
for (auto &op : trt_disabled_ops_) ss << op.c_str();
ss << ";";
ss << trt_use_dla_;
ss << trt_dla_core_;
ss << enable_memory_optim_;
ss << use_mkldnn_;
......
......@@ -477,6 +477,8 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_.SetTensorRtUseDLA(config_.trt_use_dla_);
argument_.SetTensorRtDLACore(config_.trt_dla_core_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
......
......@@ -326,6 +326,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// V7.2.1 is needed.
///
void EnableTensorRtOSS();
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
///
......@@ -333,6 +334,20 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool tensorrt_oss_enabled() { return trt_use_oss_; }
///
/// \brief Enable TensorRT DLA
/// \param dla_core ID of DLACore, which should be 0, 1,
/// ..., IBuilder.getNbDLACores() - 1
///
void EnableTensorRtDLA(int dla_core = 0);
///
/// \brief A boolean state telling whether to use the TensorRT DLA.
///
/// \return bool Whether to use the TensorRT DLA.
///
bool tensorrt_dla_enabled() { return trt_use_dla_; }
///
/// \brief Turn on the usage of Lite sub-graph engine.
///
......@@ -591,6 +606,8 @@ struct PD_INFER_DECL AnalysisConfig {
bool trt_use_static_engine_{false};
bool trt_use_calib_mode_{true};
bool trt_use_oss_{false};
bool trt_use_dla_{false};
int trt_dla_core_{0};
std::map<std::string, std::vector<int>> min_input_shape_{};
std::map<std::string, std::vector<int>> max_input_shape_{};
std::map<std::string, std::vector<int>> optim_input_shape_{};
......
......@@ -176,6 +176,29 @@ void TensorRTEngine::FreezeNetwork() {
}
}
if (use_dla_) {
if (!enable_int8 && !enable_fp16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used.";
} else if (infer_builder_->getNbDLACores() == 0) {
LOG(WARNING)
<< "TensorRT DLA is set by config, but your device does not have "
"DLA, so DLA is not used.";
} else {
if (dla_core_ < 0 || dla_core_ >= infer_builder_->getNbDLACores()) {
dla_core_ = 0;
LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
<< infer_builder_->getNbDLACores() << ", but got "
<< dla_core_ << ", so use use 0 as default.";
}
infer_builder_->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
infer_builder_->setDLACore(dla_core_);
infer_builder_->allowGPUFallback(true);
LOG(INFO) << "TensorRT DLA enabled in FreezeNetwork(), DLACore "
<< dla_core_;
}
}
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
......
......@@ -220,6 +220,29 @@ class TensorRTEngine {
void Deserialize(const std::string& engine_serialized_data) {
freshDeviceId();
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
if (use_dla_) {
if (precision_ != AnalysisConfig::Precision::kInt8 &&
precision_ != AnalysisConfig::Precision::kHalf) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used.";
} else if (runtime->getNbDLACores() == 0) {
LOG(WARNING)
<< "TensorRT DLA is set by config, but your device does not have "
"DLA, so DLA is not used.";
} else {
if (dla_core_ < 0 || dla_core_ >= runtime->getNbDLACores()) {
dla_core_ = 0;
LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
<< runtime->getNbDLACores() << ", but got " << dla_core_
<< ", so use use 0 as default.";
}
runtime->setDLACore(dla_core_);
LOG(INFO) << "TensorRT DLA enabled in Deserialize(), DLACore "
<< dla_core_;
}
}
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
infer_engine_.reset(runtime->deserializeCudaEngine(
......@@ -287,6 +310,8 @@ class TensorRTEngine {
}
void SetUseOSS(bool use_oss) { use_oss_ = use_oss; }
void SetUseDLA(bool use_dla) { use_dla_ = use_dla; }
void SetDLACore(int dla_core) { dla_core_ = dla_core; }
void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; }
void ClearWeights() {
......@@ -316,8 +341,8 @@ class TensorRTEngine {
ShapeMapType min_input_shape() { return min_input_shape_; }
ShapeMapType max_input_shape() { return max_input_shape_; }
ShapeMapType optim_input_shape() { return optim_input_shape_; }
bool use_oss() { return use_oss_; };
bool with_ernie() { return with_ernie_; };
bool use_oss() { return use_oss_; }
bool with_ernie() { return with_ernie_; }
bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; }
bool with_dynamic_shape() { return with_dynamic_shape_; }
......@@ -354,6 +379,8 @@ class TensorRTEngine {
ShapeMapType optim_input_shape_;
bool disable_trt_plugin_fp16_{false};
bool use_oss_{false};
bool use_dla_{false};
int dla_core_{0};
bool with_ernie_{false};
nvinfer1::ILogger& logger_;
......
......@@ -58,6 +58,7 @@ TEST(PredictorPool, use_gpu) {
config.SetModel(model_dir);
config.EnableTensorRtEngine();
config.Exp_DisableTensorRtOPs({"fc"});
config.EnableTensorRtDLA(0);
services::PredictorPool pred_pool(config, 1);
auto predictor = pred_pool.Retrive(0);
......
......@@ -495,6 +495,9 @@ void BindAnalysisConfig(py::module *m) {
py::arg("disable_trt_plugin_fp16") = false)
.def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS)
.def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled)
.def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA,
py::arg("dla_core") = 0)
.def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册