diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 67f9a9d1d84b88f93db560ce1a6903b711ad492f..76967fa0084ada18d1161457e4eacb90e7319db6 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -81,16 +81,16 @@ class LITE_API Predictor { const std::vector& passes = {}); std::shared_ptr Clone( - cosnst std::vector var_names) const { + const std::vector& var_names) const { // CHECK(program_desc_) << "Both program and scope of current predicotr // should be not be nullptr in Clone mode." ; // CHECK(scope_) << "Both program and scope of current predicotr should // be not be nullptr in Clone mode."; for (auto i : var_names) { - exec_scope_->Var(i); - auto* tensor = scope_->Var(i)->GetMutable(); - auto* sub_tensor = exec_scope_->Var(i)->GetMutable(); - sub_tensor->CopyDataFrom(tensor); + this->exec_scope_->Var(i); + auto* tensor = this->scope_->Var(i)->GetMutable(); + auto* sub_tensor = this->exec_scope_->Var(i)->GetMutable(); + sub_tensor->CopyDataFrom(*tensor); } auto predictor = std::make_shared(program_desc_, scope_, valid_places_); @@ -150,7 +150,7 @@ class LITE_API Predictor { Optimizer optimizer_; std::shared_ptr program_desc_; std::shared_ptr scope_; - const Scope* exec_scope_; + Scope* exec_scope_; std::unique_ptr program_; bool program_generated_{false}; std::vector input_names_; @@ -173,7 +173,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { void Run() override; - std::shared_ptr Clone() override; + std::shared_ptr Clone( + const std::vector& var_names); std::string GetVersion() const override; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 5faddd69d0bd2bbbe0efe8b341c92403601081c3..e50e4a4032588379412daefcb92c95dbea1b7ce3 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -120,7 +120,7 @@ void CxxPaddleApiImpl::Run() { } std::shared_ptr CxxPaddleApiImpl::Clone( - cosnst std::vector var_names) { + const std::vector &var_names) { std::lock_guard lock(mutex_); auto predictor = std::make_shared( raw_predictor_->Clone(var_names)); diff --git a/lite/api/light_api.h b/lite/api/light_api.h index aa25ea81c7b62238211f96265a4edc49f2d065a1..279565cc718b697309bec6ff66214420cdbf955b 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -113,7 +113,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { void Run() override; - std::shared_ptr Clone() override; + std::shared_ptr Clone( + const std::vector& var_names); std::string GetVersion() const override; std::vector GetInputNames() override; diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index cdf5b7fb06df35b2e7fb72fc4e33ccb721a0f7f7..661dbccb52d8e99c3581e376a6795282ad635a8f 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -56,7 +56,8 @@ void LightPredictorImpl::Run() { raw_predictor_->Run(); } -std::shared_ptr LightPredictorImpl::Clone() { +std::shared_ptr LightPredictorImpl::Clone( + const std::vector& var_names) { LOG(FATAL) << "The Clone API is not supported in LigthPredictor"; return nullptr; } diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index b08f2f5c745f87cda2be181bdea2444b2c11313c..caa584601a417db4ac5d42ebb6c317ceebca5608 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -78,7 +78,8 @@ class LITE_API PaddlePredictor { virtual std::unique_ptr GetOutput(int i) const = 0; virtual void Run() = 0; - virtual std::shared_ptr Clone() = 0; + virtual std::shared_ptr Clone( + const std::vector& var_names) = 0; virtual std::string GetVersion() const = 0;