提交 12c64f3f 编写于 作者: H huzhiqiang 提交者: GitHub

[develop API] add `GetMutableTensor` API into `cxx_predictor` (#3099)

上级 a08a7237
......@@ -151,6 +151,11 @@ std::vector<std::string> Predictor::GetInputNames() { return input_names_; }
// get outputnames
std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// get param names
std::vector<std::string> Predictor::GetParamNames() {
return exec_scope_->AttributeVarNames();
}
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
if (!program_) {
......@@ -346,9 +351,16 @@ void Predictor::GenRuntimeProgram() {
const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
auto *var = exec_scope_->FindVar(name);
CHECK(var) << "no variable named with " << name << " in exec_scope";
return &var->Get<lite::Tensor>();
}
lite::Tensor *Predictor::GetMutableTensor(const std::string &name) {
auto *var = exec_scope_->FindVar(name);
CHECK(var) << "no variable named with " << name << " in exec_scope";
return var->GetMutable<lite::Tensor>();
}
// get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) {
auto element = std::find(input_names_.begin(), input_names_.end(), name);
......
......@@ -85,6 +85,9 @@ class LITE_API Predictor {
// get inputnames and get outputnames.
std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames();
// get param names
std::vector<std::string> GetParamNames();
void PrepareFeedFetch();
// Get offset-th col of fetch results.
......@@ -92,6 +95,9 @@ class LITE_API Predictor {
std::vector<const lite::Tensor*> GetOutputs() const;
const cpp::ProgramDesc& program_desc() const;
// get a mutable tensor according to its name
lite::Tensor* GetMutableTensor(const std::string& name);
// get a const tensor according to its name
const lite::Tensor* GetTensor(const std::string& name) const;
const RuntimeProgram& runtime_program() const;
......@@ -142,9 +148,15 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
// get inputs names and get outputs names
std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override;
// get param names
std::vector<std::string> GetParamNames() override;
// get tensor according to tensor's name
std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string& name) const override;
// get a mutable tensor according to tensor's name
std::unique_ptr<lite_api::Tensor> GetMutableTensor(
const std::string& name) override;
// Get InputTebsor by name
std::unique_ptr<lite_api::Tensor> GetInputByName(
......
......@@ -97,6 +97,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_.GetParamNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames();
}
......@@ -123,6 +127,12 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetMutableTensor(
const std::string &name) {
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetMutableTensor(name)));
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName(
const std::string &name) {
return std::unique_ptr<lite_api::Tensor>(
......
......@@ -167,6 +167,20 @@ lod_t Tensor::lod() const { return ctensor(raw_tensor_)->lod(); }
void Tensor::SetLoD(const lod_t &lod) { tensor(raw_tensor_)->set_lod(lod); }
std::unique_ptr<Tensor> PaddlePredictor::GetMutableTensor(
const std::string &name) {
LOG(FATAL)
<< "The GetMutableTensor API is only supported by CxxConfig predictor.";
return nullptr;
}
std::vector<std::string> PaddlePredictor::GetParamNames() {
std::vector<std::string> null_result = {};
LOG(FATAL)
<< "The GetParamNames API is only supported by CxxConfig predictor.";
return null_result;
}
void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir,
LiteModelType model_type,
bool record_info) {
......
......@@ -86,6 +86,8 @@ class LITE_API PaddlePredictor {
virtual std::vector<std::string> GetInputNames() = 0;
// Get output names
virtual std::vector<std::string> GetOutputNames() = 0;
// Get output names
virtual std::vector<std::string> GetParamNames();
// Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
......@@ -93,6 +95,9 @@ class LITE_API PaddlePredictor {
/// Get a readonly tensor, return null if no one called `name` exists.
virtual std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const = 0;
/// Get a mutable tensor, return null if on one called `name` exists
/// internal infereces API, not recommanded.
virtual std::unique_ptr<Tensor> GetMutableTensor(const std::string& name);
/// Persist the optimized model to disk. This API is only supported by
/// CxxConfig, and the persisted model can be reused for MobileConfig.
......
......@@ -60,6 +60,29 @@ Variable *Scope::FindLocalVar(const std::string &name) const {
return nullptr;
}
// AttributeVarNames will get persistive attribute names stored in parent scope
std::vector<std::string> Scope::AttributeVarNames() const {
std::vector<std::string> resulted_keys;
const Scope *cur_scope = this;
while (cur_scope->parent()) {
cur_scope = cur_scope->parent();
auto keys = cur_scope->LocalVarNames();
resulted_keys.insert(resulted_keys.end(), keys.begin(), keys.end());
}
// remove feed and fetch
std::vector<std::string> skiped_vars = {"feed", "fetch"};
for (int i = 0; i < skiped_vars.size(); i++) {
auto iter =
std::find(resulted_keys.begin(), resulted_keys.end(), skiped_vars[i]);
while (iter != resulted_keys.end()) {
resulted_keys.erase(iter);
iter =
std::find(resulted_keys.begin(), resulted_keys.end(), skiped_vars[i]);
}
}
return resulted_keys;
}
std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys;
for (const auto &item : vars_) {
......
......@@ -45,6 +45,8 @@ class Scope final {
const Scope* parent() const { return parent_; }
// Get attribute params stored in parent scopes.
std::vector<std::string> AttributeVarNames() const;
// Following the legacy scope interface.
std::vector<std::string> LocalVarNames() const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册