未验证 提交 6e0cf610 编写于 作者: T Tian Zheng 提交者: GitHub

Fix trt runtime destroy issue (#53937)

上级 01345a51
...@@ -370,9 +370,9 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -370,9 +370,9 @@ void TensorRTEngine::FreezeNetwork() {
#else #else
ihost_memory_.reset(infer_builder_->buildSerializedNetwork( ihost_memory_.reset(infer_builder_->buildSerializedNetwork(
*network(), *infer_builder_config_)); *network(), *infer_builder_config_));
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_)); infer_runtime_.reset(createInferRuntime(&logger_));
infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(), infer_engine_.reset(infer_runtime_->deserializeCudaEngine(
ihost_memory_->size())); ihost_memory_->data(), ihost_memory_->size()));
#endif #endif
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -559,31 +559,31 @@ std::unordered_map<std::string, nvinfer1::ITensor *> ...@@ -559,31 +559,31 @@ std::unordered_map<std::string, nvinfer1::ITensor *>
void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) { void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) {
freshDeviceId(); freshDeviceId();
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_)); infer_runtime_.reset(createInferRuntime(&logger_));
if (use_dla_) { if (use_dla_) {
if (precision_ != phi::DataType::INT8 && if (precision_ != phi::DataType::INT8 &&
precision_ != phi::DataType::FLOAT16) { precision_ != phi::DataType::FLOAT16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you " LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used."; "set float32, so DLA is not used.";
} else if (runtime->getNbDLACores() == 0) { } else if (infer_runtime_->getNbDLACores() == 0) {
LOG(WARNING) LOG(WARNING)
<< "TensorRT DLA is set by config, but your device does not have " << "TensorRT DLA is set by config, but your device does not have "
"DLA, so DLA is not used."; "DLA, so DLA is not used.";
} else { } else {
if (dla_core_ < 0 || dla_core_ >= runtime->getNbDLACores()) { if (dla_core_ < 0 || dla_core_ >= infer_runtime_->getNbDLACores()) {
dla_core_ = 0; dla_core_ = 0;
LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < " LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
<< runtime->getNbDLACores() << ", but got " << dla_core_ << infer_runtime_->getNbDLACores() << ", but got "
<< ", so use use 0 as default."; << dla_core_ << ", so use use 0 as default.";
} }
runtime->setDLACore(dla_core_); infer_runtime_->setDLACore(dla_core_);
LOG(INFO) << "TensorRT DLA enabled in Deserialize(), DLACore " LOG(INFO) << "TensorRT DLA enabled in Deserialize(), DLACore "
<< dla_core_; << dla_core_;
} }
} }
infer_engine_.reset(runtime->deserializeCudaEngine( infer_engine_.reset(infer_runtime_->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size())); engine_serialized_data.c_str(), engine_serialized_data.size()));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
......
...@@ -811,6 +811,7 @@ class TensorRTEngine { ...@@ -811,6 +811,7 @@ class TensorRTEngine {
// TensorRT related internal members // TensorRT related internal members
infer_ptr<nvinfer1::IBuilder> infer_builder_; infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_; infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::IRuntime> infer_runtime_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_; infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
std::unordered_map<PredictorID, infer_ptr<nvinfer1::IExecutionContext>> std::unordered_map<PredictorID, infer_ptr<nvinfer1::IExecutionContext>>
infer_context_; infer_context_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册