未验证 提交 8591aaec 编写于 作者: H huzhiqiang 提交者: GitHub

Fix codestyle of GetInputName&GetOutputName (#2185)

* add shell file to automatically build and collect publish result test=develop

* modify codestyle of getInputNames test=develop

* test=develop

* rm publish.sh

* remove copy of func param

* test=develop

* test=devcelop

* test=develop

* test=develop

* const & test=develop

* modify variable defination test=develop

* test=develop

* test=develop

* test=develop

* test=develop
上级 4ac51a6b
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/api/cxx_api.h" #include "lite/api/cxx_api.h"
#include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -52,34 +53,35 @@ lite::Tensor *Predictor::GetInput(size_t offset) { ...@@ -52,34 +53,35 @@ lite::Tensor *Predictor::GetInput(size_t offset) {
} }
// get inputs names // get inputs names
std::vector<std::string> Predictor::GetInputNames() { const std::vector<std::string> &Predictor::GetInputNames() {
std::vector<std::string> input_names; return input_names_;
for (auto &item : input_names_) {
input_names.push_back(item.second);
}
return input_names;
} }
// get outputnames // get outputnames
std::vector<std::string> Predictor::GetOutputNames() { const std::vector<std::string> &Predictor::GetOutputNames() {
std::vector<std::string> output_names; return output_names_;
for (auto &item : output_names_) {
output_names.push_back(item.second);
}
return output_names;
} }
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() { void Predictor::PrepareFeedFetch() {
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0); auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
std::vector<cpp::OpDesc *> feeds;
std::vector<cpp::OpDesc *> fetchs;
for (int i = 0; i < current_block->OpsSize(); i++) { for (int i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i); auto op = current_block->GetOp<cpp::OpDesc>(i);
if (op->Type() == "feed") { if (op->Type() == "feed") {
int idx = op->GetAttr<int>("col"); feeds.push_back(op);
input_names_[idx] = op->Output("Out").front();
idx2feeds_[op->Output("Out").front()] = idx;
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
int idx = op->GetAttr<int>("col"); fetchs.push_back(op);
output_names_[idx] = op->Input("X").front(); }
}
input_names_.resize(feeds.size());
output_names_.resize(fetchs.size());
for (int i = 0; i < feeds.size(); i++) {
input_names_[feeds[i]->GetAttr<int>("col")] =
feeds[i]->Output("Out").front();
} }
for (int i = 0; i < fetchs.size(); i++) {
output_names_[fetchs[i]->GetAttr<int>("col")] =
fetchs[i]->Input("X").front();
} }
} }
...@@ -189,16 +191,17 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const { ...@@ -189,16 +191,17 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
} }
// get input by name // get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) { lite::Tensor *Predictor::GetInputByName(const std::string &name) {
if (idx2feeds_.find(name) == idx2feeds_.end()) { auto element = std::find(input_names_.begin(), input_names_.end(), name);
if (element == input_names_.end()) {
LOG(ERROR) << "Model do not have input named with: [" << name LOG(ERROR) << "Model do not have input named with: [" << name
<< "], model's inputs include:"; << "], model's inputs include:";
for (int i = 0; i < input_names_.size(); i++) { for (int i = 0; i < input_names_.size(); i++) {
LOG(ERROR) << "[" << input_names_[i] << "]"; LOG(ERROR) << "[" << input_names_[i] << "]";
} }
return NULL; return nullptr;
} else { } else {
int idx = idx2feeds_[name]; int position = std::distance(input_names_.begin(), element);
return GetInput(idx); return GetInput(position);
} }
} }
......
...@@ -74,8 +74,8 @@ class LITE_API Predictor { ...@@ -74,8 +74,8 @@ class LITE_API Predictor {
// get input by name. // get input by name.
lite::Tensor* GetInputByName(const std::string& name); lite::Tensor* GetInputByName(const std::string& name);
// get inputnames and get outputnames. // get inputnames and get outputnames.
std::vector<std::string> GetInputNames(); const std::vector<std::string>& GetInputNames();
std::vector<std::string> GetOutputNames(); const std::vector<std::string>& GetOutputNames();
void PrepareFeedFetch(); void PrepareFeedFetch();
// Get offset-th col of fetch results. // Get offset-th col of fetch results.
...@@ -107,9 +107,8 @@ class LITE_API Predictor { ...@@ -107,9 +107,8 @@ class LITE_API Predictor {
const Scope* exec_scope_; const Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
bool program_generated_{false}; bool program_generated_{false};
std::map<size_t, std::string> input_names_; std::vector<std::string> input_names_;
std::map<std::string, size_t> idx2feeds_; std::vector<std::string> output_names_;
std::map<size_t, std::string> output_names_;
}; };
/* /*
......
...@@ -37,8 +37,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -37,8 +37,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
std::string GetVersion() const override; std::string GetVersion() const override;
// get inputs names and get outputs names // get inputs names and get outputs names
std::vector<std::string> GetInputNames() override; const std::vector<std::string> &GetInputNames() override;
std::vector<std::string> GetOutputNames() override; const std::vector<std::string> &GetOutputNames() override;
std::unique_ptr<const lite_api::Tensor> GetTensor( std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string &name) const override; const std::string &name) const override;
...@@ -76,11 +76,11 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput( ...@@ -76,11 +76,11 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::vector<std::string> CxxPaddleApiImpl::GetInputNames() { const std::vector<std::string> &CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames(); return raw_predictor_.GetInputNames();
} }
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { const std::vector<std::string> &CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames(); return raw_predictor_.GetOutputNames();
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/api/light_api.h" #include "lite/api/light_api.h"
#include <algorithm>
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -56,16 +57,17 @@ Tensor* LightPredictor::GetInput(size_t offset) { ...@@ -56,16 +57,17 @@ Tensor* LightPredictor::GetInput(size_t offset) {
// get input by name // get input by name
Tensor* LightPredictor::GetInputByName(const std::string& name) { Tensor* LightPredictor::GetInputByName(const std::string& name) {
if (idx2feeds_.find(name) == idx2feeds_.end()) { auto element = std::find(input_names_.begin(), input_names_.end(), name);
if (element == input_names_.end()) {
LOG(ERROR) << "Model do not have input named with: [" << name LOG(ERROR) << "Model do not have input named with: [" << name
<< "], model's inputs include:"; << "], model's inputs include:";
for (int i = 0; i < input_names_.size(); i++) { for (int i = 0; i < input_names_.size(); i++) {
LOG(ERROR) << "[" << input_names_[i] << "]"; LOG(ERROR) << "[" << input_names_[i] << "]";
} }
return NULL; return nullptr;
} else { } else {
int idx = idx2feeds_[name]; int position = std::distance(input_names_.begin(), element);
return GetInput(idx); return GetInput(position);
} }
} }
...@@ -79,34 +81,35 @@ const Tensor* LightPredictor::GetOutput(size_t offset) { ...@@ -79,34 +81,35 @@ const Tensor* LightPredictor::GetOutput(size_t offset) {
return out_var->GetMutable<lite::Tensor>(); return out_var->GetMutable<lite::Tensor>();
} }
// get inputs names // get inputs names
std::vector<std::string> LightPredictor::GetInputNames() { const std::vector<std::string>& LightPredictor::GetInputNames() {
std::vector<std::string> input_names; return input_names_;
for (auto& item : input_names_) {
input_names.push_back(item.second);
}
return input_names;
} }
// get outputnames // get outputnames
std::vector<std::string> LightPredictor::GetOutputNames() { const std::vector<std::string>& LightPredictor::GetOutputNames() {
std::vector<std::string> output_names; return output_names_;
for (auto& item : output_names_) {
output_names.push_back(item.second);
}
return output_names;
} }
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void LightPredictor::PrepareFeedFetch() { void LightPredictor::PrepareFeedFetch() {
auto current_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0); auto current_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0);
std::vector<cpp::OpDesc*> feeds;
std::vector<cpp::OpDesc*> fetchs;
for (int i = 0; i < current_block->OpsSize(); i++) { for (int i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i); auto op = current_block->GetOp<cpp::OpDesc>(i);
if (op->Type() == "feed") { if (op->Type() == "feed") {
int idx = op->GetAttr<int>("col"); feeds.push_back(op);
input_names_[idx] = op->Output("Out").front();
idx2feeds_[op->Output("Out").front()] = idx;
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
int idx = op->GetAttr<int>("col"); fetchs.push_back(op);
output_names_[idx] = op->Input("X").front(); }
}
input_names_.resize(feeds.size());
output_names_.resize(fetchs.size());
for (int i = 0; i < feeds.size(); i++) {
input_names_[feeds[i]->GetAttr<int>("col")] =
feeds[i]->Output("Out").front();
} }
for (int i = 0; i < fetchs.size(); i++) {
output_names_[fetchs[i]->GetAttr<int>("col")] =
fetchs[i]->Input("X").front();
} }
} }
......
...@@ -64,8 +64,8 @@ class LITE_API LightPredictor { ...@@ -64,8 +64,8 @@ class LITE_API LightPredictor {
} }
// get inputnames and get outputnames. // get inputnames and get outputnames.
std::vector<std::string> GetInputNames(); const std::vector<std::string>& GetInputNames();
std::vector<std::string> GetOutputNames(); const std::vector<std::string>& GetOutputNames();
void PrepareFeedFetch(); void PrepareFeedFetch();
private: private:
...@@ -82,9 +82,8 @@ class LITE_API LightPredictor { ...@@ -82,9 +82,8 @@ class LITE_API LightPredictor {
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_; cpp::ProgramDesc cpp_program_desc_;
std::map<size_t, std::string> input_names_; std::vector<std::string> input_names_;
std::map<std::string, size_t> idx2feeds_; std::vector<std::string> output_names_;
std::map<size_t, std::string> output_names_;
}; };
} // namespace lite } // namespace lite
......
...@@ -32,8 +32,8 @@ class LightPredictorImpl : public PaddlePredictor { ...@@ -32,8 +32,8 @@ class LightPredictorImpl : public PaddlePredictor {
void Run() override; void Run() override;
std::string GetVersion() const override; std::string GetVersion() const override;
std::vector<std::string> GetInputNames() override; const std::vector<std::string>& GetInputNames() override;
std::vector<std::string> GetOutputNames() override; const std::vector<std::string>& GetOutputNames() override;
std::unique_ptr<const Tensor> GetTensor( std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const override; const std::string& name) const override;
...@@ -78,11 +78,11 @@ std::unique_ptr<Tensor> LightPredictorImpl::GetInputByName( ...@@ -78,11 +78,11 @@ std::unique_ptr<Tensor> LightPredictorImpl::GetInputByName(
new Tensor(raw_predictor_->GetInputByName(name))); new Tensor(raw_predictor_->GetInputByName(name)));
} }
std::vector<std::string> LightPredictorImpl::GetInputNames() { const std::vector<std::string>& LightPredictorImpl::GetInputNames() {
return raw_predictor_->GetInputNames(); return raw_predictor_->GetInputNames();
} }
std::vector<std::string> LightPredictorImpl::GetOutputNames() { const std::vector<std::string>& LightPredictorImpl::GetOutputNames() {
return raw_predictor_->GetOutputNames(); return raw_predictor_->GetOutputNames();
} }
......
...@@ -36,12 +36,14 @@ TEST(LightAPI, load) { ...@@ -36,12 +36,14 @@ TEST(LightAPI, load) {
data[i] = i; data[i] = i;
} }
std::vector<std::string> inputs = predictor.GetInputNames(); predictor.PrepareFeedFetch();
const std::vector<std::string>& inputs = predictor.GetInputNames();
LOG(INFO) << "input size: " << inputs.size(); LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i]; LOG(INFO) << "inputnames: " << inputs[i];
} }
std::vector<std::string> outputs = predictor.GetOutputNames(); const std::vector<std::string>& outputs = predictor.GetOutputNames();
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i]; LOG(INFO) << "outputnames: " << outputs[i];
} }
......
...@@ -75,9 +75,9 @@ class LITE_API PaddlePredictor { ...@@ -75,9 +75,9 @@ class LITE_API PaddlePredictor {
virtual std::string GetVersion() const = 0; virtual std::string GetVersion() const = 0;
// Get input names // Get input names
virtual std::vector<std::string> GetInputNames() = 0; virtual const std::vector<std::string>& GetInputNames() = 0;
// Get output names // Get output names
virtual std::vector<std::string> GetOutputNames() = 0; virtual const std::vector<std::string>& GetOutputNames() = 0;
// Get Input by name // Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0; virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
......
...@@ -37,12 +37,12 @@ TEST(CxxApi, run) { ...@@ -37,12 +37,12 @@ TEST(CxxApi, run) {
LOG(INFO) << "Version: " << predictor->GetVersion(); LOG(INFO) << "Version: " << predictor->GetVersion();
std::vector<std::string> inputs = predictor->GetInputNames(); auto& inputs = predictor->GetInputNames();
LOG(INFO) << "input size: " << inputs.size(); LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i]; LOG(INFO) << "inputnames: " << inputs[i];
} }
std::vector<std::string> outputs = predictor->GetOutputNames(); auto& outputs = predictor->GetOutputNames();
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i]; LOG(INFO) << "outputnames: " << outputs[i];
} }
...@@ -76,14 +76,14 @@ TEST(LightApi, run) { ...@@ -76,14 +76,14 @@ TEST(LightApi, run) {
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
std::vector<std::string> inputs = predictor->GetInputNames(); auto& inputs = predictor->GetInputNames();
LOG(INFO) << "input size: " << inputs.size(); LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i]; LOG(INFO) << "inputnames: " << inputs.at(i);
} }
std::vector<std::string> outputs = predictor->GetOutputNames(); auto& outputs = predictor->GetOutputNames();
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i]; LOG(INFO) << "outputnames: " << outputs.at(i);
} }
LOG(INFO) << "Version: " << predictor->GetVersion(); LOG(INFO) << "Version: " << predictor->GetVersion();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册