From 8591aaecdb3454a466767a81d9d08912e7b2b2c8 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Fri, 18 Oct 2019 12:46:23 +0800 Subject: [PATCH] Fix codestyle of GetInputName&GetOutputName (#2185) * add shell file to automatically build and collect publish result test=develop * modify codestyle of getInputNames test=develop * test=develop * rm publish.sh * remove copy of func param * test=develop * test=devcelop * test=develop * test=develop * const & test=develop * modify variable defination test=develop * test=develop * test=develop * test=develop * test=develop --- lite/api/cxx_api.cc | 45 ++++++++++++++++++++----------------- lite/api/cxx_api.h | 9 ++++---- lite/api/cxx_api_impl.cc | 8 +++---- lite/api/light_api.cc | 45 ++++++++++++++++++++----------------- lite/api/light_api.h | 9 ++++---- lite/api/light_api_impl.cc | 8 +++---- lite/api/light_api_test.cc | 6 +++-- lite/api/paddle_api.h | 4 ++-- lite/api/paddle_api_test.cc | 12 +++++----- 9 files changed, 76 insertions(+), 70 deletions(-) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 502b28d7b4..1060602e12 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/api/cxx_api.h" +#include #include #include #include @@ -52,35 +53,36 @@ lite::Tensor *Predictor::GetInput(size_t offset) { } // get inputs names -std::vector Predictor::GetInputNames() { - std::vector input_names; - for (auto &item : input_names_) { - input_names.push_back(item.second); - } - return input_names; +const std::vector &Predictor::GetInputNames() { + return input_names_; } // get outputnames -std::vector Predictor::GetOutputNames() { - std::vector output_names; - for (auto &item : output_names_) { - output_names.push_back(item.second); - } - return output_names; +const std::vector &Predictor::GetOutputNames() { + return output_names_; } // append the names of inputs and outputs into input_names_ and output_names_ void Predictor::PrepareFeedFetch() { auto current_block = program_desc_.GetBlock(0); + std::vector feeds; + std::vector fetchs; for (int i = 0; i < current_block->OpsSize(); i++) { auto op = current_block->GetOp(i); if (op->Type() == "feed") { - int idx = op->GetAttr("col"); - input_names_[idx] = op->Output("Out").front(); - idx2feeds_[op->Output("Out").front()] = idx; + feeds.push_back(op); } else if (op->Type() == "fetch") { - int idx = op->GetAttr("col"); - output_names_[idx] = op->Input("X").front(); + fetchs.push_back(op); } } + input_names_.resize(feeds.size()); + output_names_.resize(fetchs.size()); + for (int i = 0; i < feeds.size(); i++) { + input_names_[feeds[i]->GetAttr("col")] = + feeds[i]->Output("Out").front(); + } + for (int i = 0; i < fetchs.size(); i++) { + output_names_[fetchs[i]->GetAttr("col")] = + fetchs[i]->Input("X").front(); + } } const lite::Tensor *Predictor::GetOutput(size_t offset) const { @@ -189,16 +191,17 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const { } // get input by name lite::Tensor *Predictor::GetInputByName(const std::string &name) { - if (idx2feeds_.find(name) == idx2feeds_.end()) { + auto element = std::find(input_names_.begin(), input_names_.end(), name); + if (element == input_names_.end()) { LOG(ERROR) << "Model do not have input named with: [" << name << "], model's inputs include:"; for (int i = 0; i < input_names_.size(); i++) { LOG(ERROR) << "[" << input_names_[i] << "]"; } - return NULL; + return nullptr; } else { - int idx = idx2feeds_[name]; - return GetInput(idx); + int position = std::distance(input_names_.begin(), element); + return GetInput(position); } } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 3d8dc2f06a..7226f4767d 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -74,8 +74,8 @@ class LITE_API Predictor { // get input by name. lite::Tensor* GetInputByName(const std::string& name); // get inputnames and get outputnames. - std::vector GetInputNames(); - std::vector GetOutputNames(); + const std::vector& GetInputNames(); + const std::vector& GetOutputNames(); void PrepareFeedFetch(); // Get offset-th col of fetch results. @@ -107,9 +107,8 @@ class LITE_API Predictor { const Scope* exec_scope_; std::unique_ptr program_; bool program_generated_{false}; - std::map input_names_; - std::map idx2feeds_; - std::map output_names_; + std::vector input_names_; + std::vector output_names_; }; /* diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index b4fb3828f3..62984ea476 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -37,8 +37,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::string GetVersion() const override; // get inputs names and get outputs names - std::vector GetInputNames() override; - std::vector GetOutputNames() override; + const std::vector &GetInputNames() override; + const std::vector &GetOutputNames() override; std::unique_ptr GetTensor( const std::string &name) const override; @@ -76,11 +76,11 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( return std::unique_ptr(new lite_api::Tensor(x)); } -std::vector CxxPaddleApiImpl::GetInputNames() { +const std::vector &CxxPaddleApiImpl::GetInputNames() { return raw_predictor_.GetInputNames(); } -std::vector CxxPaddleApiImpl::GetOutputNames() { +const std::vector &CxxPaddleApiImpl::GetOutputNames() { return raw_predictor_.GetOutputNames(); } diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 12963285e4..d28081c515 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/api/light_api.h" +#include namespace paddle { namespace lite { @@ -56,16 +57,17 @@ Tensor* LightPredictor::GetInput(size_t offset) { // get input by name Tensor* LightPredictor::GetInputByName(const std::string& name) { - if (idx2feeds_.find(name) == idx2feeds_.end()) { + auto element = std::find(input_names_.begin(), input_names_.end(), name); + if (element == input_names_.end()) { LOG(ERROR) << "Model do not have input named with: [" << name << "], model's inputs include:"; for (int i = 0; i < input_names_.size(); i++) { LOG(ERROR) << "[" << input_names_[i] << "]"; } - return NULL; + return nullptr; } else { - int idx = idx2feeds_[name]; - return GetInput(idx); + int position = std::distance(input_names_.begin(), element); + return GetInput(position); } } @@ -79,35 +81,36 @@ const Tensor* LightPredictor::GetOutput(size_t offset) { return out_var->GetMutable(); } // get inputs names -std::vector LightPredictor::GetInputNames() { - std::vector input_names; - for (auto& item : input_names_) { - input_names.push_back(item.second); - } - return input_names; +const std::vector& LightPredictor::GetInputNames() { + return input_names_; } // get outputnames -std::vector LightPredictor::GetOutputNames() { - std::vector output_names; - for (auto& item : output_names_) { - output_names.push_back(item.second); - } - return output_names; +const std::vector& LightPredictor::GetOutputNames() { + return output_names_; } // append the names of inputs and outputs into input_names_ and output_names_ void LightPredictor::PrepareFeedFetch() { auto current_block = cpp_program_desc_.GetBlock(0); + std::vector feeds; + std::vector fetchs; for (int i = 0; i < current_block->OpsSize(); i++) { auto op = current_block->GetOp(i); if (op->Type() == "feed") { - int idx = op->GetAttr("col"); - input_names_[idx] = op->Output("Out").front(); - idx2feeds_[op->Output("Out").front()] = idx; + feeds.push_back(op); } else if (op->Type() == "fetch") { - int idx = op->GetAttr("col"); - output_names_[idx] = op->Input("X").front(); + fetchs.push_back(op); } } + input_names_.resize(feeds.size()); + output_names_.resize(fetchs.size()); + for (int i = 0; i < feeds.size(); i++) { + input_names_[feeds[i]->GetAttr("col")] = + feeds[i]->Output("Out").front(); + } + for (int i = 0; i < fetchs.size(); i++) { + output_names_[fetchs[i]->GetAttr("col")] = + fetchs[i]->Input("X").front(); + } } void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 0705e0aba4..9d69cce441 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -64,8 +64,8 @@ class LITE_API LightPredictor { } // get inputnames and get outputnames. - std::vector GetInputNames(); - std::vector GetOutputNames(); + const std::vector& GetInputNames(); + const std::vector& GetOutputNames(); void PrepareFeedFetch(); private: @@ -82,9 +82,8 @@ class LITE_API LightPredictor { std::shared_ptr scope_; std::unique_ptr program_; cpp::ProgramDesc cpp_program_desc_; - std::map input_names_; - std::map idx2feeds_; - std::map output_names_; + std::vector input_names_; + std::vector output_names_; }; } // namespace lite diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 90e1397d83..70ab8ac0c0 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -32,8 +32,8 @@ class LightPredictorImpl : public PaddlePredictor { void Run() override; std::string GetVersion() const override; - std::vector GetInputNames() override; - std::vector GetOutputNames() override; + const std::vector& GetInputNames() override; + const std::vector& GetOutputNames() override; std::unique_ptr GetTensor( const std::string& name) const override; @@ -78,11 +78,11 @@ std::unique_ptr LightPredictorImpl::GetInputByName( new Tensor(raw_predictor_->GetInputByName(name))); } -std::vector LightPredictorImpl::GetInputNames() { +const std::vector& LightPredictorImpl::GetInputNames() { return raw_predictor_->GetInputNames(); } -std::vector LightPredictorImpl::GetOutputNames() { +const std::vector& LightPredictorImpl::GetOutputNames() { return raw_predictor_->GetOutputNames(); } diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 418d97e9e8..d2bbc295ad 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -36,12 +36,14 @@ TEST(LightAPI, load) { data[i] = i; } - std::vector inputs = predictor.GetInputNames(); + predictor.PrepareFeedFetch(); + const std::vector& inputs = predictor.GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { LOG(INFO) << "inputnames: " << inputs[i]; } - std::vector outputs = predictor.GetOutputNames(); + const std::vector& outputs = predictor.GetOutputNames(); for (int i = 0; i < outputs.size(); i++) { LOG(INFO) << "outputnames: " << outputs[i]; } diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 545ae03f67..d7e3c014b0 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -75,9 +75,9 @@ class LITE_API PaddlePredictor { virtual std::string GetVersion() const = 0; // Get input names - virtual std::vector GetInputNames() = 0; + virtual const std::vector& GetInputNames() = 0; // Get output names - virtual std::vector GetOutputNames() = 0; + virtual const std::vector& GetOutputNames() = 0; // Get Input by name virtual std::unique_ptr GetInputByName(const std::string& name) = 0; diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index 63142d4981..443a05d992 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -37,12 +37,12 @@ TEST(CxxApi, run) { LOG(INFO) << "Version: " << predictor->GetVersion(); - std::vector inputs = predictor->GetInputNames(); + auto& inputs = predictor->GetInputNames(); LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { LOG(INFO) << "inputnames: " << inputs[i]; } - std::vector outputs = predictor->GetOutputNames(); + auto& outputs = predictor->GetOutputNames(); for (int i = 0; i < outputs.size(); i++) { LOG(INFO) << "outputnames: " << outputs[i]; } @@ -76,14 +76,14 @@ TEST(LightApi, run) { auto predictor = lite_api::CreatePaddlePredictor(config); - std::vector inputs = predictor->GetInputNames(); + auto& inputs = predictor->GetInputNames(); LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { - LOG(INFO) << "inputnames: " << inputs[i]; + LOG(INFO) << "inputnames: " << inputs.at(i); } - std::vector outputs = predictor->GetOutputNames(); + auto& outputs = predictor->GetOutputNames(); for (int i = 0; i < outputs.size(); i++) { - LOG(INFO) << "outputnames: " << outputs[i]; + LOG(INFO) << "outputnames: " << outputs.at(i); } LOG(INFO) << "Version: " << predictor->GetVersion(); -- GitLab