未验证 提交 43f9f180 编写于 作者: P Pei Yang 提交者: GitHub

Add api to clear intermediate tensors in AnalysisPredictor (#25069)

* add api to clear intemediate tensors in analysis predictor. test=develop

* add python api. test=develop
上级 6bfbb6ab
......@@ -828,6 +828,25 @@ bool AnalysisPredictor::LoadParameters() {
return true;
}
void AnalysisPredictor::ClearIntermediateTensor() {
PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
platform::errors::PreconditionNotMet(
"The inference program should be loaded first."));
const auto &global_block = inference_program_->MutableBlock(0);
for (auto *var : global_block->AllVars()) {
if (!IsPersistable(var)) {
const std::string name = var->Name();
auto *variable = executor_->scope()->FindVar(name);
if (variable != nullptr && variable->IsType<framework::LoDTensor>() &&
name != "feed" && name != "fetch") {
VLOG(3) << "Clear Intermediate Tensor: " << name;
auto *t = variable->GetMutable<framework::LoDTensor>();
t->clear();
}
}
}
}
#if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE(config_.tensorrt_engine_enabled(),
......
......@@ -187,6 +187,12 @@ class AnalysisPredictor : public PaddlePredictor {
///
void OptimizeInferenceProgram();
///
/// \brief Clear the intermediate tensors of the predictor
///
///
void ClearIntermediateTensor();
///
/// \brief Get the argument used by predictor
///
......
......@@ -313,6 +313,12 @@ class PD_INFER_DECL PaddlePredictor {
/// \return Whether the run is successful
virtual bool ZeroCopyRun() { return false; }
///
/// \brief Clear the intermediate tensors of the predictor
///
///
virtual void ClearIntermediateTensor() {}
/// \brief Clone an existing predictor
/// When using clone, the same network will be created,
/// and the parameters between them are shared.
......
......@@ -43,6 +43,7 @@ TEST(AnalysisPredictor, use_gpu) {
std::vector<PaddleTensor> outputs;
for (auto& input : inputs_all) {
ASSERT_TRUE(predictor->Run(input, &outputs));
predictor->ClearIntermediateTensor();
}
}
......
......@@ -501,6 +501,8 @@ void BindAnalysisPredictor(py::module *m) {
.def("get_output_names", &AnalysisPredictor::GetOutputNames)
.def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape)
.def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun)
.def("clear_intermediate_tensor",
&AnalysisPredictor::ClearIntermediateTensor)
.def("create_feed_fetch_var", &AnalysisPredictor::CreateFeedFetchVar)
.def("prepare_feed_fetch", &AnalysisPredictor::PrepareFeedFetch)
.def("prepare_argument", &AnalysisPredictor::PrepareArgument)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册