io.cc 7.8 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/inference/io.h"
16

17
#include <algorithm>
18
#include <fstream>
19
#include <vector>
S
Steffy-zxf 已提交
20

Y
Yi Wang 已提交
21 22
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
23
#include "paddle/fluid/framework/op_registry.h"
X
version  
Xin Pan 已提交
24
#include "paddle/fluid/framework/version.h"
T
tensor-tang 已提交
25
#include "paddle/fluid/platform/cpu_helper.h"
S
Steffy-zxf 已提交
26
#include "paddle/fluid/platform/enforce.h"
27
#include "paddle/fluid/pybind/pybind.h"
28

29
// phi
30
#include "paddle/phi/kernels/declarations.h"
31

W
wanghaoshuang 已提交
32
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
33 34
DEFINE_int32(math_num_threads,
             1,
35
             "Number of threads used to run math functions.");
36

37
namespace paddle {
38
namespace inference {
39

40 41
void Init(const std::vector<std::string> argv) {
  framework::InitGflags(argv);
T
tensor-tang 已提交
42
  platform::SetNumThreads(FLAGS_math_num_threads);
43 44 45 46 47 48 49
  // init devices
  std::vector<int> devices;
  std::string token;
  std::istringstream tokenStream(FLAGS_devices);
  while (std::getline(tokenStream, token, ',')) {
    devices.push_back(std::stoi(token));
  }
50
  framework::InitDevices(devices);
51
}
52

53
void ReadBinaryFile(const std::string& filename, std::string* contents) {
54
  std::ifstream fin(filename, std::ios::in | std::ios::binary);
55
  PADDLE_ENFORCE_EQ(
56 57
      fin.is_open(),
      true,
58
      platform::errors::Unavailable("Failed to open file %s.", filename));
59
  fin.seekg(0, std::ios::end);
60 61
  contents->clear();
  contents->resize(fin.tellg());
62
  fin.seekg(0, std::ios::beg);
63
  fin.read(&(contents->at(0)), contents->size());
64
  fin.close();
65 66
}

L
Liu Yiqun 已提交
67 68
bool IsPersistable(const framework::VarDesc* var) {
  if (var->Persistable() &&
69
      var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
70 71
      var->GetType() != framework::proto::VarType::FETCH_LIST &&
      var->GetType() != framework::proto::VarType::RAW) {
L
Liu Yiqun 已提交
72
    return true;
73 74 75 76
  }
  return false;
}

77 78
void LoadPersistables(framework::Executor* executor,
                      framework::Scope* scope,
79
                      const framework::ProgramDesc& main_program,
80
                      const std::string& dirname,
T
Tao Luo 已提交
81
                      const std::string& param_filename,
T
Tao Luo 已提交
82
                      bool model_from_memory = false) {
K
kexinzhao 已提交
83
  const framework::BlockDesc& global_block = main_program.Block(0);
84

85 86
  framework::ProgramDesc* load_program = new framework::ProgramDesc();
  framework::BlockDesc* load_block = load_program->MutableBlock(0);
87 88
  std::vector<std::string> paramlist;

K
kexinzhao 已提交
89
  for (auto* var : global_block.AllVars()) {
L
Liu Yiqun 已提交
90
    if (IsPersistable(var)) {
91
      VLOG(4) << "persistable variable's name: " << var->Name();
92 93

      framework::VarDesc* new_var = load_block->Var(var->Name());
F
fengjiayi 已提交
94
      new_var->SetShape(var->GetShape());
95
      new_var->SetDataType(var->GetDataType());
S
Steffy-zxf 已提交
96 97
      auto var_type = var->GetType();
      new_var->SetType(var_type);
98

S
Steffy-zxf 已提交
99 100 101
      if ((var_type !=
           framework::proto::VarType::Type::VarType_Type_SELECTED_ROWS) &&
          (var_type != framework::proto::VarType::VOCAB)) {
102 103 104
        new_var->SetLoDLevel(var->GetLoDLevel());
      }

105 106
      new_var->SetPersistable(true);

107 108 109 110 111 112 113 114 115 116
      if (!param_filename.empty()) {
        paramlist.push_back(new_var->Name());
      } else {
        // append_op
        framework::OpDesc* op = load_block->AppendOp();
        op->SetType("load");
        op->SetOutput("Out", {new_var->Name()});
        op->SetAttr("file_path", {dirname + "/" + new_var->Name()});
        op->CheckAttrs();
      }
117 118
    }
  }
119 120 121 122 123 124 125 126 127

  if (!param_filename.empty()) {
    // sort paramlist to have consistent ordering
    std::sort(paramlist.begin(), paramlist.end());
    // append just the load_combine op
    framework::OpDesc* op = load_block->AppendOp();
    op->SetType("load_combine");
    op->SetOutput("Out", paramlist);
    op->SetAttr("file_path", {param_filename});
T
Tao Luo 已提交
128
    op->SetAttr("model_from_memory", {model_from_memory});
129 130 131
    op->CheckAttrs();
  }

132
  executor->Run(*load_program, scope, 0, true, true);
133

134
  delete load_program;
135
}
136

137 138
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                                             framework::Scope* scope,
K
kexinzhao 已提交
139
                                             const std::string& dirname) {
140 141
  std::string model_filename = dirname + "/__model__";
  std::string program_desc_str;
M
minqiyang 已提交
142
  VLOG(3) << "loading model from " << model_filename;
143
  ReadBinaryFile(model_filename, &program_desc_str);
144 145 146

  std::unique_ptr<framework::ProgramDesc> main_program(
      new framework::ProgramDesc(program_desc_str));
147
  PADDLE_ENFORCE_EQ(
148 149
      framework::IsProgramVersionSupported(main_program->Version()),
      true,
150 151
      platform::errors::Unavailable("Model version %ld is not supported.",
                                    main_program->Version()));
152

T
tianshuo78520a 已提交
153
  // model_from_memory is false in separate parameters.
154 155 156 157 158
  LoadPersistables(executor,
                   scope,
                   *main_program,
                   dirname,
                   "",
T
Tao Luo 已提交
159
                   false /* model_from_memory */);
160 161 162
  return main_program;
}

163 164 165 166 167
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                                             framework::Scope* scope,
                                             const std::string& prog_filename,
                                             const std::string& param_filename,
                                             bool load_params) {
168
  std::string program_desc_str;
T
Tao Luo 已提交
169
  ReadBinaryFile(prog_filename, &program_desc_str);
170

K
kexinzhao 已提交
171 172
  std::unique_ptr<framework::ProgramDesc> main_program(
      new framework::ProgramDesc(program_desc_str));
173
  PADDLE_ENFORCE_EQ(
174 175
      framework::IsProgramVersionSupported(main_program->Version()),
      true,
176 177
      platform::errors::Unavailable("Model version %ld is not supported.",
                                    main_program->Version()));
178 179 180 181 182 183 184 185
  if (load_params) {
    LoadPersistables(executor,
                     scope,
                     *main_program,
                     "",
                     param_filename,
                     false /* model_from_memory */);
  }
186 187 188
  return main_program;
}

