From a8ae87f118ddde049bd5c60c4493a667206f8055 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Thu, 18 Aug 2022 08:49:05 -0500 Subject: [PATCH] [inference]predictor add GetInputType interface (#45143) * predictor add GetInputType interface * predictor change GetInputType to GetInputTypes * predictor add tester * predictor add tester * predictor change GetInputType to GetInputTypes * predictor change GetInputType to GetInputTypes * predictor add tester --- .../fluid/inference/api/analysis_predictor.cc | 35 +++++++++++++++++++ .../fluid/inference/api/analysis_predictor.h | 6 ++++ .../api/analysis_predictor_tester.cc | 2 ++ paddle/fluid/inference/api/paddle_api.h | 6 ++++ .../inference/api/paddle_inference_api.h | 7 ++++ 5 files changed, 56 insertions(+) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 49045089ce5..fa927a7da22 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1454,6 +1454,37 @@ AnalysisPredictor::GetInputTensorShape() { return input_shapes; } +std::map +AnalysisPredictor::GetInputTypes() { + std::map input_type; + std::vector names = GetInputNames(); + for (const auto &name : names) { + auto *var = inference_program_->Block(0).FindVar(name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PreconditionNotMet( + "Input %s does not exist inference_program_.", name)); + auto dtype = var->GetDataType(); + if (dtype == paddle::framework::proto::VarType::FP32) { + input_type[name] = paddle_infer::DataType::FLOAT32; + } else if (dtype == paddle::framework::proto::VarType::FP16) { + input_type[name] = paddle_infer::DataType::FLOAT16; + } else if (dtype == paddle::framework::proto::VarType::INT64) { + input_type[name] = paddle_infer::DataType::INT64; + } else if (dtype == paddle::framework::proto::VarType::INT32) { + input_type[name] = paddle_infer::DataType::INT32; + } else if (dtype == paddle::framework::proto::VarType::UINT8) { + input_type[name] = paddle_infer::DataType::UINT8; + } else if (dtype == paddle::framework::proto::VarType::INT8) { + input_type[name] = paddle_infer::DataType::INT8; + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported data type `%s` when get input dtype ", dtype)); + } + } + return input_type; +} + std::vector AnalysisPredictor::GetOutputNames() { std::vector output_names; for (auto &item : idx2fetches_) { @@ -2172,6 +2203,10 @@ std::vector Predictor::GetInputNames() { return predictor_->GetInputNames(); } +std::map Predictor::GetInputTypes() { + return predictor_->GetInputTypes(); +} + std::unique_ptr Predictor::GetInputHandle(const std::string &name) { return predictor_->GetInputTensor(name); } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 0835f712b6e..23571425755 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -173,6 +173,12 @@ class AnalysisPredictor : public PaddlePredictor { /// \return the map of input names and shapes /// std::map> GetInputTensorShape() override; + /// + /// \brief Get all input names and their corresponding type + /// + /// \return the map of input names and type + /// + std::map GetInputTypes() override; /// /// \brief Run the prediction engine diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index a7871737ad4..8856ceb61a7 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -105,6 +105,7 @@ TEST(AnalysisPredictor, analysis_on) { ASSERT_TRUE(predictor->sub_scope_); ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); + ASSERT_EQ(predictor->GetInputTypes().size(), 4UL); // 2. Dummy Input Data int64_t data[4] = {1, 2, 3, 4}; PaddleTensor tensor; @@ -389,6 +390,7 @@ TEST(Predictor, Run) { config.SetModel(FLAGS_dirname); auto predictor = CreatePredictor(config); + ASSERT_EQ(predictor->GetInputTypes().size(), 4UL); auto w0 = predictor->GetInputHandle("firstw"); auto w1 = predictor->GetInputHandle("secondw"); diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 293ede9a284..ffb634ce829 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -231,6 +231,12 @@ class PD_INFER_DECL PaddlePredictor { return {}; } + /// \brief Get the input type of the model. + /// \return A map contains all the input names and type defined in the model. + virtual std::map GetInputTypes() { + return {}; + } + /// \brief Used to get the name of the network output. /// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. /// \return Output tensor names. diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index c3ccb58b803..ae844f138b0 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -93,6 +93,13 @@ class PD_INFER_DECL Predictor { /// explicit Predictor(const Config& config); + /// + /// \brief Get all input names and their corresponding type + /// + /// \return the map of input names and type + /// + std::map GetInputTypes(); + /// /// \brief Get the input names /// -- GitLab