diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 7a9c5b889d..c847a5d523 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -418,6 +418,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( graph->Has(framework::ir::kEmbEltwiseLayernormPass) && graph->Has(framework::ir::kMultiheadMatmulPass)); + std::unordered_set param_set(params.begin(), params.end()); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( Get("model_opt_cache_dir"), engine_key); @@ -427,6 +428,19 @@ void TensorRtSubgraphPass::CreateTensorRTOp( LOG(INFO) << "Load TRT Optimized Info from " << GetTrtEngineSerializedPath( Get("model_opt_cache_dir"), engine_key); + const auto* root_scope{param_scope()}; + for (;root_scope->parent();) { + root_scope = root_scope->parent(); + } + for (const auto& name: param_set) { + LOG(INFO) << " ===== Clear param: " << name; + root_scope->FindLocalVar(name)->Clear(); + } + for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount(); + ++dev_id) { + memory::Release(platform::CUDAPlace(dev_id)); + } + memory::Release(platform::CPUPlace()); return; } } @@ -439,7 +453,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp( auto *scope = param_scope(); framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto()); - std::unordered_set param_set(params.begin(), params.end()); inference::Singleton::Global() .ConvertBlockToTRTEngine( &block_desc_temp, @@ -449,6 +462,21 @@ void TensorRtSubgraphPass::CreateTensorRTOp( output_mapping, trt_engine); + const auto* root_scope{scope}; + for (;root_scope->parent();) { + root_scope = root_scope->parent(); + } + VLOG(4) << "root_scope->LocalVarNames().size: " << root_scope->LocalVarNames().size(); + for (const auto& name: param_set) { + VLOG(4) << " ===== Clear param: " << name; + root_scope->FindLocalVar(name)->Clear(); + } + for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount(); + ++dev_id) { + memory::Release(platform::CUDAPlace(dev_id)); + } + memory::Release(platform::CPUPlace()); + if (use_static_engine) { nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize(); trt_engine_serialized_data = @@ -462,6 +490,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( << GetTrtEngineSerializedPath( Get("model_opt_cache_dir"), engine_key); } + trt_engine_serialized_data.clear(); + trt_engine_serialized_data.shrink_to_fit(); } } // namespace analysis