diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 502b28d7b4c4e27276d9ac8880c9d46ee25191b1..1060602e12f5821a1c2f110d01a87d5fc6902704 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/api/cxx_api.h" +#include #include #include #include @@ -52,35 +53,36 @@ lite::Tensor *Predictor::GetInput(size_t offset) { } // get inputs names -std::vector Predictor::GetInputNames() { - std::vector input_names; - for (auto &item : input_names_) { - input_names.push_back(item.second); - } - return input_names; +const std::vector &Predictor::GetInputNames() { + return input_names_; } // get outputnames -std::vector Predictor::GetOutputNames() { - std::vector output_names; - for (auto &item : output_names_) { - output_names.push_back(item.second); - } - return output_names; +const std::vector &Predictor::GetOutputNames() { + return output_names_; } // append the names of inputs and outputs into input_names_ and output_names_ void Predictor::PrepareFeedFetch() { auto current_block = program_desc_.GetBlock(0); + std::vector feeds; + std::vector fetchs; for (int i = 0; i < current_block->OpsSize(); i++) { auto op = current_block->GetOp(i); if (op->Type() == "feed") { - int idx = op->GetAttr("col"); - input_names_[idx] = op->Output("Out").front(); - idx2feeds_[op->Output("Out").front()] = idx; + feeds.push_back(op); } else if (op->Type() == "fetch") { - int idx = op->GetAttr("col"); - output_names_[idx] = op->Input("X").front(); + fetchs.push_back(op); } } + input_names_.resize(feeds.size()); + output_names_.resize(fetchs.size()); + for (int i = 0; i < feeds.size(); i++) { + input_names_[feeds[i]->GetAttr("col")] = + feeds[i]->Output("Out").front(); + } + for (int i = 0; i < fetchs.size(); i++) { + output_names_[fetchs[i]->GetAttr("col")] = + fetchs[i]->Input("X").front(); + } } const lite::Tensor *Predictor::GetOutput(size_t offset) const { @@ -189,16 +191,17 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const { } // get input by name lite::Tensor *Predictor::GetInputByName(const std::string &name) { - if (idx2feeds_.find(name) == idx2feeds_.end()) { + auto element = std::find(input_names_.begin(), input_names_.end(), name); + if (element == input_names_.end()) { LOG(ERROR) << "Model do not have input named with: [" << name << "], model's inputs include:"; for (int i = 0; i < input_names_.size(); i++) { LOG(ERROR) << "[" << input_names_[i] << "]"; } - return NULL; + return nullptr; } else { - int idx = idx2feeds_[name]; - return GetInput(idx); + int position = std::distance(input_names_.begin(), element); + return GetInput(position); } } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 3d8dc2f06aca24e23a77a0b32dc85a0959290758..7226f4767ddf91c2e8d9864e4bc7a7665845179a 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -74,8 +74,8 @@ class LITE_API Predictor { // get input by name. lite::Tensor* GetInputByName(const std::string& name); // get inputnames and get outputnames. - std::vector GetInputNames(); - std::vector GetOutputNames(); + const std::vector& GetInputNames(); + const std::vector& GetOutputNames(); void PrepareFeedFetch(); // Get offset-th col of fetch results. @@ -107,9 +107,8 @@ class LITE_API Predictor { const Scope* exec_scope_; std::unique_ptr program_; bool program_generated_{false}; - std::map input_names_; - std::map idx2feeds_; - std::map output_names_; + std::vector input_names_; + std::vector output_names_; }; /* diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index b4fb3828f3b9b38aa3bcefc1df05d6453d55e771..62984ea476a901828367d74874291080667df3d8 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -37,8 +37,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::string GetVersion() const override; // get inputs names and get outputs names - std::vector GetInputNames() override; - std::vector GetOutputNames() override; + const std::vector &GetInputNames() override; + const std::vector &GetOutputNames() override; std::unique_ptr GetTensor( const std::string &name) const override; @@ -76,11 +76,11 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( return std::unique_ptr(new lite_api::Tensor(x)); } -std::vector CxxPaddleApiImpl::GetInputNames() { +const std::vector &CxxPaddleApiImpl::GetInputNames() { return raw_predictor_.GetInputNames(); } -std::vector CxxPaddleApiImpl::GetOutputNames() { +const std::vector &CxxPaddleApiImpl::GetOutputNames() { return raw_predictor_.GetOutputNames(); } diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 12963285e482b2ea6c6e761f430699507d45c0c5..d28081c5152024606eb2e453aae1c7ca9eb7cd07 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/api/light_api.h" +#include namespace paddle { namespace lite { @@ -56,16 +57,17 @@ Tensor* LightPredictor::GetInput(size_t offset) { // get input by name Tensor* LightPredictor::GetInputByName(const std::string& name) { - if (idx2feeds_.find(name) == idx2feeds_.end()) { + auto element = std::find(input_names_.begin(), input_names_.end(), name); + if (element == input_names_.end()) { LOG(ERROR) << "Model do not have input named with: [" << name << "], model's inputs include:"; for (int i = 0; i < input_names_.size(); i++) { LOG(ERROR) << "[" << input_names_[i] << "]"; } - return NULL; + return nullptr; } else { - int idx = idx2feeds_[name]; - return GetInput(idx); + int position = std::distance(input_names_.begin(), element); + return GetInput(position); } } @@ -79,35 +81,36 @@ const Tensor* LightPredictor::GetOutput(size_t offset) { return out_var->GetMutable(); } // get inputs names -std::vector LightPredictor::GetInputNames() { - std::vector input_names; - for (auto& item : input_names_) { - input_names.push_back(item.second); - } - return input_names; +const std::vector& LightPredictor::GetInputNames() { + return input_names_; } // get outputnames -std::vector LightPredictor::GetOutputNames() { - std::vector output_names; - for (auto& item : output_names_) { - output_names.push_back(item.second); - } - return output_names; +const std::vector& LightPredictor::GetOutputNames() { + return output_names_; } // append the names of inputs and outputs into input_names_ and output_names_ void LightPredictor::PrepareFeedFetch() { auto current_block = cpp_program_desc_.GetBlock(0); + std::vector feeds; + std::vector fetchs; for (int i = 0; i < current_block->OpsSize(); i++) { auto op = current_block->GetOp(i); if (op->Type() == "feed") { - int idx = op->GetAttr("col"); - input_names_[idx] = op->Output("Out").front(); - idx2feeds_[op->Output("Out").front()] = idx; + feeds.push_back(op); } else if (op->Type() == "fetch") { - int idx = op->GetAttr("col"); - output_names_[idx] = op->Input("X").front(); + fetchs.push_back(op); } } + input_names_.resize(feeds.size()); + output_names_.resize(fetchs.size()); + for (int i = 0; i < feeds.size(); i++) { + input_names_[feeds[i]->GetAttr("col")] = + feeds[i]->Output("Out").front(); + } + for (int i = 0; i < fetchs.size(); i++) { + output_names_[fetchs[i]->GetAttr("col")] = + fetchs[i]->Input("X").front(); + } } void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 0705e0aba42373dec9f1387573024c5b3bb98bbc..9d69cce441f86e563ad3ed0501514ab1fe79d98e 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -64,8 +64,8 @@ class LITE_API LightPredictor { } // get inputnames and get outputnames. - std::vector GetInputNames(); - std::vector GetOutputNames(); + const std::vector& GetInputNames(); + const std::vector& GetOutputNames(); void PrepareFeedFetch(); private: @@ -82,9 +82,8 @@ class LITE_API LightPredictor { std::shared_ptr scope_; std::unique_ptr program_; cpp::ProgramDesc cpp_program_desc_; - std::map input_names_; - std::map idx2feeds_; - std::map output_names_; + std::vector input_names_; + std::vector output_names_; }; } // namespace lite diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 90e1397d8338adb1ba732fc322ae03520bcce27f..70ab8ac0c03b8dea84da5ef1d6ca9c64c4c9d102 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -32,8 +32,8 @@ class LightPredictorImpl : public PaddlePredictor { void Run() override; std::string GetVersion() const override; - std::vector GetInputNames() override; - std::vector GetOutputNames() override; + const std::vector& GetInputNames() override; + const std::vector& GetOutputNames() override; std::unique_ptr GetTensor( const std::string& name) const override; @@ -78,11 +78,11 @@ std::unique_ptr LightPredictorImpl::GetInputByName( new Tensor(raw_predictor_->GetInputByName(name))); } -std::vector LightPredictorImpl::GetInputNames() { +const std::vector& LightPredictorImpl::GetInputNames() { return raw_predictor_->GetInputNames(); } -std::vector LightPredictorImpl::GetOutputNames() { +const std::vector& LightPredictorImpl::GetOutputNames() { return raw_predictor_->GetOutputNames(); } diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 418d97e9e8814b5e6e90a76cbdb6e92677c9c726..d2bbc295ad4b68e7849d5d25f34e0b5117fc846d 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -36,12 +36,14 @@ TEST(LightAPI, load) { data[i] = i; } - std::vector inputs = predictor.GetInputNames(); + predictor.PrepareFeedFetch(); + const std::vector& inputs = predictor.GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { LOG(INFO) << "inputnames: " << inputs[i]; } - std::vector outputs = predictor.GetOutputNames(); + const std::vector& outputs = predictor.GetOutputNames(); for (int i = 0; i < outputs.size(); i++) { LOG(INFO) << "outputnames: " << outputs[i]; } diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 545ae03f6725de7649b3278835bda973ade2755e..d7e3c014b0fe37a5f1da4210972349ac4124ed6b 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -75,9 +75,9 @@ class LITE_API PaddlePredictor { virtual std::string GetVersion() const = 0; // Get input names - virtual std::vector GetInputNames() = 0; + virtual const std::vector& GetInputNames() = 0; // Get output names - virtual std::vector GetOutputNames() = 0; + virtual const std::vector& GetOutputNames() = 0; // Get Input by name virtual std::unique_ptr GetInputByName(const std::string& name) = 0; diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index 63142d49814473e6dc9ee6e553d95fa86b4058c5..443a05d9927cfa461a306ce6c3c32ff6e5024631 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -37,12 +37,12 @@ TEST(CxxApi, run) { LOG(INFO) << "Version: " << predictor->GetVersion(); - std::vector inputs = predictor->GetInputNames(); + auto& inputs = predictor->GetInputNames(); LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { LOG(INFO) << "inputnames: " << inputs[i]; } - std::vector outputs = predictor->GetOutputNames(); + auto& outputs = predictor->GetOutputNames(); for (int i = 0; i < outputs.size(); i++) { LOG(INFO) << "outputnames: " << outputs[i]; } @@ -76,14 +76,14 @@ TEST(LightApi, run) { auto predictor = lite_api::CreatePaddlePredictor(config); - std::vector inputs = predictor->GetInputNames(); + auto& inputs = predictor->GetInputNames(); LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { - LOG(INFO) << "inputnames: " << inputs[i]; + LOG(INFO) << "inputnames: " << inputs.at(i); } - std::vector outputs = predictor->GetOutputNames(); + auto& outputs = predictor->GetOutputNames(); for (int i = 0; i < outputs.size(); i++) { - LOG(INFO) << "outputnames: " << outputs[i]; + LOG(INFO) << "outputnames: " << outputs.at(i); } LOG(INFO) << "Version: " << predictor->GetVersion();