未验证 提交 8d0922ed 编写于 作者: W Wilber 提交者: GitHub

fix trt problem (#35938)

上级 9b8aafe5
...@@ -50,7 +50,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -50,7 +50,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204) set(LITE_GIT_TAG 4ab64daecc11fbf74fffdc6a4733f388472e7d5d)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
......
...@@ -686,9 +686,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -686,9 +686,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Note, please do NOT use any member variables, because member variables may // Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads. // have been destructed in multiple threads.
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
paddle::inference::Singleton< auto &block = prog->Block(0);
inference::tensorrt::TRTEngineManager>::Global() for (auto &op_desc : block.AllOps()) {
.DeleteAll(); if (op_desc->Type() == "tensorrt_engine") {
std::string engine_key =
BOOST_GET_CONST(std::string, op_desc->GetAttr("engine_key"));
int engine_predictor_id =
BOOST_GET_CONST(int, op_desc->GetAttr("predictor_id"));
std::string engine_name =
engine_key + std::to_string(engine_predictor_id);
if (paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Has(engine_name)) {
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteKey(engine_name);
}
}
}
#endif #endif
delete prog; delete prog;
}); });
......
...@@ -631,6 +631,14 @@ class TRTEngineManager { ...@@ -631,6 +631,14 @@ class TRTEngineManager {
} }
} }
void DeleteKey(const std::string& key) {
auto iter = engines_.find(key);
if (iter != engines_.end()) {
iter->second.reset(nullptr);
engines_.erase(iter);
}
}
private: private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_; std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册