From 526790e652502a3299b079203ec1b69f5633334a Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 28 Jan 2019 14:35:31 +0800 Subject: [PATCH] infer get program (#15511) --- paddle/fluid/inference/api/analysis_predictor.cc | 4 ++++ paddle/fluid/inference/api/analysis_predictor.h | 2 ++ paddle/fluid/inference/api/analysis_predictor_tester.cc | 2 ++ paddle/fluid/inference/api/paddle_api.h | 8 ++++++++ 4 files changed, 16 insertions(+) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7d97aea714a..3a5f21d4756 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -726,6 +726,10 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() { return need; } +std::string AnalysisPredictor::GetSeriazlizedProgram() const { + return inference_program_->Proto()->SerializeAsString(); +} + template <> std::unique_ptr CreatePaddlePredictor( const contrib::AnalysisConfig &config) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 921aa90952d..fa1d0d596df 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -75,6 +75,8 @@ class AnalysisPredictor : public PaddlePredictor { void SetMkldnnThreadID(int tid); + std::string GetSeriazlizedProgram() const override; + protected: // For memory optimization. bool need_collect_var_shapes_for_memory_optim(); diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 4688e93d710..20b61344da9 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -215,6 +215,8 @@ TEST(AnalysisPredictor, memory_optim) { { // The first predictor help to cache the memory optimize strategy. auto predictor = CreatePaddlePredictor(config); + LOG(INFO) << "serialized program: " << predictor->GetSeriazlizedProgram(); + ASSERT_FALSE(predictor->GetSeriazlizedProgram().empty()); // Run several times to check the parameters are not reused by mistake. for (int i = 0; i < 5; i++) { diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 46b510fd1ec..4fc12c294ac 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -215,6 +215,14 @@ class PaddlePredictor { */ virtual ~PaddlePredictor() = default; + /** \brief Get the serialized model program that executes in inference phase. + * Its data type is ProgramDesc, which is a protobuf message. + */ + virtual std::string GetSeriazlizedProgram() const { + assert(false); // Force raise error. + return "NotImplemented"; + }; + /** The common configs for all the predictors. */ struct Config { -- GitLab