diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index f6f7ec75e65ff54e3f3642822e51057d3522ae3a..3fcaaac65be82e062cb554a596fb82b531b258ba 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -314,9 +314,16 @@ void Predictor::GenRuntimeProgram() { const lite::Tensor *Predictor::GetTensor(const std::string &name) const { auto *var = exec_scope_->FindVar(name); + CHECK(var) << "no variable named with " << name << " in exec_scope"; return &var->Get(); } +lite::Tensor *Predictor::GetMutableTensor(const std::string &name) const { + auto *var = exec_scope_->FindVar(name); + CHECK(var) << "no variable named with " << name << " in exec_scope"; + return var->GetMutable(); +} + // get input by name lite::Tensor *Predictor::GetInputByName(const std::string &name) { auto element = std::find(input_names_.begin(), input_names_.end(), name); diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 504710d9fa29420b8762f31e0c675b59c6c626bd..0348ab8ef13d95c469f49e069dd6f5fc1f76f07b 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -91,6 +91,9 @@ class LITE_API Predictor { std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; + // get a mutable tensor according to its name + lite::Tensor* GetMutableTensor(const std::string& name); + // get a const tensor according to its name const lite::Tensor* GetTensor(const std::string& name) const; const RuntimeProgram& runtime_program() const; @@ -142,8 +145,12 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::vector GetInputNames() override; std::vector GetOutputNames() override; + // get tensor according to tensor's name std::unique_ptr GetTensor( const std::string& name) const override; + // get a mutable tensor according to tensor's name + std::unique_ptr GetMutableTensor( + const std::string& name) override; // Get InputTebsor by name std::unique_ptr GetInputByName( diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 972210c8f9ea05ba1b041382c43efad64aeacc1b..24500dd04318551206a65a33d2e93ce32fcd6029 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -101,6 +101,12 @@ std::unique_ptr CxxPaddleApiImpl::GetTensor( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetMutableTensor( + const std::string &name) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_.GetMutableTensor(name))); +} + std::unique_ptr CxxPaddleApiImpl::GetInputByName( const std::string &name) { return std::unique_ptr( diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 307eeb74e8b4cdc3b2d6188eb18490e4dcf89b8f..9f540adc28a001c18e41e9c27b35e12b04b366eb 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -94,6 +94,14 @@ class LITE_API PaddlePredictor { virtual std::unique_ptr GetTensor( const std::string& name) const = 0; + /// Get a readonly tensor, return null if no one called `name` exists. + virtual std::unique_ptr GetTensor( + const std::string& name) const = 0; + /// Get a mutable tensor, return null if on one called `name` exists + /// internal infereces API, not recommanded. + virtual std::unique_ptr GetMutableTensor( + const std::string& name) const = 0; + /// Persist the optimized model to disk. This API is only supported by /// CxxConfig, and the persisted model can be reused for MobileConfig. virtual void SaveOptimizedModel(