提交 1cd077dc 编写于 作者: H huzhiqiang 提交者: GitHub

add GetInputNames 、 GetOutPutNames 、 GetInputByName and GetTensor method (#2154)


* add GetInputNames and GetOutPutNames and GetInputByName method test=develop
上级 2f035fec
......@@ -64,6 +64,38 @@ lite::Tensor *Predictor::GetInput(size_t offset) {
return &feed_list->at(offset);
}
// get inputs names
std::vector<std::string> Predictor::GetInputNames() {
std::vector<std::string> input_names;
for (auto &item : input_names_) {
input_names.push_back(item.second);
}
return input_names;
}
// get outputnames
std::vector<std::string> Predictor::GetOutputNames() {
std::vector<std::string> output_names;
for (auto &item : output_names_) {
output_names.push_back(item.second);
}
return output_names;
}
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
for (int i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
if (op->Type() == "feed") {
int idx = op->GetAttr<int>("col");
input_names_[idx] = op->Output("Out").front();
idx2feeds_[op->Output("Out").front()] = idx;
} else if (op->Type() == "fetch") {
int idx = op->GetAttr<int>("col");
output_names_[idx] = op->Input("X").front();
}
}
}
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
......@@ -162,6 +194,20 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
auto *var = exec_scope_->FindVar(name);
return &var->Get<lite::Tensor>();
}
// get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) {
if (idx2feeds_.find(name) == idx2feeds_.end()) {
LOG(ERROR) << "Model do not have input named with: [" << name
<< "], model's inputs include:";
for (int i = 0; i < input_names_.size(); i++) {
LOG(ERROR) << "[" << input_names_[i] << "]";
}
return NULL;
} else {
int idx = idx2feeds_[name];
return GetInput(idx);
}
}
#ifdef LITE_WITH_TRAIN
void Predictor::FeedVars(const std::vector<framework::Tensor> &tensors) {
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <utility>
......@@ -72,6 +73,12 @@ class LITE_API Predictor {
// Get offset-th col of feed inputs.
lite::Tensor* GetInput(size_t offset);
// get input by name.
lite::Tensor* GetInputByName(const std::string& name);
// get inputnames and get outputnames.
std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames();
void PrepareFeedFetch();
// Get offset-th col of fetch results.
const lite::Tensor* GetOutput(size_t offset) const;
......@@ -102,6 +109,9 @@ class LITE_API Predictor {
const Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::map<size_t, std::string> input_names_;
std::map<std::string, size_t> idx2feeds_;
std::map<size_t, std::string> output_names_;
};
/*
......
......@@ -36,9 +36,17 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
std::string GetVersion() const override;
// get inputs names and get outputs names
std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override;
std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string &name) const override;
// Get InputTebsor by name
std::unique_ptr<lite_api::Tensor> GetInputByName(
const std::string &name) override;
void SaveOptimizedModel(const std::string &model_dir,
lite_api::LiteModelType model_type =
lite_api::LiteModelType::kProtobuf) override;
......@@ -56,6 +64,7 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
auto places = config.valid_places();
places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
raw_predictor_.Build(config, places);
raw_predictor_.PrepareFeedFetch();
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......@@ -69,6 +78,14 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
}
std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames();
}
void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); }
std::string CxxPaddleApiImpl::GetVersion() const { return version(); }
......@@ -79,6 +96,12 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName(
const std::string &name) {
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetInputByName(name)));
}
void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir,
lite_api::LiteModelType model_type) {
raw_predictor_.SaveModel(model_dir, model_type);
......
......@@ -53,6 +53,21 @@ Tensor* LightPredictor::GetInput(size_t offset) {
return &feed_list->at(offset);
}
// get input by name
Tensor* LightPredictor::GetInputByName(const std::string& name) {
if (idx2feeds_.find(name) == idx2feeds_.end()) {
LOG(ERROR) << "Model do not have input named with: [" << name
<< "], model's inputs include:";
for (int i = 0; i < input_names_.size(); i++) {
LOG(ERROR) << "[" << input_names_[i] << "]";
}
return NULL;
} else {
int idx = idx2feeds_[name];
return GetInput(idx);
}
}
const Tensor* LightPredictor::GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
......@@ -60,6 +75,37 @@ const Tensor* LightPredictor::GetOutput(size_t offset) {
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
// get inputs names
std::vector<std::string> LightPredictor::GetInputNames() {
std::vector<std::string> input_names;
for (auto& item : input_names_) {
input_names.push_back(item.second);
}
return input_names;
}
// get outputnames
std::vector<std::string> LightPredictor::GetOutputNames() {
std::vector<std::string> output_names;
for (auto& item : output_names_) {
output_names.push_back(item.second);
}
return output_names;
}
// append the names of inputs and outputs into input_names_ and output_names_
void LightPredictor::PrepareFeedFetch() {
auto current_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0);
for (int i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
if (op->Type() == "feed") {
int idx = op->GetAttr<int>("col");
input_names_[idx] = op->Output("Out").front();
idx2feeds_[op->Output("Out").front()] = idx;
} else if (op->Type() == "fetch") {
int idx = op->GetAttr<int>("col");
output_names_[idx] = op->Input("X").front();
}
}
}
void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
std::vector<Instruction> insts;
......
......@@ -18,6 +18,7 @@
*/
#pragma once
#include <map>
#include <memory>
#include <string>
#include <utility>
......@@ -52,7 +53,8 @@ class LITE_API LightPredictor {
// Get offset-th col of feed inputs.
Tensor* GetInput(size_t offset);
// get input by name.
Tensor* GetInputByName(const std::string& name);
// Get offset-th col of fetch outputs.
const Tensor* GetOutput(size_t offset);
......@@ -61,6 +63,11 @@ class LITE_API LightPredictor {
return &var->Get<lite::Tensor>();
}
// get inputnames and get outputnames.
std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames();
void PrepareFeedFetch();
private:
void Build(
const std::string& model_dir,
......@@ -75,6 +82,9 @@ class LITE_API LightPredictor {
std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_;
std::map<size_t, std::string> input_names_;
std::map<std::string, size_t> idx2feeds_;
std::map<size_t, std::string> output_names_;
};
} // namespace lite
......
......@@ -32,9 +32,13 @@ class LightPredictorImpl : public PaddlePredictor {
void Run() override;
std::string GetVersion() const override;
std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override;
std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const override;
// Get InputTebsor by name
std::unique_ptr<Tensor> GetInputByName(const std::string& name) override;
void Init(const MobileConfig& config);
......@@ -49,6 +53,7 @@ void LightPredictorImpl::Init(const MobileConfig& config) {
config.param_buffer(),
config.model_from_memory(),
LiteModelType::kNaiveBuffer));
raw_predictor_->PrepareFeedFetch();
}
std::unique_ptr<Tensor> LightPredictorImpl::GetInput(int i) {
......@@ -68,6 +73,19 @@ std::unique_ptr<const Tensor> LightPredictorImpl::GetTensor(
return std::unique_ptr<const Tensor>(
new Tensor(raw_predictor_->GetTensor(name)));
}
std::unique_ptr<Tensor> LightPredictorImpl::GetInputByName(
const std::string& name) {
return std::unique_ptr<Tensor>(
new Tensor(raw_predictor_->GetInputByName(name)));
}
std::vector<std::string> LightPredictorImpl::GetInputNames() {
return raw_predictor_->GetInputNames();
}
std::vector<std::string> LightPredictorImpl::GetOutputNames() {
return raw_predictor_->GetOutputNames();
}
template <>
std::shared_ptr<PaddlePredictor> CreatePaddlePredictor(
......
......@@ -36,6 +36,17 @@ TEST(LightAPI, load) {
data[i] = i;
}
predictor.PrepareFeedFetch();
std::vector<std::string> inputs = predictor.GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i];
}
std::vector<std::string> outputs = predictor.GetOutputNames();
for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i];
}
predictor.Run();
const auto* output = predictor.GetOutput(0);
......
......@@ -74,6 +74,14 @@ class LITE_API PaddlePredictor {
virtual std::string GetVersion() const = 0;
// Get input names
virtual std::vector<std::string> GetInputNames() = 0;
// Get output names
virtual std::vector<std::string> GetOutputNames() = 0;
// Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
/// Get a readonly tensor, return null if no one called `name` exists.
virtual std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const = 0;
......
......@@ -38,7 +38,16 @@ TEST(CxxApi, run) {
LOG(INFO) << "Version: " << predictor->GetVersion();
auto input_tensor = predictor->GetInput(0);
std::vector<std::string> inputs = predictor->GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i];
}
std::vector<std::string> outputs = predictor->GetOutputNames();
for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i];
}
auto input_tensor = predictor->GetInputByName(inputs[0]);
input_tensor->Resize(std::vector<int64_t>({100, 100}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
......@@ -47,7 +56,7 @@ TEST(CxxApi, run) {
predictor->Run();
auto output = predictor->GetOutput(0);
auto output = predictor->GetTensor(outputs[0]);
auto* out = output->data<float>();
LOG(INFO) << out[0];
LOG(INFO) << out[1];
......@@ -68,6 +77,16 @@ TEST(LightApi, run) {
auto predictor = lite_api::CreatePaddlePredictor(config);
std::vector<std::string> inputs = predictor->GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
LOG(INFO) << "inputnames: " << inputs[i];
}
std::vector<std::string> outputs = predictor->GetOutputNames();
for (int i = 0; i < outputs.size(); i++) {
LOG(INFO) << "outputnames: " << outputs[i];
}
LOG(INFO) << "Version: " << predictor->GetVersion();
auto input_tensor = predictor->GetInput(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册