提交 8e0a526f 编写于 作者: D DannyIsFunny

add GetMutableTensor test=develop

上级 56a3e8d2
...@@ -314,9 +314,16 @@ void Predictor::GenRuntimeProgram() { ...@@ -314,9 +314,16 @@ void Predictor::GenRuntimeProgram() {
const lite::Tensor *Predictor::GetTensor(const std::string &name) const { const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
auto *var = exec_scope_->FindVar(name); auto *var = exec_scope_->FindVar(name);
CHECK(var) << "no variable named with " << name << " in exec_scope";
return &var->Get<lite::Tensor>(); return &var->Get<lite::Tensor>();
} }
lite::Tensor *Predictor::GetMutableTensor(const std::string &name) const {
auto *var = exec_scope_->FindVar(name);
CHECK(var) << "no variable named with " << name << " in exec_scope";
return var->GetMutable<lite::Tensor>();
}
// get input by name // get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) { lite::Tensor *Predictor::GetInputByName(const std::string &name) {
auto element = std::find(input_names_.begin(), input_names_.end(), name); auto element = std::find(input_names_.begin(), input_names_.end(), name);
......
...@@ -91,6 +91,9 @@ class LITE_API Predictor { ...@@ -91,6 +91,9 @@ class LITE_API Predictor {
std::vector<const lite::Tensor*> GetOutputs() const; std::vector<const lite::Tensor*> GetOutputs() const;
const cpp::ProgramDesc& program_desc() const; const cpp::ProgramDesc& program_desc() const;
// get a mutable tensor according to its name
lite::Tensor* GetMutableTensor(const std::string& name);
// get a const tensor according to its name
const lite::Tensor* GetTensor(const std::string& name) const; const lite::Tensor* GetTensor(const std::string& name) const;
const RuntimeProgram& runtime_program() const; const RuntimeProgram& runtime_program() const;
...@@ -142,8 +145,12 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -142,8 +145,12 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
std::vector<std::string> GetInputNames() override; std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override; std::vector<std::string> GetOutputNames() override;
// get tensor according to tensor's name
std::unique_ptr<const lite_api::Tensor> GetTensor( std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string& name) const override; const std::string& name) const override;
// get a mutable tensor according to tensor's name
std::unique_ptr<lite_api::Tensor> GetMutableTensor(
const std::string& name) override;
// Get InputTebsor by name // Get InputTebsor by name
std::unique_ptr<lite_api::Tensor> GetInputByName( std::unique_ptr<lite_api::Tensor> GetInputByName(
......
...@@ -101,6 +101,12 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor( ...@@ -101,6 +101,12 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetMutableTensor(
const std::string &name) {
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetMutableTensor(name)));
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName( std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName(
const std::string &name) { const std::string &name) {
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
......
...@@ -94,6 +94,14 @@ class LITE_API PaddlePredictor { ...@@ -94,6 +94,14 @@ class LITE_API PaddlePredictor {
virtual std::unique_ptr<const Tensor> GetTensor( virtual std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const = 0; const std::string& name) const = 0;
/// Get a readonly tensor, return null if no one called `name` exists.
virtual std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const = 0;
/// Get a mutable tensor, return null if on one called `name` exists
/// internal infereces API, not recommanded.
virtual std::unique_ptr<Tensor> GetMutableTensor(
const std::string& name) const = 0;
/// Persist the optimized model to disk. This API is only supported by /// Persist the optimized model to disk. This API is only supported by
/// CxxConfig, and the persisted model can be reused for MobileConfig. /// CxxConfig, and the persisted model can be reused for MobileConfig.
virtual void SaveOptimizedModel( virtual void SaveOptimizedModel(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册