未验证 提交 a8ae87f1 编写于 作者: H heliqi 提交者: GitHub

[inference]predictor add GetInputType interface (#45143)

* predictor add GetInputType interface

* predictor change GetInputType to GetInputTypes

* predictor add tester

* predictor add tester

* predictor change GetInputType to GetInputTypes

* predictor change GetInputType to GetInputTypes

* predictor add tester
上级 30122212
......@@ -1454,6 +1454,37 @@ AnalysisPredictor::GetInputTensorShape() {
return input_shapes;
}
std::map<std::string, paddle_infer::DataType>
AnalysisPredictor::GetInputTypes() {
std::map<std::string, paddle_infer::DataType> input_type;
std::vector<std::string> names = GetInputNames();
for (const auto &name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet(
"Input %s does not exist inference_program_.", name));
auto dtype = var->GetDataType();
if (dtype == paddle::framework::proto::VarType::FP32) {
input_type[name] = paddle_infer::DataType::FLOAT32;
} else if (dtype == paddle::framework::proto::VarType::FP16) {
input_type[name] = paddle_infer::DataType::FLOAT16;
} else if (dtype == paddle::framework::proto::VarType::INT64) {
input_type[name] = paddle_infer::DataType::INT64;
} else if (dtype == paddle::framework::proto::VarType::INT32) {
input_type[name] = paddle_infer::DataType::INT32;
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
input_type[name] = paddle_infer::DataType::UINT8;
} else if (dtype == paddle::framework::proto::VarType::INT8) {
input_type[name] = paddle_infer::DataType::INT8;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported data type `%s` when get input dtype ", dtype));
}
}
return input_type;
}
std::vector<std::string> AnalysisPredictor::GetOutputNames() {
std::vector<std::string> output_names;
for (auto &item : idx2fetches_) {
......@@ -2172,6 +2203,10 @@ std::vector<std::string> Predictor::GetInputNames() {
return predictor_->GetInputNames();
}
std::map<std::string, DataType> Predictor::GetInputTypes() {
return predictor_->GetInputTypes();
}
std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) {
return predictor_->GetInputTensor(name);
}
......
......@@ -173,6 +173,12 @@ class AnalysisPredictor : public PaddlePredictor {
/// \return the map of input names and shapes
///
std::map<std::string, std::vector<int64_t>> GetInputTensorShape() override;
///
/// \brief Get all input names and their corresponding type
///
/// \return the map of input names and type
///
std::map<std::string, paddle_infer::DataType> GetInputTypes() override;
///
/// \brief Run the prediction engine
......
......@@ -105,6 +105,7 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
// 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor;
......@@ -389,6 +390,7 @@ TEST(Predictor, Run) {
config.SetModel(FLAGS_dirname);
auto predictor = CreatePredictor(config);
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw");
......
......@@ -231,6 +231,12 @@ class PD_INFER_DECL PaddlePredictor {
return {};
}
/// \brief Get the input type of the model.
/// \return A map contains all the input names and type defined in the model.
virtual std::map<std::string, paddle_infer::DataType> GetInputTypes() {
return {};
}
/// \brief Used to get the name of the network output.
/// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
/// \return Output tensor names.
......
......@@ -93,6 +93,13 @@ class PD_INFER_DECL Predictor {
///
explicit Predictor(const Config& config);
///
/// \brief Get all input names and their corresponding type
///
/// \return the map of input names and type
///
std::map<std::string, DataType> GetInputTypes();
///
/// \brief Get the input names
///
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册