diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index df4b0079c79a070f86a123e3c1d64e460c854871..195053814e6a00094dd7cf59d2c1331d96e3d634 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -828,6 +828,25 @@ bool AnalysisPredictor::LoadParameters() { return true; } +void AnalysisPredictor::ClearIntermediateTensor() { + PADDLE_ENFORCE_NOT_NULL(inference_program_.get(), + platform::errors::PreconditionNotMet( + "The inference program should be loaded first.")); + const auto &global_block = inference_program_->MutableBlock(0); + for (auto *var : global_block->AllVars()) { + if (!IsPersistable(var)) { + const std::string name = var->Name(); + auto *variable = executor_->scope()->FindVar(name); + if (variable != nullptr && variable->IsType() && + name != "feed" && name != "fetch") { + VLOG(3) << "Clear Intermediate Tensor: " << name; + auto *t = variable->GetMutable(); + t->clear(); + } + } + } +} + #if PADDLE_WITH_TENSORRT bool AnalysisPredictor::SaveTrtCalibToDisk() { PADDLE_ENFORCE(config_.tensorrt_engine_enabled(), diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 267817829ec4598808486fd3ea5df241a1466e22..365f86c21105a7f1ffb7c300e0ab38c6aaa230fc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -187,6 +187,12 @@ class AnalysisPredictor : public PaddlePredictor { /// void OptimizeInferenceProgram(); + /// + /// \brief Clear the intermediate tensors of the predictor + /// + /// + void ClearIntermediateTensor(); + /// /// \brief Get the argument used by predictor /// diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index bf243bf9a45ebb67a0b6bc356ac2697decd1e300..386d20103a71acb34cd47ddf5527f580cc5bf5b1 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -313,6 +313,12 @@ class PD_INFER_DECL PaddlePredictor { /// \return Whether the run is successful virtual bool ZeroCopyRun() { return false; } + /// + /// \brief Clear the intermediate tensors of the predictor + /// + /// + virtual void ClearIntermediateTensor() {} + /// \brief Clone an existing predictor /// When using clone, the same network will be created, /// and the parameters between them are shared. diff --git a/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc b/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc index 1dbdcccf41ba3a42dd21982cd9fac86f5e767382..8ffa3efdf0556bd7cde7efa615f60853ad18d903 100644 --- a/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc +++ b/paddle/fluid/inference/tests/api/trt_mobilenet_test.cc @@ -43,6 +43,7 @@ TEST(AnalysisPredictor, use_gpu) { std::vector outputs; for (auto& input : inputs_all) { ASSERT_TRUE(predictor->Run(input, &outputs)); + predictor->ClearIntermediateTensor(); } } diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 1b6c407e6bf1a2a38752acb3c096bbdc64c36da6..5a0b18a34f768f3fb4392abf1d796feb951990c3 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -501,6 +501,8 @@ void BindAnalysisPredictor(py::module *m) { .def("get_output_names", &AnalysisPredictor::GetOutputNames) .def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape) .def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun) + .def("clear_intermediate_tensor", + &AnalysisPredictor::ClearIntermediateTensor) .def("create_feed_fetch_var", &AnalysisPredictor::CreateFeedFetchVar) .def("prepare_feed_fetch", &AnalysisPredictor::PrepareFeedFetch) .def("prepare_argument", &AnalysisPredictor::PrepareArgument)