From 39ad0c74c0c6177f9b5f9b5b2fb765732a5da1ed Mon Sep 17 00:00:00 2001 From: "ysh329@sina.com" Date: Tue, 26 Nov 2019 04:20:20 +0000 Subject: [PATCH] [LITE][API] Add GetOutput API for CXX and Mobile. test=develop --- lite/api/cxx_api.h | 3 +++ lite/api/cxx_api_impl.cc | 5 +++++ lite/api/light_api.h | 4 ++++ lite/api/light_api_impl.cc | 6 ++++++ lite/api/paddle_api.h | 4 ++++ 5 files changed, 22 insertions(+) diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 504710d9fa..25b529b5e5 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -88,6 +88,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; + const lite::Tensor* GetOutput(const std::string& name) const; std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; @@ -131,6 +132,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::unique_ptr GetInput(int i) override; std::unique_ptr GetOutput(int i) const override; + std::unique_ptr GetOutput( + const std::string& name) const override; void Run() override; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 6fa400db6d..28105d63c9 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -46,6 +46,11 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetOutput( + const std::string &name) const { + return CxxPaddleApiImpl::GetTensor(name); +} + std::vector CxxPaddleApiImpl::GetInputNames() { return raw_predictor_.GetInputNames(); } diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 3781bc4d67..bc6a9baf7b 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -57,6 +57,7 @@ class LITE_API LightPredictor { Tensor* GetInputByName(const std::string& name); // Get offset-th col of fetch outputs. const Tensor* GetOutput(size_t offset); + const Tensor* GetOutput(const std::string& name); const lite::Tensor* GetTensor(const std::string& name) const { auto* var = program_->exec_scope()->FindVar(name); @@ -94,6 +95,9 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { std::unique_ptr GetOutput(int i) const override; + std::unique_ptr GetOutput( + const std::string& name) const override; + void Run() override; std::shared_ptr Clone() override; diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index a0ae28df09..4d23da1625 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -45,6 +45,11 @@ std::unique_ptr LightPredictorImpl::GetOutput( new lite_api::Tensor(raw_predictor_->GetOutput(i))); } +std::unique_ptr LightPredictorImpl::GetOutput( + const std::string& name) const { + return LightPredictorImpl::GetTensor(name); +} + void LightPredictorImpl::Run() { #ifdef LITE_WITH_ARM lite::DeviceInfo::Global().SetRunMode(mode_, threads_); @@ -63,6 +68,7 @@ std::unique_ptr LightPredictorImpl::GetTensor( return std::unique_ptr( new lite_api::Tensor(raw_predictor_->GetTensor(name))); } + std::unique_ptr LightPredictorImpl::GetInputByName( const std::string& name) { return std::unique_ptr( diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index c578769bd5..738adbc43e 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -77,6 +77,10 @@ class LITE_API PaddlePredictor { /// Get i-th output. virtual std::unique_ptr GetOutput(int i) const = 0; + /// Get output tensor with `name` existed. + virtual std::unique_ptr GetOutput( + const std::string& name) const = 0; + virtual void Run() = 0; virtual std::shared_ptr Clone() = 0; -- GitLab