未验证 提交 ae78940a 编写于 作者: W Wilber 提交者: GitHub

[cherry-pick] inference fix trt problem (#35939)

* update xpu version
上级 0e19aeb9
...@@ -50,7 +50,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -50,7 +50,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG 1c4698c6efd9a5f57a4f8369bd5b6374166f5ba4) set(LITE_GIT_TAG 4ab64daecc11fbf74fffdc6a4733f388472e7d5d)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
......
...@@ -35,7 +35,7 @@ ELSE () ...@@ -35,7 +35,7 @@ ELSE ()
ENDIF() ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210909") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210921")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
...@@ -686,9 +686,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -686,9 +686,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Note, please do NOT use any member variables, because member variables may // Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads. // have been destructed in multiple threads.
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
auto &block = prog->Block(0);
for (auto &op_desc : block.AllOps()) {
if (op_desc->Type() == "tensorrt_engine") {
std::string engine_key =
BOOST_GET_CONST(std::string, op_desc->GetAttr("engine_key"));
int engine_predictor_id =
BOOST_GET_CONST(int, op_desc->GetAttr("predictor_id"));
std::string engine_name =
engine_key + std::to_string(engine_predictor_id);
if (paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Has(engine_name)) {
paddle::inference::Singleton< paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global() inference::tensorrt::TRTEngineManager>::Global()
.DeleteAll(); .DeleteKey(engine_name);
}
}
}
#endif #endif
delete prog; delete prog;
}); });
......
...@@ -631,6 +631,14 @@ class TRTEngineManager { ...@@ -631,6 +631,14 @@ class TRTEngineManager {
} }
} }
void DeleteKey(const std::string& key) {
auto iter = engines_.find(key);
if (iter != engines_.end()) {
iter->second.reset(nullptr);
engines_.erase(iter);
}
}
private: private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_; std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册