未验证 提交 6e0cf610 编写于 作者: T Tian Zheng 提交者: GitHub

Fix trt runtime destroy issue (#53937)

上级 01345a51
......@@ -370,9 +370,9 @@ void TensorRTEngine::FreezeNetwork() {
#else
ihost_memory_.reset(infer_builder_->buildSerializedNetwork(
*network(), *infer_builder_config_));
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(),
ihost_memory_->size()));
infer_runtime_.reset(createInferRuntime(&logger_));
infer_engine_.reset(infer_runtime_->deserializeCudaEngine(
ihost_memory_->data(), ihost_memory_->size()));
#endif
PADDLE_ENFORCE_NOT_NULL(
......@@ -559,31 +559,31 @@ std::unordered_map<std::string, nvinfer1::ITensor *>
void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) {
freshDeviceId();
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_runtime_.reset(createInferRuntime(&logger_));
if (use_dla_) {
if (precision_ != phi::DataType::INT8 &&
precision_ != phi::DataType::FLOAT16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used.";
} else if (runtime->getNbDLACores() == 0) {
} else if (infer_runtime_->getNbDLACores() == 0) {
LOG(WARNING)
<< "TensorRT DLA is set by config, but your device does not have "
"DLA, so DLA is not used.";
} else {
if (dla_core_ < 0 || dla_core_ >= runtime->getNbDLACores()) {
if (dla_core_ < 0 || dla_core_ >= infer_runtime_->getNbDLACores()) {
dla_core_ = 0;
LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
<< runtime->getNbDLACores() << ", but got " << dla_core_
<< ", so use use 0 as default.";
<< infer_runtime_->getNbDLACores() << ", but got "
<< dla_core_ << ", so use use 0 as default.";
}
runtime->setDLACore(dla_core_);
infer_runtime_->setDLACore(dla_core_);
LOG(INFO) << "TensorRT DLA enabled in Deserialize(), DLACore "
<< dla_core_;
}
}
infer_engine_.reset(runtime->deserializeCudaEngine(
infer_engine_.reset(infer_runtime_->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -811,6 +811,7 @@ class TensorRTEngine {
// TensorRT related internal members
infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::IRuntime> infer_runtime_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
std::unordered_map<PredictorID, infer_ptr<nvinfer1::IExecutionContext>>
infer_context_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册