未验证 提交 7cb7535e 编写于 作者: W wenbin 提交者: GitHub

fix ernie serialize problem (#36769) (#36791)

Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>
上级 e1b5b1da
...@@ -239,11 +239,11 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -239,11 +239,11 @@ void TensorRTEngine::FreezeNetwork() {
*network(), *infer_builder_config_)); *network(), *infer_builder_config_));
#else #else
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
infer_ptr<nvinfer1::IHostMemory> plan(infer_builder_->buildSerializedNetwork( ihost_memory_.reset(infer_builder_->buildSerializedNetwork(
*network(), *infer_builder_config_)); *network(), *infer_builder_config_));
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_)); infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset( infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(),
runtime->deserializeCudaEngine(plan->data(), plan->size())); ihost_memory_->size()));
#endif #endif
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
......
...@@ -273,7 +273,14 @@ class TensorRTEngine { ...@@ -273,7 +273,14 @@ class TensorRTEngine {
infer_engine_, infer_engine_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The TensorRT engine must be built first before serialization")); "The TensorRT engine must be built first before serialization"));
#if IS_TRT_VERSION_LT(8000)
ihost_memory_.reset(infer_engine_->serialize()); ihost_memory_.reset(infer_engine_->serialize());
#else
PADDLE_ENFORCE_NOT_NULL(
ihost_memory_,
platform::errors::InvalidArgument(
"TensorRT >= 8.0 requires that buildSerializedNetwork is called"));
#endif
return ihost_memory_.get(); return ihost_memory_.get();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册