cxx_api.cc 7.1 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/api/cxx_api.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/utils/io.h"

namespace paddle {
namespace lite {

void Predictor::SaveModel(const std::string &dir,
                          lite_api::LiteModelType model_type) {
  if (!program_) {
    GenRuntimeProgram();
  }
  program_->SaveOpInfosToProgram(&program_desc_);
31
  program_->UpdateVarsOfProgram(&program_desc_);
Y
Yan Chunwei 已提交
32 33
  switch (model_type) {
    case lite_api::LiteModelType::kProtobuf:
34
      SaveModelPb(dir, *program_->exec_scope(), program_desc_, true);
Y
Yan Chunwei 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
      break;
    case lite_api::LiteModelType::kNaiveBuffer:
      SaveModelNaive(dir, *program_->exec_scope(), program_desc_);
      break;
    default:
      LOG(FATAL) << "Unknown model type";
  }
}

lite::Tensor *Predictor::GetInput(size_t offset) {
  auto *_feed_list = exec_scope_->FindVar("feed");
  CHECK(_feed_list) << "no feed variable in exec_scope";
  auto *feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>();
  if (offset >= feed_list->size()) {
    feed_list->resize(offset + 1);
  }
  return &feed_list->at(offset);
}

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
// get inputs names
std::vector<std::string> Predictor::GetInputNames() {
  std::vector<std::string> input_names;
  for (auto &item : input_names_) {
    input_names.push_back(item.second);
  }
  return input_names;
}
// get outputnames
std::vector<std::string> Predictor::GetOutputNames() {
  std::vector<std::string> output_names;
  for (auto &item : output_names_) {
    output_names.push_back(item.second);
  }
  return output_names;
}
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
  auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
  for (int i = 0; i < current_block->OpsSize(); i++) {
    auto op = current_block->GetOp<cpp::OpDesc>(i);
    if (op->Type() == "feed") {
      int idx = op->GetAttr<int>("col");
      input_names_[idx] = op->Output("Out").front();
      idx2feeds_[op->Output("Out").front()] = idx;
    } else if (op->Type() == "fetch") {
      int idx = op->GetAttr<int>("col");
      output_names_[idx] = op->Input("X").front();
    }
  }
}

Y
Yan Chunwei 已提交
86 87 88 89 90 91 92 93
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
  auto *_fetch_list = exec_scope_->FindVar("fetch");
  CHECK(_fetch_list) << "no fatch variable in exec_scope";
  auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
  CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
  return &fetch_list.at(offset);
}

T
TianXiaogang 已提交
94 95 96 97 98 99 100
const std::vector<lite::Tensor> *Predictor::GetOutputs() const {
  auto *_fetch_list = exec_scope_->FindVar("fetch");
  CHECK(_fetch_list) << "no fatch variable in exec_scope";
  auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
  return &fetch_list;
}

Y
Yan Chunwei 已提交
101 102 103 104 105
const cpp::ProgramDesc &Predictor::program_desc() const {
  return program_desc_;
}
const RuntimeProgram &Predictor::runtime_program() const { return *program_; }

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
void Predictor::Build(const lite_api::CxxConfig &config,
                      const std::vector<Place> &valid_places,
                      const std::vector<std::string> &passes,
                      lite_api::LiteModelType model_type) {
  const std::string &model_path = config.model_dir();
  const std::string &model_file = config.model_file();
  const std::string &param_file = config.param_file();
  const Place prefer_place = config.preferred_place();
  const bool model_from_memory = config.model_from_memory();
  LOG(INFO) << "load from memory " << model_from_memory;

  Build(model_path,
        model_file,
        param_file,
        prefer_place,
        valid_places,
        passes,
        model_type,
        model_from_memory);
}
Y
Yan Chunwei 已提交
126
void Predictor::Build(const std::string &model_path,
127 128
                      const std::string &model_file,
                      const std::string &param_file,
Y
Yan Chunwei 已提交
129 130 131
                      const Place &prefer_place,
                      const std::vector<Place> &valid_places,
                      const std::vector<std::string> &passes,
132 133
                      lite_api::LiteModelType model_type,
                      bool model_from_memory) {
Y
Yan Chunwei 已提交
134
  switch (model_type) {
135 136 137 138 139 140 141 142 143 144
    case lite_api::LiteModelType::kProtobuf: {
      bool combined_param = false;
      if (!model_file.empty() && !param_file.empty()) {
        combined_param = true;
      }
      LoadModelPb(model_path,
                  model_file,
                  param_file,
                  scope_.get(),
                  &program_desc_,
145 146
                  combined_param,
                  model_from_memory);
147
    } break;
Y
Yan Chunwei 已提交
148
    case lite_api::LiteModelType::kNaiveBuffer:
149 150
      CHECK(!model_path.empty())
          << "NaiveBuffer backend only supported combined param";
Y
Yan Chunwei 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
      LoadModelNaive(model_path, scope_.get(), &program_desc_);
      break;
    default:
      LOG(FATAL) << "Unknown model type";
  }
  Build(program_desc_, prefer_place, valid_places, passes);
}

void Predictor::Build(const cpp::ProgramDesc &desc,
                      const Place &prefer_place,
                      const std::vector<Place> &valid_places,
                      const std::vector<std::string> &passes) {
  program_desc_ = desc;
  Program program(desc, scope_, valid_places);
  optimizer_.KernelPickPreferPlace(prefer_place);
  core::KernelPickFactor factor;
  factor.ConsiderTarget();
  factor.ConsiderPrecision();
169
  factor.ConsiderDataLayout();
Y
Yan Chunwei 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183
  optimizer_.Run(std::move(program), valid_places, factor, passes);
  exec_scope_ = optimizer_.exec_scope();
}

void Predictor::GenRuntimeProgram() {
  program_ = optimizer_.GenRuntimeProgram();
  CHECK_EQ(exec_scope_, program_->exec_scope());
  program_generated_ = true;
}

const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
  auto *var = exec_scope_->FindVar(name);
  return &var->Get<lite::Tensor>();
}
184 185 186 187 188 189 190 191 192 193 194 195 196 197
// get input by name
lite::Tensor *Predictor::GetInputByName(const std::string &name) {
  if (idx2feeds_.find(name) == idx2feeds_.end()) {
    LOG(ERROR) << "Model do not have input named with: [" << name
               << "], model's inputs include:";
    for (int i = 0; i < input_names_.size(); i++) {
      LOG(ERROR) << "[" << input_names_[i] << "]";
    }
    return NULL;
  } else {
    int idx = idx2feeds_[name];
    return GetInput(idx);
  }
}
Y
Yan Chunwei 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211

#ifdef LITE_WITH_TRAIN
void Predictor::FeedVars(const std::vector<framework::Tensor> &tensors) {
  auto var = scope_->FindVar("feed");
  auto &feed_list = *(var->GetMutable<std::vector<lite::Tensor>>());
  feed_list.resize(tensors.size());

  for (size_t i = 0; i < tensors.size(); ++i)
    feed_list[i].ShareDataWith(tensors[i]);
}
#endif

}  // namespace lite
}  // namespace paddle