提交 a4b881dd 编写于 作者: C chenhaoze

fix bugs. test=develop

上级 6df6fd08
......@@ -81,16 +81,16 @@ class LITE_API Predictor {
const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone(
cosnst std::vector<std::string> var_names) const {
const std::vector<std::string>& var_names) const {
// CHECK(program_desc_) << "Both program and scope of current predicotr
// should be not be nullptr in Clone mode." ;
// CHECK(scope_) << "Both program and scope of current predicotr should
// be not be nullptr in Clone mode.";
for (auto i : var_names) {
exec_scope_->Var(i);
auto* tensor = scope_->Var(i)->GetMutable<lite::Tensor>();
auto* sub_tensor = exec_scope_->Var(i)->GetMutable<lite::Tensor>();
sub_tensor->CopyDataFrom(tensor);
this->exec_scope_->Var(i);
auto* tensor = this->scope_->Var(i)->GetMutable<lite::Tensor>();
auto* sub_tensor = this->exec_scope_->Var(i)->GetMutable<lite::Tensor>();
sub_tensor->CopyDataFrom(*tensor);
}
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
......@@ -150,7 +150,7 @@ class LITE_API Predictor {
Optimizer optimizer_;
std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_;
const Scope* exec_scope_;
Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::vector<std::string> input_names_;
......@@ -173,7 +173,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
void Run() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone(
const std::vector<std::string>& var_names);
std::string GetVersion() const override;
......
......@@ -120,7 +120,7 @@ void CxxPaddleApiImpl::Run() {
}
std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone(
cosnst std::vector<std::string> var_names) {
const std::vector<std::string> &var_names) {
std::lock_guard<std::mutex> lock(mutex_);
auto predictor = std::make_shared<lite::CxxPaddleApiImpl>(
raw_predictor_->Clone(var_names));
......
......@@ -113,7 +113,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor {
void Run() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone(
const std::vector<std::string>& var_names);
std::string GetVersion() const override;
std::vector<std::string> GetInputNames() override;
......
......@@ -56,7 +56,8 @@ void LightPredictorImpl::Run() {
raw_predictor_->Run();
}
std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone() {
std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone(
const std::vector<std::string>& var_names) {
LOG(FATAL) << "The Clone API is not supported in LigthPredictor";
return nullptr;
}
......
......@@ -78,7 +78,8 @@ class LITE_API PaddlePredictor {
virtual std::unique_ptr<const Tensor> GetOutput(int i) const = 0;
virtual void Run() = 0;
virtual std::shared_ptr<PaddlePredictor> Clone() = 0;
virtual std::shared_ptr<PaddlePredictor> Clone(
const std::vector<std::string>& var_names) = 0;
virtual std::string GetVersion() const = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册