diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 49045089ce5a045ddbc4fdb1ca060ab00baa3015..fa927a7da225f7527297e71b0f4913fb19196fe1 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1454,6 +1454,37 @@ AnalysisPredictor::GetInputTensorShape() { return input_shapes; } +std::map +AnalysisPredictor::GetInputTypes() { + std::map input_type; + std::vector names = GetInputNames(); + for (const auto &name : names) { + auto *var = inference_program_->Block(0).FindVar(name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PreconditionNotMet( + "Input %s does not exist inference_program_.", name)); + auto dtype = var->GetDataType(); + if (dtype == paddle::framework::proto::VarType::FP32) { + input_type[name] = paddle_infer::DataType::FLOAT32; + } else if (dtype == paddle::framework::proto::VarType::FP16) { + input_type[name] = paddle_infer::DataType::FLOAT16; + } else if (dtype == paddle::framework::proto::VarType::INT64) { + input_type[name] = paddle_infer::DataType::INT64; + } else if (dtype == paddle::framework::proto::VarType::INT32) { + input_type[name] = paddle_infer::DataType::INT32; + } else if (dtype == paddle::framework::proto::VarType::UINT8) { + input_type[name] = paddle_infer::DataType::UINT8; + } else if (dtype == paddle::framework::proto::VarType::INT8) { + input_type[name] = paddle_infer::DataType::INT8; + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported data type `%s` when get input dtype ", dtype)); + } + } + return input_type; +} + std::vector AnalysisPredictor::GetOutputNames() { std::vector output_names; for (auto &item : idx2fetches_) { @@ -2172,6 +2203,10 @@ std::vector Predictor::GetInputNames() { return predictor_->GetInputNames(); } +std::map Predictor::GetInputTypes() { + return predictor_->GetInputTypes(); +} + std::unique_ptr Predictor::GetInputHandle(const std::string &name) { return predictor_->GetInputTensor(name); } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 0835f712b6e4607d780c0c18f168d12f8e272f8e..235714257558aaafff17d68ee308fef9990bb090 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -173,6 +173,12 @@ class AnalysisPredictor : public PaddlePredictor { /// \return the map of input names and shapes /// std::map> GetInputTensorShape() override; + /// + /// \brief Get all input names and their corresponding type + /// + /// \return the map of input names and type + /// + std::map GetInputTypes() override; /// /// \brief Run the prediction engine diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index a7871737ad4b1f03a53b6307d59294f7364d8058..8856ceb61a76f2fd9d822f3d15eda5de3e17bd10 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -105,6 +105,7 @@ TEST(AnalysisPredictor, analysis_on) { ASSERT_TRUE(predictor->sub_scope_); ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); + ASSERT_EQ(predictor->GetInputTypes().size(), 4UL); // 2. Dummy Input Data int64_t data[4] = {1, 2, 3, 4}; PaddleTensor tensor; @@ -389,6 +390,7 @@ TEST(Predictor, Run) { config.SetModel(FLAGS_dirname); auto predictor = CreatePredictor(config); + ASSERT_EQ(predictor->GetInputTypes().size(), 4UL); auto w0 = predictor->GetInputHandle("firstw"); auto w1 = predictor->GetInputHandle("secondw"); diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 293ede9a2848abf3d5aa8a086799f6883f02ee9e..ffb634ce829683b2725090890d7f167fc51a08de 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -231,6 +231,12 @@ class PD_INFER_DECL PaddlePredictor { return {}; } + /// \brief Get the input type of the model. + /// \return A map contains all the input names and type defined in the model. + virtual std::map GetInputTypes() { + return {}; + } + /// \brief Used to get the name of the network output. /// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. /// \return Output tensor names. diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index c3ccb58b8031ce19d04fe01dc1893e56573215fe..ae844f138b0f680148012546f106da448a31a97d 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -93,6 +93,13 @@ class PD_INFER_DECL Predictor { /// explicit Predictor(const Config& config); + /// + /// \brief Get all input names and their corresponding type + /// + /// \return the map of input names and type + /// + std::map GetInputTypes(); + /// /// \brief Get the input names ///