未验证 提交 a8ae87f1 编写于 作者: H heliqi 提交者: GitHub

[inference]predictor add GetInputType interface (#45143)

* predictor add GetInputType interface

* predictor change GetInputType to GetInputTypes

* predictor add tester

* predictor add tester

* predictor change GetInputType to GetInputTypes

* predictor change GetInputType to GetInputTypes

* predictor add tester
上级 30122212
...@@ -1454,6 +1454,37 @@ AnalysisPredictor::GetInputTensorShape() { ...@@ -1454,6 +1454,37 @@ AnalysisPredictor::GetInputTensorShape() {
return input_shapes; return input_shapes;
} }
std::map<std::string, paddle_infer::DataType>
AnalysisPredictor::GetInputTypes() {
std::map<std::string, paddle_infer::DataType> input_type;
std::vector<std::string> names = GetInputNames();
for (const auto &name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet(
"Input %s does not exist inference_program_.", name));
auto dtype = var->GetDataType();
if (dtype == paddle::framework::proto::VarType::FP32) {
input_type[name] = paddle_infer::DataType::FLOAT32;
} else if (dtype == paddle::framework::proto::VarType::FP16) {
input_type[name] = paddle_infer::DataType::FLOAT16;
} else if (dtype == paddle::framework::proto::VarType::INT64) {
input_type[name] = paddle_infer::DataType::INT64;
} else if (dtype == paddle::framework::proto::VarType::INT32) {
input_type[name] = paddle_infer::DataType::INT32;
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
input_type[name] = paddle_infer::DataType::UINT8;
} else if (dtype == paddle::framework::proto::VarType::INT8) {
input_type[name] = paddle_infer::DataType::INT8;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported data type `%s` when get input dtype ", dtype));
}
}
return input_type;
}
std::vector<std::string> AnalysisPredictor::GetOutputNames() { std::vector<std::string> AnalysisPredictor::GetOutputNames() {
std::vector<std::string> output_names; std::vector<std::string> output_names;
for (auto &item : idx2fetches_) { for (auto &item : idx2fetches_) {
...@@ -2172,6 +2203,10 @@ std::vector<std::string> Predictor::GetInputNames() { ...@@ -2172,6 +2203,10 @@ std::vector<std::string> Predictor::GetInputNames() {
return predictor_->GetInputNames(); return predictor_->GetInputNames();
} }
std::map<std::string, DataType> Predictor::GetInputTypes() {
return predictor_->GetInputTypes();
}
std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) { std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) {
return predictor_->GetInputTensor(name); return predictor_->GetInputTensor(name);
} }
......
...@@ -173,6 +173,12 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -173,6 +173,12 @@ class AnalysisPredictor : public PaddlePredictor {
/// \return the map of input names and shapes /// \return the map of input names and shapes
/// ///
std::map<std::string, std::vector<int64_t>> GetInputTensorShape() override; std::map<std::string, std::vector<int64_t>> GetInputTensorShape() override;
///
/// \brief Get all input names and their corresponding type
///
/// \return the map of input names and type
///
std::map<std::string, paddle_infer::DataType> GetInputTypes() override;
/// ///
/// \brief Run the prediction engine /// \brief Run the prediction engine
......
...@@ -105,6 +105,7 @@ TEST(AnalysisPredictor, analysis_on) { ...@@ -105,6 +105,7 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_TRUE(predictor->sub_scope_); ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
// 2. Dummy Input Data // 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor; PaddleTensor tensor;
...@@ -389,6 +390,7 @@ TEST(Predictor, Run) { ...@@ -389,6 +390,7 @@ TEST(Predictor, Run) {
config.SetModel(FLAGS_dirname); config.SetModel(FLAGS_dirname);
auto predictor = CreatePredictor(config); auto predictor = CreatePredictor(config);
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
auto w0 = predictor->GetInputHandle("firstw"); auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw"); auto w1 = predictor->GetInputHandle("secondw");
......
...@@ -231,6 +231,12 @@ class PD_INFER_DECL PaddlePredictor { ...@@ -231,6 +231,12 @@ class PD_INFER_DECL PaddlePredictor {
return {}; return {};
} }
/// \brief Get the input type of the model.
/// \return A map contains all the input names and type defined in the model.
virtual std::map<std::string, paddle_infer::DataType> GetInputTypes() {
return {};
}
/// \brief Used to get the name of the network output. /// \brief Used to get the name of the network output.
/// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. /// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
/// \return Output tensor names. /// \return Output tensor names.
......
...@@ -93,6 +93,13 @@ class PD_INFER_DECL Predictor { ...@@ -93,6 +93,13 @@ class PD_INFER_DECL Predictor {
/// ///
explicit Predictor(const Config& config); explicit Predictor(const Config& config);
///
/// \brief Get all input names and their corresponding type
///
/// \return the map of input names and type
///
std::map<std::string, DataType> GetInputTypes();
/// ///
/// \brief Get the input names /// \brief Get the input names
/// ///
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册