未验证 提交 d6b1beb0 编写于 作者: Z zlsh80826 提交者: GitHub

fix ernie serialize problem (#36769)

上级 5e9845b8
......@@ -233,11 +233,11 @@ void TensorRTEngine::FreezeNetwork() {
*network(), *infer_builder_config_));
#else
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
infer_ptr<nvinfer1::IHostMemory> plan(infer_builder_->buildSerializedNetwork(
ihost_memory_.reset(infer_builder_->buildSerializedNetwork(
*network(), *infer_builder_config_));
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(
runtime->deserializeCudaEngine(plan->data(), plan->size()));
infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(),
ihost_memory_->size()));
#endif
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -273,7 +273,14 @@ class TensorRTEngine {
infer_engine_,
platform::errors::InvalidArgument(
"The TensorRT engine must be built first before serialization"));
#if IS_TRT_VERSION_LT(8000)
ihost_memory_.reset(infer_engine_->serialize());
#else
PADDLE_ENFORCE_NOT_NULL(
ihost_memory_,
platform::errors::InvalidArgument(
"TensorRT >= 8.0 requires that buildSerializedNetwork is called"));
#endif
return ihost_memory_.get();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册