diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 504710d9fa29420b8762f31e0c675b59c6c626bd..25b529b5e5168eab94ab12ab0b92dde555a16f2a 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -88,6 +88,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; + const lite::Tensor* GetOutput(const std::string& name) const; std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; @@ -131,6 +132,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { std::unique_ptr GetInput(int i) override; std::unique_ptr GetOutput(int i) const override; + std::unique_ptr GetOutput( + const std::string& name) const override; void Run() override; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 6fa400db6da9f029c38b496cd70d593a876628c9..28105d63c9e0ca08cd19500b13bb7a14236f68d0 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -46,6 +46,11 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetOutput( + const std::string &name) const { + return CxxPaddleApiImpl::GetTensor(name); +} + std::vector CxxPaddleApiImpl::GetInputNames() { return raw_predictor_.GetInputNames(); } diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 3781bc4d674db5d2e8794edaf33f00627b9977bb..bc6a9baf7bde7abd89e8310176d6d49183853d14 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -57,6 +57,7 @@ class LITE_API LightPredictor { Tensor* GetInputByName(const std::string& name); // Get offset-th col of fetch outputs. const Tensor* GetOutput(size_t offset); + const Tensor* GetOutput(const std::string& name); const lite::Tensor* GetTensor(const std::string& name) const { auto* var = program_->exec_scope()->FindVar(name); @@ -94,6 +95,9 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { std::unique_ptr GetOutput(int i) const override; + std::unique_ptr GetOutput( + const std::string& name) const override; + void Run() override; std::shared_ptr Clone() override; diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index a0ae28df0958403237114a3d4b94031829019339..4d23da162503998557610bbd889108869d477f1c 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -45,6 +45,11 @@ std::unique_ptr LightPredictorImpl::GetOutput( new lite_api::Tensor(raw_predictor_->GetOutput(i))); } +std::unique_ptr LightPredictorImpl::GetOutput( + const std::string& name) const { + return LightPredictorImpl::GetTensor(name); +} + void LightPredictorImpl::Run() { #ifdef LITE_WITH_ARM lite::DeviceInfo::Global().SetRunMode(mode_, threads_); @@ -63,6 +68,7 @@ std::unique_ptr LightPredictorImpl::GetTensor( return std::unique_ptr( new lite_api::Tensor(raw_predictor_->GetTensor(name))); } + std::unique_ptr LightPredictorImpl::GetInputByName( const std::string& name) { return std::unique_ptr( diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index c578769bd5159d27ad43e4e93de33f601223004b..738adbc43e561c6d2323ee5efc5124301c8a599e 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -77,6 +77,10 @@ class LITE_API PaddlePredictor { /// Get i-th output. virtual std::unique_ptr GetOutput(int i) const = 0; + /// Get output tensor with `name` existed. + virtual std::unique_ptr GetOutput( + const std::string& name) const = 0; + virtual void Run() = 0; virtual std::shared_ptr Clone() = 0;