diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 7935ab275121ac55261367a304c3bd5e6d2e9a70..78b937c4ab6122fb915f8f82bc8dae4a63cf9bbf 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -64,6 +64,38 @@ lite::Tensor *Predictor::GetInput(size_t offset) { return &feed_list->at(offset); } +// get inputs names +std::vector Predictor::GetInputNames() { + std::vector input_names; + for (auto &item : input_names_) { + input_names.push_back(item.second); + } + return input_names; +} +// get outputnames +std::vector Predictor::GetOutputNames() { + std::vector output_names; + for (auto &item : output_names_) { + output_names.push_back(item.second); + } + return output_names; +} +// append the names of inputs and outputs into input_names_ and output_names_ +void Predictor::PrepareFeedFetch() { + auto current_block = program_desc_.GetBlock(0); + for (int i = 0; i < current_block->OpsSize(); i++) { + auto op = current_block->GetOp(i); + if (op->Type() == "feed") { + int idx = op->GetAttr("col"); + input_names_[idx] = op->Output("Out").front(); + idx2feeds_[op->Output("Out").front()] = idx; + } else if (op->Type() == "fetch") { + int idx = op->GetAttr("col"); + output_names_[idx] = op->Input("X").front(); + } + } +} + const lite::Tensor *Predictor::GetOutput(size_t offset) const { auto *_fetch_list = exec_scope_->FindVar("fetch"); CHECK(_fetch_list) << "no fatch variable in exec_scope"; @@ -162,6 +194,20 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const { auto *var = exec_scope_->FindVar(name); return &var->Get(); } +// get input by name +lite::Tensor *Predictor::GetInputByName(const std::string &name) { + if (idx2feeds_.find(name) == idx2feeds_.end()) { + LOG(ERROR) << "Model do not have input named with: [" << name + << "], model's inputs include:"; + for (int i = 0; i < input_names_.size(); i++) { + LOG(ERROR) << "[" << input_names_[i] << "]"; + } + return NULL; + } else { + int idx = idx2feeds_[name]; + return GetInput(idx); + } +} #ifdef LITE_WITH_TRAIN void Predictor::FeedVars(const std::vector &tensors) { diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 05bc1b13201df962adf39c1b599695e23b19aecf..1fe7c985afdc50f0e17a2d581876ed7d6947d2ad 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include @@ -72,6 +73,12 @@ class LITE_API Predictor { // Get offset-th col of feed inputs. lite::Tensor* GetInput(size_t offset); + // get input by name. + lite::Tensor* GetInputByName(const std::string& name); + // get inputnames and get outputnames. + std::vector GetInputNames(); + std::vector GetOutputNames(); + void PrepareFeedFetch(); // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; @@ -102,6 +109,9 @@ class LITE_API Predictor { const Scope* exec_scope_; std::unique_ptr program_; bool program_generated_{false}; + std::map input_names_; + std::map idx2feeds_; + std::map output_names_; }; /* diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index c5aa0a00a5a0d0a5a0b6418bccd53602964a6205..8091e2ffd55cdabfbb73a1875873da39acd4f85d 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -36,9 +36,17 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::string GetVersion() const override; + // get inputs names and get outputs names + std::vector GetInputNames() override; + std::vector GetOutputNames() override; + std::unique_ptr GetTensor( const std::string &name) const override; + // Get InputTebsor by name + std::unique_ptr GetInputByName( + const std::string &name) override; + void SaveOptimizedModel(const std::string &model_dir, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) override; @@ -56,6 +64,7 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { auto places = config.valid_places(); places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); raw_predictor_.Build(config, places); + raw_predictor_.PrepareFeedFetch(); } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { @@ -69,6 +78,14 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( return std::unique_ptr(new lite_api::Tensor(x)); } +std::vector CxxPaddleApiImpl::GetInputNames() { + return raw_predictor_.GetInputNames(); +} + +std::vector CxxPaddleApiImpl::GetOutputNames() { + return raw_predictor_.GetOutputNames(); +} + void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); } std::string CxxPaddleApiImpl::GetVersion() const { return version(); } @@ -79,6 +96,12 @@ std::unique_ptr CxxPaddleApiImpl::GetTensor( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetInputByName( + const std::string &name) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_.GetInputByName(name))); +} + void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir, lite_api::LiteModelType model_type) { raw_predictor_.SaveModel(model_dir, model_type); diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 2d75a1ba8213f4455998909f82c42a8cb1f60021..f23a973c830dc62719f5c4e25b2cd2de294882e8 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -53,6 +53,21 @@ Tensor* LightPredictor::GetInput(size_t offset) { return &feed_list->at(offset); } +// get input by name +Tensor* LightPredictor::GetInputByName(const std::string& name) { + if (idx2feeds_.find(name) == idx2feeds_.end()) { + LOG(ERROR) << "Model do not have input named with: [" << name + << "], model's inputs include:"; + for (int i = 0; i < input_names_.size(); i++) { + LOG(ERROR) << "[" << input_names_[i] << "]"; + } + return NULL; + } else { + int idx = idx2feeds_[name]; + return GetInput(idx); + } +} + const Tensor* LightPredictor::GetOutput(size_t offset) { auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); CHECK(_fetch_list) << "no fatch variable in exec_scope"; @@ -60,6 +75,37 @@ const Tensor* LightPredictor::GetOutput(size_t offset) { CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; return &fetch_list.at(offset); } +// get inputs names +std::vector LightPredictor::GetInputNames() { + std::vector input_names; + for (auto& item : input_names_) { + input_names.push_back(item.second); + } + return input_names; +} +// get outputnames +std::vector LightPredictor::GetOutputNames() { + std::vector output_names; + for (auto& item : output_names_) { + output_names.push_back(item.second); + } + return output_names; +} +// append the names of inputs and outputs into input_names_ and output_names_ +void LightPredictor::PrepareFeedFetch() { + auto current_block = cpp_program_desc_.GetBlock(0); + for (int i = 0; i < current_block->OpsSize(); i++) { + auto op = current_block->GetOp(i); + if (op->Type() == "feed") { + int idx = op->GetAttr("col"); + input_names_[idx] = op->Output("Out").front(); + idx2feeds_[op->Output("Out").front()] = idx; + } else if (op->Type() == "fetch") { + int idx = op->GetAttr("col"); + output_names_[idx] = op->Input("X").front(); + } + } +} void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { std::vector insts; diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 0d5c7006c843f904dfeff8a97dd23f5cc4aebaeb..0705e0aba42373dec9f1387573024c5b3bb98bbc 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -18,6 +18,7 @@ */ #pragma once +#include #include #include #include @@ -52,7 +53,8 @@ class LITE_API LightPredictor { // Get offset-th col of feed inputs. Tensor* GetInput(size_t offset); - + // get input by name. + Tensor* GetInputByName(const std::string& name); // Get offset-th col of fetch outputs. const Tensor* GetOutput(size_t offset); @@ -61,6 +63,11 @@ class LITE_API LightPredictor { return &var->Get(); } + // get inputnames and get outputnames. + std::vector GetInputNames(); + std::vector GetOutputNames(); + void PrepareFeedFetch(); + private: void Build( const std::string& model_dir, @@ -75,6 +82,9 @@ class LITE_API LightPredictor { std::shared_ptr scope_; std::unique_ptr program_; cpp::ProgramDesc cpp_program_desc_; + std::map input_names_; + std::map idx2feeds_; + std::map output_names_; }; } // namespace lite diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 4dedc5332db3082ee636eabc8eac1215034b2abe..c857e377fcd62644be7783c2bc7431a52e2af277 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -32,9 +32,13 @@ class LightPredictorImpl : public PaddlePredictor { void Run() override; std::string GetVersion() const override; + std::vector GetInputNames() override; + std::vector GetOutputNames() override; std::unique_ptr GetTensor( const std::string& name) const override; + // Get InputTebsor by name + std::unique_ptr GetInputByName(const std::string& name) override; void Init(const MobileConfig& config); @@ -49,6 +53,7 @@ void LightPredictorImpl::Init(const MobileConfig& config) { config.param_buffer(), config.model_from_memory(), LiteModelType::kNaiveBuffer)); + raw_predictor_->PrepareFeedFetch(); } std::unique_ptr LightPredictorImpl::GetInput(int i) { @@ -68,6 +73,19 @@ std::unique_ptr LightPredictorImpl::GetTensor( return std::unique_ptr( new Tensor(raw_predictor_->GetTensor(name))); } +std::unique_ptr LightPredictorImpl::GetInputByName( + const std::string& name) { + return std::unique_ptr( + new Tensor(raw_predictor_->GetInputByName(name))); +} + +std::vector LightPredictorImpl::GetInputNames() { + return raw_predictor_->GetInputNames(); +} + +std::vector LightPredictorImpl::GetOutputNames() { + return raw_predictor_->GetOutputNames(); +} template <> std::shared_ptr CreatePaddlePredictor( diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 8e2fc420bc3be91e35047b823e628b80f2175496..0960b6c079f9a833206d5b41526360e14a512116 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -36,6 +36,17 @@ TEST(LightAPI, load) { data[i] = i; } + predictor.PrepareFeedFetch(); + std::vector inputs = predictor.GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); + for (int i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputnames: " << inputs[i]; + } + std::vector outputs = predictor.GetOutputNames(); + for (int i = 0; i < outputs.size(); i++) { + LOG(INFO) << "outputnames: " << outputs[i]; + } + predictor.Run(); const auto* output = predictor.GetOutput(0); diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index e4da4af4a942d46ffe4429e589762a8878faf54b..97b3b31bc77933fd41e862f4ec6f0d023c16911f 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -74,6 +74,14 @@ class LITE_API PaddlePredictor { virtual std::string GetVersion() const = 0; + // Get input names + virtual std::vector GetInputNames() = 0; + // Get output names + virtual std::vector GetOutputNames() = 0; + + // Get Input by name + virtual std::unique_ptr GetInputByName(const std::string& name) = 0; + /// Get a readonly tensor, return null if no one called `name` exists. virtual std::unique_ptr GetTensor( const std::string& name) const = 0; diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index ac5388e5dd92785d4812c7e481ba8d5979b93e19..994658037735cd26bb3dcbaf905215f17f306af7 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -38,7 +38,16 @@ TEST(CxxApi, run) { LOG(INFO) << "Version: " << predictor->GetVersion(); - auto input_tensor = predictor->GetInput(0); + std::vector inputs = predictor->GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); + for (int i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputnames: " << inputs[i]; + } + std::vector outputs = predictor->GetOutputNames(); + for (int i = 0; i < outputs.size(); i++) { + LOG(INFO) << "outputnames: " << outputs[i]; + } + auto input_tensor = predictor->GetInputByName(inputs[0]); input_tensor->Resize(std::vector({100, 100})); auto* data = input_tensor->mutable_data(); for (int i = 0; i < 100 * 100; i++) { @@ -47,7 +56,7 @@ TEST(CxxApi, run) { predictor->Run(); - auto output = predictor->GetOutput(0); + auto output = predictor->GetTensor(outputs[0]); auto* out = output->data(); LOG(INFO) << out[0]; LOG(INFO) << out[1]; @@ -68,6 +77,16 @@ TEST(LightApi, run) { auto predictor = lite_api::CreatePaddlePredictor(config); + std::vector inputs = predictor->GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); + for (int i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputnames: " << inputs[i]; + } + std::vector outputs = predictor->GetOutputNames(); + for (int i = 0; i < outputs.size(); i++) { + LOG(INFO) << "outputnames: " << outputs[i]; + } + LOG(INFO) << "Version: " << predictor->GetVersion(); auto input_tensor = predictor->GetInput(0);