diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index e151ef124a7510d3ea9bc16606bdbffe5c69091a..f20ef14177a0b9889f775e40f24bd99c8a65215a 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -239,11 +239,11 @@ void TensorRTEngine::FreezeNetwork() { *network(), *infer_builder_config_)); #else infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - infer_ptr plan(infer_builder_->buildSerializedNetwork( + ihost_memory_.reset(infer_builder_->buildSerializedNetwork( *network(), *infer_builder_config_)); infer_ptr runtime(createInferRuntime(&logger_)); - infer_engine_.reset( - runtime->deserializeCudaEngine(plan->data(), plan->size())); + infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(), + ihost_memory_->size())); #endif PADDLE_ENFORCE_NOT_NULL( diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 8cb0d36277def86ead541873f321c233af6e2104..c7a5eca5a11b49f76b95268a128c199907148472 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -273,7 +273,14 @@ class TensorRTEngine { infer_engine_, platform::errors::InvalidArgument( "The TensorRT engine must be built first before serialization")); +#if IS_TRT_VERSION_LT(8000) ihost_memory_.reset(infer_engine_->serialize()); +#else + PADDLE_ENFORCE_NOT_NULL( + ihost_memory_, + platform::errors::InvalidArgument( + "TensorRT >= 8.0 requires that buildSerializedNetwork is called")); +#endif return ihost_memory_.get(); }