tensorrt_subgraph_pass.cc.patch 2.9 KB
Newer Older
S
Shang Zhizhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
index e4fc52b6fa..24b6f73949 100644
--- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
@@ -384,6 +384,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
       (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
        graph->Has(framework::ir::kMultiheadMatmulPass)));
 
+  std::unordered_set<std::string> param_set(params.begin(), params.end());
   if (use_static_engine) {
     trt_engine_serialized_data = GetTrtEngineSerializedData(
         Get<std::string>("model_opt_cache_dir"), engine_key);
@@ -393,6 +394,19 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
       LOG(INFO) << "Load TRT Optimized Info from "
                 << GetTrtEngineSerializedPath(
                        Get<std::string>("model_opt_cache_dir"), engine_key);
+      const auto* root_scope{param_scope()};
+      for (;root_scope->parent();) {
+        root_scope = root_scope->parent();
+      }
+      for (const auto& name: param_set) {
+        LOG(INFO) << " ===== Clear param: " << name;
+        root_scope->FindLocalVar(name)->Clear();
+      }
+      for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount();
+          ++dev_id) {
+        memory::Release(platform::CUDAPlace(dev_id));
+      }
+      memory::Release(platform::CPUPlace());
       return;
     }
   }
@@ -405,12 +419,25 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
 
   auto *scope = param_scope();
   framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
-  std::unordered_set<std::string> param_set(params.begin(), params.end());
   inference::Singleton<inference::tensorrt::OpConverter>::Global()
       .ConvertBlockToTRTEngine(
           &block_desc_temp, *scope,
           std::vector<std::string>(input_names.begin(), input_names.end()),
           param_set, output_mapping, trt_engine);
+  const auto* root_scope{scope};
+  for (;root_scope->parent();) {
+    root_scope = root_scope->parent();
+  }
+  VLOG(4) << "root_scope->LocalVarNames().size: " << root_scope->LocalVarNames().size();
+  for (const auto& name: param_set) {
+    VLOG(4) << "  ===== Clear param: " << name;
+    root_scope->FindLocalVar(name)->Clear();
+  }
+  for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount();
+       ++dev_id) {
+    memory::Release(platform::CUDAPlace(dev_id));
+  }
+  memory::Release(platform::CPUPlace());
 
   if (use_static_engine) {
     nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
@@ -425,6 +452,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
               << GetTrtEngineSerializedPath(
                      Get<std::string>("model_opt_cache_dir"), engine_key);
   }
+  trt_engine_serialized_data.clear();
+  trt_engine_serialized_data.shrink_to_fit();
 }
 
 }  // namespace analysis