cxx_api.cc 3.3 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/lite/api/cxx_api.h"
S
superjomn 已提交
16 17 18 19
#include <memory>
#include <string>
#include <utility>
#include <vector>
C
Chunwei 已提交
20
#include "paddle/fluid/lite/utils/io.h"
S
superjomn 已提交
21 22

namespace paddle {
S
Superjomn 已提交
23 24
namespace lite {

C
Chunwei 已提交
25 26
void Predictor::SaveModel(const std::string &dir) {
  MkDirRecur(dir);
S
superjomn 已提交
27 28
  program_->PersistModel(dir, program_desc_);
  LOG(INFO) << "Save model to " << dir;
C
Chunwei 已提交
29 30 31 32 33 34 35 36 37 38 39 40
}

lite::Tensor *Predictor::GetInput(size_t offset) {
  auto *_feed_list = program_->exec_scope()->FindVar("feed");
  CHECK(_feed_list) << "no feed variable in exec_scope";
  auto *feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>();
  if (offset >= feed_list->size()) {
    feed_list->resize(offset + 1);
  }
  return &feed_list->at(offset);
}

C
Chunwei 已提交
41
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
C
Chunwei 已提交
42 43 44 45 46 47 48 49
  auto *_fetch_list = program_->exec_scope()->FindVar("fetch");
  CHECK(_fetch_list) << "no fatch variable in exec_scope";
  auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
  CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
  return &fetch_list.at(offset);
}

void Predictor::Build(const std::string &model_path, const Place &prefer_place,
50 51
                      const std::vector<Place> &valid_places,
                      const std::vector<std::string> &passes) {
C
Chunwei 已提交
52
  LoadModel(model_path, scope_.get(), &program_desc_);
53
  Build(program_desc_, prefer_place, valid_places, passes);
C
Chunwei 已提交
54 55 56 57 58
}

const framework::proto::ProgramDesc &Predictor::program_desc() const {
  return program_desc_;
}
S
Superjomn 已提交
59

60 61
const RuntimeProgram &Predictor::runtime_program() const { return *program_; }

S
init  
superjomn 已提交
62 63
void Predictor::Build(const framework::proto::ProgramDesc &desc,
                      const Place &prefer_place,
64 65
                      const std::vector<Place> &valid_places,
                      const std::vector<std::string> &passes) {
S
init  
superjomn 已提交
66 67 68 69 70 71 72
  program_desc_ = desc;
  Program program(desc, scope_, valid_places);

  optimizer_.KernelPickPreferPlace(prefer_place);
  core::KernelPickFactor factor;
  factor.ConsiderTarget();
  factor.ConsiderPrecision();
73
  optimizer_.Run(std::move(program), valid_places, factor, passes);
S
init  
superjomn 已提交
74 75 76 77 78 79 80 81
  program_ = optimizer_.GenRuntimeProgram();
}

const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
  auto *var = program_->exec_scope()->FindVar(name);
  return &var->Get<lite::Tensor>();
}

82 83 84 85 86 87 88 89 90 91 92
#ifdef LITE_WITH_X86
void Predictor::FeedVars(const std::vector<framework::Tensor> &tensors) {
  auto var = scope_->FindVar("feed");
  auto &feed_list = *(var->GetMutable<std::vector<lite::Tensor>>());
  feed_list.resize(tensors.size());

  for (size_t i = 0; i < tensors.size(); ++i)
    feed_list[i].ShareDataWith(tensors[i]);
}
#endif

S
Superjomn 已提交
93
}  // namespace lite
S
superjomn 已提交
94
}  // namespace paddle