T
Tao Luo 已提交
189
std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
190 191 192 193
    framework::Executor* executor,
    framework::Scope* scope,
    const std::string& prog_buffer,
    const std::string& param_buffer) {
T
Tao Luo 已提交
194 195
  std::unique_ptr<framework::ProgramDesc> main_program(
      new framework::ProgramDesc(prog_buffer));
196
  PADDLE_ENFORCE_EQ(
197 198
      framework::IsProgramVersionSupported(main_program->Version()),
      true,
199 200
      platform::errors::Unavailable("Model version %ld is not supported.",
                                    main_program->Version()));
T
Tao Luo 已提交
201

202 203 204 205 206
  LoadPersistables(executor,
                   scope,
                   *main_program,
                   "",
                   param_buffer,
T
Tao Luo 已提交
207 208
                   true /* model_filename */);
  return main_program;
T
Tao Luo 已提交
209 210
}

211
void SaveVars(const framework::Scope& scope,
212 213
              const std::vector<std::string>& vars,
              const std::string& dirname,
214 215 216 217 218 219 220 221 222 223 224 225 226 227
              bool predicate) {
  framework::ProgramDesc prog;
  auto* block = prog.MutableBlock(0);
  auto* op = block->AppendOp();
  op->SetType("save_combine");
  op->SetInput("X", vars);
  op->SetAttr("file_path", dirname + "/param");
  op->CheckAttrs();

  platform::CPUPlace place;
  framework::Executor exe(place);
  exe.Run(prog, const_cast<framework::Scope*>(&scope), 0, true, true);
}

228
}  // namespace inference
229
}  // namespace paddle