diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index cec8e2ec3781bdc0f9cb9fd238301c65322cabec..67f9a9d1d84b88f93db560ce1a6903b711ad492f 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -80,11 +80,18 @@ class LITE_API Predictor { const std::vector& valid_places, const std::vector& passes = {}); - std::shared_ptr Clone() const { + std::shared_ptr Clone( + cosnst std::vector var_names) const { // CHECK(program_desc_) << "Both program and scope of current predicotr // should be not be nullptr in Clone mode." ; // CHECK(scope_) << "Both program and scope of current predicotr should // be not be nullptr in Clone mode."; + for (auto i : var_names) { + exec_scope_->Var(i); + auto* tensor = scope_->Var(i)->GetMutable(); + auto* sub_tensor = exec_scope_->Var(i)->GetMutable(); + sub_tensor->CopyDataFrom(tensor); + } auto predictor = std::make_shared(program_desc_, scope_, valid_places_); return predictor; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index c327d47c99167dca9f6fbece94ffa60387afeb77..5faddd69d0bd2bbbe0efe8b341c92403601081c3 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -119,10 +119,11 @@ void CxxPaddleApiImpl::Run() { raw_predictor_->Run(); } -std::shared_ptr CxxPaddleApiImpl::Clone() { +std::shared_ptr CxxPaddleApiImpl::Clone( + cosnst std::vector var_names) { std::lock_guard lock(mutex_); - auto predictor = - std::make_shared(raw_predictor_->Clone()); + auto predictor = std::make_shared( + raw_predictor_->Clone(var_names)); status_is_cloned_ = true; predictor->Init(config_); return predictor;