提交 39ad0c74 编写于 作者: 开心的小妮's avatar 开心的小妮

[LITE][API] Add GetOutput API for CXX and Mobile. test=develop

上级 af661abe
...@@ -88,6 +88,7 @@ class LITE_API Predictor { ...@@ -88,6 +88,7 @@ class LITE_API Predictor {
// Get offset-th col of fetch results. // Get offset-th col of fetch results.
const lite::Tensor* GetOutput(size_t offset) const; const lite::Tensor* GetOutput(size_t offset) const;
const lite::Tensor* GetOutput(const std::string& name) const;
std::vector<const lite::Tensor*> GetOutputs() const; std::vector<const lite::Tensor*> GetOutputs() const;
const cpp::ProgramDesc& program_desc() const; const cpp::ProgramDesc& program_desc() const;
...@@ -131,6 +132,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -131,6 +132,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
std::unique_ptr<lite_api::Tensor> GetInput(int i) override; std::unique_ptr<lite_api::Tensor> GetInput(int i) override;
std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override; std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override;
std::unique_ptr<const lite_api::Tensor> GetOutput(
const std::string& name) const override;
void Run() override; void Run() override;
......
...@@ -46,6 +46,11 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput( ...@@ -46,6 +46,11 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
const std::string &name) const {
return CxxPaddleApiImpl::GetTensor(name);
}
std::vector<std::string> CxxPaddleApiImpl::GetInputNames() { std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames(); return raw_predictor_.GetInputNames();
} }
......
...@@ -57,6 +57,7 @@ class LITE_API LightPredictor { ...@@ -57,6 +57,7 @@ class LITE_API LightPredictor {
Tensor* GetInputByName(const std::string& name); Tensor* GetInputByName(const std::string& name);
// Get offset-th col of fetch outputs. // Get offset-th col of fetch outputs.
const Tensor* GetOutput(size_t offset); const Tensor* GetOutput(size_t offset);
const Tensor* GetOutput(const std::string& name);
const lite::Tensor* GetTensor(const std::string& name) const { const lite::Tensor* GetTensor(const std::string& name) const {
auto* var = program_->exec_scope()->FindVar(name); auto* var = program_->exec_scope()->FindVar(name);
...@@ -94,6 +95,9 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { ...@@ -94,6 +95,9 @@ class LightPredictorImpl : public lite_api::PaddlePredictor {
std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override; std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override;
std::unique_ptr<const lite_api::Tensor> GetOutput(
const std::string& name) const override;
void Run() override; void Run() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone() override; std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
......
...@@ -45,6 +45,11 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput( ...@@ -45,6 +45,11 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput(
new lite_api::Tensor(raw_predictor_->GetOutput(i))); new lite_api::Tensor(raw_predictor_->GetOutput(i)));
} }
std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput(
const std::string& name) const {
return LightPredictorImpl::GetTensor(name);
}
void LightPredictorImpl::Run() { void LightPredictorImpl::Run() {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_); lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
...@@ -63,6 +68,7 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetTensor( ...@@ -63,6 +68,7 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetTensor(
return std::unique_ptr<const lite_api::Tensor>( return std::unique_ptr<const lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_->GetTensor(name))); new lite_api::Tensor(raw_predictor_->GetTensor(name)));
} }
std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInputByName( std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInputByName(
const std::string& name) { const std::string& name) {
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
......
...@@ -77,6 +77,10 @@ class LITE_API PaddlePredictor { ...@@ -77,6 +77,10 @@ class LITE_API PaddlePredictor {
/// Get i-th output. /// Get i-th output.
virtual std::unique_ptr<const Tensor> GetOutput(int i) const = 0; virtual std::unique_ptr<const Tensor> GetOutput(int i) const = 0;
/// Get output tensor with `name` existed.
virtual std::unique_ptr<const Tensor> GetOutput(
const std::string& name) const = 0;
virtual void Run() = 0; virtual void Run() = 0;
virtual std::shared_ptr<PaddlePredictor> Clone() = 0; virtual std::shared_ptr<PaddlePredictor> Clone() = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册