diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index e4fc52b6fa..24b6f73949 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -384,6 +384,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && graph->Has(framework::ir::kMultiheadMatmulPass))); + std::unordered_set param_set(params.begin(), params.end()); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( Get("model_opt_cache_dir"), engine_key); @@ -393,6 +394,19 @@ void TensorRtSubgraphPass::CreateTensorRTOp( LOG(INFO) << "Load TRT Optimized Info from " << GetTrtEngineSerializedPath( Get("model_opt_cache_dir"), engine_key); + const auto* root_scope{param_scope()}; + for (;root_scope->parent();) { + root_scope = root_scope->parent(); + } + for (const auto& name: param_set) { + LOG(INFO) << " ===== Clear param: " << name; + root_scope->FindLocalVar(name)->Clear(); + } + for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount(); + ++dev_id) { + memory::Release(platform::CUDAPlace(dev_id)); + } + memory::Release(platform::CPUPlace()); return; } } @@ -405,12 +419,25 @@ void TensorRtSubgraphPass::CreateTensorRTOp( auto *scope = param_scope(); framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto()); - std::unordered_set param_set(params.begin(), params.end()); inference::Singleton::Global() .ConvertBlockToTRTEngine( &block_desc_temp, *scope, std::vector(input_names.begin(), input_names.end()), param_set, output_mapping, trt_engine); + const auto* root_scope{scope}; + for (;root_scope->parent();) { + root_scope = root_scope->parent(); + } + VLOG(4) << "root_scope->LocalVarNames().size: " << root_scope->LocalVarNames().size(); + for (const auto& name: param_set) { + VLOG(4) << " ===== Clear param: " << name; + root_scope->FindLocalVar(name)->Clear(); + } + for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount(); + ++dev_id) { + memory::Release(platform::CUDAPlace(dev_id)); + } + memory::Release(platform::CPUPlace()); if (use_static_engine) { nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize(); @@ -425,6 +452,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( << GetTrtEngineSerializedPath( Get("model_opt_cache_dir"), engine_key); } + trt_engine_serialized_data.clear(); + trt_engine_serialized_data.shrink_to_fit(); } } // namespace analysis