diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index f123c2a9c7db336b6e94b0c1934fc2e284d50f67..5c89c24325e2aeff0f8b0ed7a5cd621f26318b8f 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -151,6 +151,11 @@ std::vector Predictor::GetInputNames() { return input_names_; } // get outputnames std::vector Predictor::GetOutputNames() { return output_names_; } +// get param names +std::vector Predictor::GetParamNames() { + return exec_scope_->AttributeVarNames(); +} + // append the names of inputs and outputs into input_names_ and output_names_ void Predictor::PrepareFeedFetch() { if (!program_) { @@ -346,9 +351,16 @@ void Predictor::GenRuntimeProgram() { const lite::Tensor *Predictor::GetTensor(const std::string &name) const { auto *var = exec_scope_->FindVar(name); + CHECK(var) << "no variable named with " << name << " in exec_scope"; return &var->Get(); } +lite::Tensor *Predictor::GetMutableTensor(const std::string &name) { + auto *var = exec_scope_->FindVar(name); + CHECK(var) << "no variable named with " << name << " in exec_scope"; + return var->GetMutable(); +} + // get input by name lite::Tensor *Predictor::GetInputByName(const std::string &name) { auto element = std::find(input_names_.begin(), input_names_.end(), name); diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 146556756af7e0b56ae38b5303e622c97dfe58af..cd542e87ed3bf4632bce141f019e974af6ef4308 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -85,6 +85,9 @@ class LITE_API Predictor { // get inputnames and get outputnames. std::vector GetInputNames(); std::vector GetOutputNames(); + // get param names + std::vector GetParamNames(); + void PrepareFeedFetch(); // Get offset-th col of fetch results. @@ -92,6 +95,9 @@ class LITE_API Predictor { std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; + // get a mutable tensor according to its name + lite::Tensor* GetMutableTensor(const std::string& name); + // get a const tensor according to its name const lite::Tensor* GetTensor(const std::string& name) const; const RuntimeProgram& runtime_program() const; @@ -142,9 +148,15 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { // get inputs names and get outputs names std::vector GetInputNames() override; std::vector GetOutputNames() override; + // get param names + std::vector GetParamNames() override; + // get tensor according to tensor's name std::unique_ptr GetTensor( const std::string& name) const override; + // get a mutable tensor according to tensor's name + std::unique_ptr GetMutableTensor( + const std::string& name) override; // Get InputTebsor by name std::unique_ptr GetInputByName( diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 28e87dca394ba06844269746c19a892c26e0c653..18eb0b3545eeb27c6661c48b9a91dbf413757606 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -97,6 +97,10 @@ std::vector CxxPaddleApiImpl::GetInputNames() { return raw_predictor_.GetInputNames(); } +std::vector CxxPaddleApiImpl::GetParamNames() { + return raw_predictor_.GetParamNames(); +} + std::vector CxxPaddleApiImpl::GetOutputNames() { return raw_predictor_.GetOutputNames(); } @@ -123,6 +127,12 @@ std::unique_ptr CxxPaddleApiImpl::GetTensor( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetMutableTensor( + const std::string &name) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_.GetMutableTensor(name))); +} + std::unique_ptr CxxPaddleApiImpl::GetInputByName( const std::string &name) { return std::unique_ptr( diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index daef2c66dda5188a1eec25c3d5f045f1fa705e1e..4b13ae4ed241eb1a3164a1213feec12306df89f6 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -167,6 +167,20 @@ lod_t Tensor::lod() const { return ctensor(raw_tensor_)->lod(); } void Tensor::SetLoD(const lod_t &lod) { tensor(raw_tensor_)->set_lod(lod); } +std::unique_ptr PaddlePredictor::GetMutableTensor( + const std::string &name) { + LOG(FATAL) + << "The GetMutableTensor API is only supported by CxxConfig predictor."; + return nullptr; +} + +std::vector PaddlePredictor::GetParamNames() { + std::vector null_result = {}; + LOG(FATAL) + << "The GetParamNames API is only supported by CxxConfig predictor."; + return null_result; +} + void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir, LiteModelType model_type, bool record_info) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 31600bda3017861a9f43b1f5b844ab0157395627..dfb0a7fa68579e24eac22a7edee89a8cf9e12d5c 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -86,6 +86,8 @@ class LITE_API PaddlePredictor { virtual std::vector GetInputNames() = 0; // Get output names virtual std::vector GetOutputNames() = 0; + // Get output names + virtual std::vector GetParamNames(); // Get Input by name virtual std::unique_ptr GetInputByName(const std::string& name) = 0; @@ -93,6 +95,9 @@ class LITE_API PaddlePredictor { /// Get a readonly tensor, return null if no one called `name` exists. virtual std::unique_ptr GetTensor( const std::string& name) const = 0; + /// Get a mutable tensor, return null if on one called `name` exists + /// internal infereces API, not recommanded. + virtual std::unique_ptr GetMutableTensor(const std::string& name); /// Persist the optimized model to disk. This API is only supported by /// CxxConfig, and the persisted model can be reused for MobileConfig. diff --git a/lite/core/scope.cc b/lite/core/scope.cc index 775652e2a0d3c962c17dc796ef5f1d381411fa50..d87360a1da8215332c71739bbfa2660977f4f74c 100644 --- a/lite/core/scope.cc +++ b/lite/core/scope.cc @@ -60,6 +60,29 @@ Variable *Scope::FindLocalVar(const std::string &name) const { return nullptr; } +// AttributeVarNames will get persistive attribute names stored in parent scope +std::vector Scope::AttributeVarNames() const { + std::vector resulted_keys; + const Scope *cur_scope = this; + while (cur_scope->parent()) { + cur_scope = cur_scope->parent(); + auto keys = cur_scope->LocalVarNames(); + resulted_keys.insert(resulted_keys.end(), keys.begin(), keys.end()); + } + // remove feed and fetch + std::vector skiped_vars = {"feed", "fetch"}; + for (int i = 0; i < skiped_vars.size(); i++) { + auto iter = + std::find(resulted_keys.begin(), resulted_keys.end(), skiped_vars[i]); + while (iter != resulted_keys.end()) { + resulted_keys.erase(iter); + iter = + std::find(resulted_keys.begin(), resulted_keys.end(), skiped_vars[i]); + } + } + return resulted_keys; +} + std::vector Scope::LocalVarNames() const { std::vector keys; for (const auto &item : vars_) { diff --git a/lite/core/scope.h b/lite/core/scope.h index 2593c365224a0564caa27cf10eee1f917b90c342..aa3a8a1bfb7f4bf1cc00b548c0b0962ce8d73663 100644 --- a/lite/core/scope.h +++ b/lite/core/scope.h @@ -45,6 +45,8 @@ class Scope final { const Scope* parent() const { return parent_; } + // Get attribute params stored in parent scopes. + std::vector AttributeVarNames() const; // Following the legacy scope interface. std::vector LocalVarNames() const;