From 8e0a526ff4150ca5ce37041c4e4db84faa517d84 Mon Sep 17 00:00:00 2001 From: DannyIsFunny <912790387@qq.com> Date: Fri, 6 Mar 2020 13:24:01 +0000 Subject: [PATCH] add GetMutableTensor test=develop --- lite/api/cxx_api.cc | 7 +++++++ lite/api/cxx_api.h | 7 +++++++ lite/api/cxx_api_impl.cc | 6 ++++++ lite/api/paddle_api.h | 8 ++++++++ 4 files changed, 28 insertions(+) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index f6f7ec75e6..3fcaaac65b 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -314,9 +314,16 @@ void Predictor::GenRuntimeProgram() { const lite::Tensor *Predictor::GetTensor(const std::string &name) const { auto *var = exec_scope_->FindVar(name); + CHECK(var) << "no variable named with " << name << " in exec_scope"; return &var->Get(); } +lite::Tensor *Predictor::GetMutableTensor(const std::string &name) const { + auto *var = exec_scope_->FindVar(name); + CHECK(var) << "no variable named with " << name << " in exec_scope"; + return var->GetMutable(); +} + // get input by name lite::Tensor *Predictor::GetInputByName(const std::string &name) { auto element = std::find(input_names_.begin(), input_names_.end(), name); diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 504710d9fa..0348ab8ef1 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -91,6 +91,9 @@ class LITE_API Predictor { std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; + // get a mutable tensor according to its name + lite::Tensor* GetMutableTensor(const std::string& name); + // get a const tensor according to its name const lite::Tensor* GetTensor(const std::string& name) const; const RuntimeProgram& runtime_program() const; @@ -142,8 +145,12 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::vector GetInputNames() override; std::vector GetOutputNames() override; + // get tensor according to tensor's name std::unique_ptr GetTensor( const std::string& name) const override; + // get a mutable tensor according to tensor's name + std::unique_ptr GetMutableTensor( + const std::string& name) override; // Get InputTebsor by name std::unique_ptr GetInputByName( diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 972210c8f9..24500dd043 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -101,6 +101,12 @@ std::unique_ptr CxxPaddleApiImpl::GetTensor( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetMutableTensor( + const std::string &name) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_.GetMutableTensor(name))); +} + std::unique_ptr CxxPaddleApiImpl::GetInputByName( const std::string &name) { return std::unique_ptr( diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 307eeb74e8..9f540adc28 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -94,6 +94,14 @@ class LITE_API PaddlePredictor { virtual std::unique_ptr GetTensor( const std::string& name) const = 0; + /// Get a readonly tensor, return null if no one called `name` exists. + virtual std::unique_ptr GetTensor( + const std::string& name) const = 0; + /// Get a mutable tensor, return null if on one called `name` exists + /// internal infereces API, not recommanded. + virtual std::unique_ptr GetMutableTensor( + const std::string& name) const = 0; + /// Persist the optimized model to disk. This API is only supported by /// CxxConfig, and the persisted model can be reused for MobileConfig. virtual void SaveOptimizedModel( -- GitLab