未验证 提交 c6e411f7 编写于 作者: H huzhiqiang 提交者: GitHub

[develop API] add `GetMutableTensor` API into `cxx_predictor` (#3099)

上级 263a4ead
...@@ -151,6 +151,11 @@ std::vector<std::string> Predictor::GetInputNames() { return input_names_; } ...@@ -151,6 +151,11 @@ std::vector<std::string> Predictor::GetInputNames() { return input_names_; }
// get outputnames // get outputnames
std::vector<std::string> Predictor::GetOutputNames() { return output_names_; } std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// get param names
std::vector<std::string> Predictor::GetParamNames() {
return exec_scope_->AttributeVarNames();
}
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() { void Predictor::PrepareFeedFetch() {
if (!program_) { if (!program_) {
...@@ -346,9 +351,16 @@ void Predictor::GenRuntimeProgram() { ...@@ -346,9 +351,16 @@ void Predictor::GenRuntimeProgram() {
const lite::Tensor *Predictor::GetTensor(const std::string &name) const { const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
auto *var = exec_scope_->FindVar(name); auto *var = exec_scope_->FindVar(name);
CHECK(var) << "no variable named with " << name << " in exec_scope";
return &var->Get<lite::Tensor>(); return &var->Get<lite::Tensor>();
} }
lite::Tensor *Predictor::GetMutableTensor(const std::string &name) {
auto *var = exec_scope_->FindVar(name);
CHECK(var) << "no variable named with " << name << " in exec_scope";
return var->GetMutable<lite::Tensor>();
}
// get input by name // get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) { lite::Tensor *Predictor::GetInputByName(const std::string &name) {
auto element = std::find(input_names_.begin(), input_names_.end(), name); auto element = std::find(input_names_.begin(), input_names_.end(), name);
......
...@@ -85,6 +85,9 @@ class LITE_API Predictor { ...@@ -85,6 +85,9 @@ class LITE_API Predictor {
// get inputnames and get outputnames. // get inputnames and get outputnames.
std::vector<std::string> GetInputNames(); std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames(); std::vector<std::string> GetOutputNames();
// get param names
std::vector<std::string> GetParamNames();
void PrepareFeedFetch(); void PrepareFeedFetch();
// Get offset-th col of fetch results. // Get offset-th col of fetch results.
...@@ -92,6 +95,9 @@ class LITE_API Predictor { ...@@ -92,6 +95,9 @@ class LITE_API Predictor {
std::vector<const lite::Tensor*> GetOutputs() const; std::vector<const lite::Tensor*> GetOutputs() const;
const cpp::ProgramDesc& program_desc() const; const cpp::ProgramDesc& program_desc() const;
// get a mutable tensor according to its name
lite::Tensor* GetMutableTensor(const std::string& name);
// get a const tensor according to its name
const lite::Tensor* GetTensor(const std::string& name) const; const lite::Tensor* GetTensor(const std::string& name) const;
const RuntimeProgram& runtime_program() const; const RuntimeProgram& runtime_program() const;
...@@ -142,9 +148,15 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -142,9 +148,15 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
// get inputs names and get outputs names // get inputs names and get outputs names
std::vector<std::string> GetInputNames() override; std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override; std::vector<std::string> GetOutputNames() override;
// get param names
std::vector<std::string> GetParamNames() override;
// get tensor according to tensor's name
std::unique_ptr<const lite_api::Tensor> GetTensor( std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string& name) const override; const std::string& name) const override;
// get a mutable tensor according to tensor's name
std::unique_ptr<lite_api::Tensor> GetMutableTensor(
const std::string& name) override;
// Get InputTebsor by name // Get InputTebsor by name
std::unique_ptr<lite_api::Tensor> GetInputByName( std::unique_ptr<lite_api::Tensor> GetInputByName(
......
...@@ -97,6 +97,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetInputNames() { ...@@ -97,6 +97,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames(); return raw_predictor_.GetInputNames();
} }
std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_.GetParamNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames(); return raw_predictor_.GetOutputNames();
} }
...@@ -123,6 +127,12 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor( ...@@ -123,6 +127,12 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetMutableTensor(
const std::string &name) {
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetMutableTensor(name)));
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName( std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName(
const std::string &name) { const std::string &name) {
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
......
...@@ -167,6 +167,20 @@ lod_t Tensor::lod() const { return ctensor(raw_tensor_)->lod(); } ...@@ -167,6 +167,20 @@ lod_t Tensor::lod() const { return ctensor(raw_tensor_)->lod(); }
void Tensor::SetLoD(const lod_t &lod) { tensor(raw_tensor_)->set_lod(lod); } void Tensor::SetLoD(const lod_t &lod) { tensor(raw_tensor_)->set_lod(lod); }
std::unique_ptr<Tensor> PaddlePredictor::GetMutableTensor(
const std::string &name) {
LOG(FATAL)
<< "The GetMutableTensor API is only supported by CxxConfig predictor.";
return nullptr;
}
std::vector<std::string> PaddlePredictor::GetParamNames() {
std::vector<std::string> null_result = {};
LOG(FATAL)
<< "The GetParamNames API is only supported by CxxConfig predictor.";
return null_result;
}
void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir, void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir,
LiteModelType model_type, LiteModelType model_type,
bool record_info) { bool record_info) {
......
...@@ -86,6 +86,8 @@ class LITE_API PaddlePredictor { ...@@ -86,6 +86,8 @@ class LITE_API PaddlePredictor {
virtual std::vector<std::string> GetInputNames() = 0; virtual std::vector<std::string> GetInputNames() = 0;
// Get output names // Get output names
virtual std::vector<std::string> GetOutputNames() = 0; virtual std::vector<std::string> GetOutputNames() = 0;
// Get output names
virtual std::vector<std::string> GetParamNames();
// Get Input by name // Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0; virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
...@@ -93,6 +95,9 @@ class LITE_API PaddlePredictor { ...@@ -93,6 +95,9 @@ class LITE_API PaddlePredictor {
/// Get a readonly tensor, return null if no one called `name` exists. /// Get a readonly tensor, return null if no one called `name` exists.
virtual std::unique_ptr<const Tensor> GetTensor( virtual std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const = 0; const std::string& name) const = 0;
/// Get a mutable tensor, return null if on one called `name` exists
/// internal infereces API, not recommanded.
virtual std::unique_ptr<Tensor> GetMutableTensor(const std::string& name);
/// Persist the optimized model to disk. This API is only supported by /// Persist the optimized model to disk. This API is only supported by
/// CxxConfig, and the persisted model can be reused for MobileConfig. /// CxxConfig, and the persisted model can be reused for MobileConfig.
......
...@@ -60,6 +60,29 @@ Variable *Scope::FindLocalVar(const std::string &name) const { ...@@ -60,6 +60,29 @@ Variable *Scope::FindLocalVar(const std::string &name) const {
return nullptr; return nullptr;
} }
// AttributeVarNames will get persistive attribute names stored in parent scope
std::vector<std::string> Scope::AttributeVarNames() const {
std::vector<std::string> resulted_keys;
const Scope *cur_scope = this;
while (cur_scope->parent()) {
cur_scope = cur_scope->parent();
auto keys = cur_scope->LocalVarNames();
resulted_keys.insert(resulted_keys.end(), keys.begin(), keys.end());
}
// remove feed and fetch
std::vector<std::string> skiped_vars = {"feed", "fetch"};
for (int i = 0; i < skiped_vars.size(); i++) {
auto iter =
std::find(resulted_keys.begin(), resulted_keys.end(), skiped_vars[i]);
while (iter != resulted_keys.end()) {
resulted_keys.erase(iter);
iter =
std::find(resulted_keys.begin(), resulted_keys.end(), skiped_vars[i]);
}
}
return resulted_keys;
}
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys; std::vector<std::string> keys;
for (const auto &item : vars_) { for (const auto &item : vars_) {
......
...@@ -45,6 +45,8 @@ class Scope final { ...@@ -45,6 +45,8 @@ class Scope final {
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
// Get attribute params stored in parent scopes.
std::vector<std::string> AttributeVarNames() const;
// Following the legacy scope interface. // Following the legacy scope interface.
std::vector<std::string> LocalVarNames() const; std::vector<std::string> LocalVarNames() const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册