未验证 提交 8591aaec 编写于 作者: H huzhiqiang 提交者: GitHub

Fix codestyle of GetInputName&GetOutputName (#2185)

* add shell file to automatically build and collect publish result test=develop

* modify codestyle of getInputNames test=develop

* test=develop

* rm publish.sh

* remove copy of func param

* test=develop

* test=devcelop

* test=develop

* test=develop

* const & test=develop

* modify variable defination test=develop

* test=develop

* test=develop

* test=develop

* test=develop
上级 4ac51a6b
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/api/cxx_api.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
......@@ -52,35 +53,36 @@ lite::Tensor *Predictor::GetInput(size_t offset) {
}
// get inputs names
std::vector<std::string> Predictor::GetInputNames() {
std::vector<std::string> input_names;
for (auto &item : input_names_) {
input_names.push_back(item.second);
}
return input_names;
const std::vector<std::string> &Predictor::GetInputNames() {
return input_names_;
}
// get outputnames
std::vector<std::string> Predictor::GetOutputNames() {
std::vector<std::string> output_names;
for (auto &item : output_names_) {
output_names.push_back(item.second);
}
return output_names;
const std::vector<std::string> &Predictor::GetOutputNames() {
return output_names_;
}
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
std::vector<cpp::OpDesc *> feeds;
std::vector<cpp::OpDesc *> fetchs;
for (int i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
if (op->Type() == "feed") {
int idx = op->GetAttr<int>("col");
input_names_[idx] = op->Output("Out").front();
idx2feeds_[op->Output("Out").front()] = idx;
feeds.push_back(op);
} else if (op->Type() == "fetch") {
int idx = op->GetAttr<int>("col");
output_names_[idx] = op->Input("X").front();
fetchs.push_back(op);
}
}
input_names_.resize(feeds.size());
output_names_.resize(fetchs.size());
for (int i = 0; i < feeds.size(); i++) {
input_names_[feeds[i]->GetAttr<int>("col")] =
feeds[i]->Output("Out").front();
}
for (int i = 0; i < fetchs.size(); i++) {
output_names_[fetchs[i]->GetAttr<int>("col")] =
fetchs[i]->Input("X").front();
}
}
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
......@@ -189,16 +191,17 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
}
// get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) {
if (idx2feeds_.find(name) == idx2feeds_.end()) {
auto element = std::find(input_names_.begin(), input_names_.end(), name);
if (element == input_names_.end()) {
LOG(ERROR) << "Model do not have input named with: [" << name
<< "], model's inputs include:";
for (int i = 0; i < input_names_.size(); i++) {
LOG(ERROR) << "[" << input_names_[i] << "]";
}
return NULL;
return nullptr;
} else {
int idx = idx2feeds_[name];
return GetInput(idx);
int position = std::distance(input_names_.begin(), element);
return GetInput(position);
}
}
......
......@@ -74,8 +74,8 @@ class LITE_API Predictor {
// get input by name.
lite::Tensor* GetInputByName(const std::string& name);
// get inputnames and get outputnames.
std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames();
const std::vector<std::string>& GetInputNames();
const std::vector<std::string>& GetOutputNames();
void PrepareFeedFetch();
// Get offset-th col of fetch results.
......@@ -107,9 +107,8 @@ class LITE_API Predictor {
const Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::map<size_t, std::string> input_names_;
std::map<std::string, size_t> idx2feeds_;
std::map<size_t, std::string> output_names_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
};
/*
......
......@@ -37,8 +37,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
std::string GetVersion() const override;
// get inputs names and get outputs names
std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override;
const std::vector<std::string> &GetInputNames() override;
const std::vector<std::string> &GetOutputNames() override;
std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string &name) const override;
......@@ -76,11 +76,11 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
}
std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
const std::vector<std::string> &CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
const std::vector<std::string> &CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames();
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/api/light_api.h"
#include <algorithm>
namespace paddle {
namespace lite {
......@@ -56,16 +57,17 @@ Tensor* LightPredictor::GetInput(size_t offset) {
// get input by name
Tensor* LightPredictor::GetInputByName(const std::string& name) {
if (idx2feeds_.find(name) == idx2feeds_.end()) {
auto element = std::find(input_names_.begin(), input_names_.end(), name);
if (element == input_names_.end()) {
LOG(ERROR) << "Model do not have input named with: [" << name
<< "], model's inputs include:";
for (int i = 0; i < input_names_.size(); i++) {
LOG(ERROR) << "[" << input_names_[i] << "]";
}
return NULL;
return nullptr;
} else {
int idx = idx2feeds_[name];
return GetInput(idx);
int position = std::distance(input_names_.begin(), element);
return GetInput(position);
}
}
......@@ -79,35 +81,36 @@ const Tensor* LightPredictor::GetOutput(size_t offset) {
return out_var->GetMutable<lite::Tensor>();
}
// get inputs names
std::vector<std::string> LightPredictor::GetInputNames() {
std::vector<std::string> input_names;
for (auto& item : input_names_) {
input_names.push_back(item.second);
}
return input_names;
const std::vector<std::string>& LightPredictor::GetInputNames() {
return input_names_;
}
// get outputnames
std::vector<std::string> LightPredictor::GetOutputNames() {
std::vector<std::string> output_names;
for (auto& item : output_names_) {
output_names.push_back(item.second);
}
return output_names;
const std::vector<std::string>& LightPredictor::GetOutputNames() {
return output_names_;
}
// append the names of inputs and outputs into input_names_ and output_names_
void LightPredictor::PrepareFeedFetch() {
auto current_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0);
std::vector<cpp::OpDesc*> feeds;
std::vector<cpp::OpDesc*> fetchs;
for (int i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
if (op->Type() == "feed") {
int idx = op->GetAttr<int>("col");
input_names_[idx] = op->Output("Out").front();
idx2feeds_[op->Output("Out").front()] = idx;
feeds.push_back(op);
} else if (op->Type() == "fetch") {
int idx = op->GetAttr<int>("col");
output_names_[idx] = op->Input("X").front();
fetchs.push_back(op);
}
}
input_names_.resize(feeds.size());
output_names_.resize(fetchs.size());
for (int i = 0; i < feeds.size(); i++) {
input_names_[feeds[i]->GetAttr<int>("col")] =
feeds[i]->Output("Out").front();
}
for (int i = 0; i < fetchs.size(); i++) {
output_names_[fetchs[i]->GetAttr<int>("col")] =
fetchs[i]->Input("X").front();
}
}
void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
......
......@@ -64,8 +64,8 @@ class LITE_API LightPredictor {
}
// get inputnames and get outputnames.
std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames();
const std::vector<std::string>& GetInputNames();
const std::vector<std::string>& GetOutputNames();
void PrepareFeedFetch();
private:
......@@ -82,9 +82,8 @@ class LITE_API LightPredictor {
std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_;
std::map<size_t, std::string> input_names_;
std::map<std::string, size_t> idx2feeds_;
std::map<size_t, std::string> output_names_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
};
} // namespace lite
......
......@@ -32,8 +32,8 @@ class LightPredictorImpl : public PaddlePredictor {
void Run() override;
std::string GetVersion() const override;
std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override;
const std::vector<std::string>& GetInputNames() override;
const std::vector<std::string>& GetOutputNames() override;
std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const override;
......@@ -78,11 +78,11 @@ std::unique_ptr<Tensor> LightPredictorImpl::GetInputByName(
new Tensor(raw_predictor_->GetInputByName(name)));
}
std::vector<std::string> LightPredictorImpl::GetInputNames() {
const std::vector<std::string>& LightPredictorImpl::GetInputNames() {
return raw_predictor_->GetInputNames();
}
std::vector<std::string> LightPredictorImpl::GetOutputNames() {
const std::vector<std::string>& LightPredictorImpl::GetOutputNames() {
return raw_predictor_->GetOutputNames();
}
......
......@@ -36,12 +36,14 @@ TEST(LightAPI, load) {
data[i] = i;
}
std::vector<std::string> inputs = predictor.GetInputNames();
predictor.PrepareFeedFetch();
const std::vector<std::string>& inputs = predictor.GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i];
}
std::vector<std::string> outputs = predictor.GetOutputNames();
const std::vector<std::string>& outputs = predictor.GetOutputNames();
for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i];
}
......
......@@ -75,9 +75,9 @@ class LITE_API PaddlePredictor {
virtual std::string GetVersion() const = 0;
// Get input names
virtual std::vector<std::string> GetInputNames() = 0;
virtual const std::vector<std::string>& GetInputNames() = 0;
// Get output names
virtual std::vector<std::string> GetOutputNames() = 0;
virtual const std::vector<std::string>& GetOutputNames() = 0;
// Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
......
......@@ -37,12 +37,12 @@ TEST(CxxApi, run) {
LOG(INFO) << "Version: " << predictor->GetVersion();
std::vector<std::string> inputs = predictor->GetInputNames();
auto& inputs = predictor->GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i];
}
std::vector<std::string> outputs = predictor->GetOutputNames();
auto& outputs = predictor->GetOutputNames();
for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i];
}
......@@ -76,14 +76,14 @@ TEST(LightApi, run) {
auto predictor = lite_api::CreatePaddlePredictor(config);
std::vector<std::string> inputs = predictor->GetInputNames();
auto& inputs = predictor->GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i];
LOG(INFO) << "inputnames: " << inputs.at(i);
}
std::vector<std::string> outputs = predictor->GetOutputNames();
auto& outputs = predictor->GetOutputNames();
for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i];
LOG(INFO) << "outputnames: " << outputs.at(i);
}
LOG(INFO) << "Version: " << predictor->GetVersion();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册