io.cc 4.5 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/inference/io.h"
16

17
#include <algorithm>
18
#include <fstream>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
21 22
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/pybind/pybind.h"
23 24

namespace paddle {
25
namespace inference {
26

W
wanghaoshuang 已提交
27
void Init(const std::vector<std::string> argv) { framework::Init(argv); }
28

29
void ReadBinaryFile(const std::string& filename, std::string* contents) {
30 31 32
  std::ifstream fin(filename, std::ios::in | std::ios::binary);
  PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
  fin.seekg(0, std::ios::end);
33 34
  contents->clear();
  contents->resize(fin.tellg());
35
  fin.seekg(0, std::ios::beg);
36
  fin.read(&(contents->at(0)), contents->size());
37
  fin.close();
38 39
}

L
Liu Yiqun 已提交
40 41
bool IsPersistable(const framework::VarDesc* var) {
  if (var->Persistable() &&
42 43
      var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
      var->GetType() != framework::proto::VarType::FETCH_LIST) {
L
Liu Yiqun 已提交
44
    return true;
45 46 47 48
  }
  return false;
}

49
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
50
                      const framework::ProgramDesc& main_program,
51
                      const std::string& dirname,
52
                      const std::string& param_filename) {
K
kexinzhao 已提交
53
  const framework::BlockDesc& global_block = main_program.Block(0);
54

55 56
  framework::ProgramDesc* load_program = new framework::ProgramDesc();
  framework::BlockDesc* load_block = load_program->MutableBlock(0);
57 58
  std::vector<std::string> paramlist;

K
kexinzhao 已提交
59
  for (auto* var : global_block.AllVars()) {
L
Liu Yiqun 已提交
60 61
    if (IsPersistable(var)) {
      VLOG(3) << "persistable variable's name: " << var->Name();
62 63

      framework::VarDesc* new_var = load_block->Var(var->Name());
F
fengjiayi 已提交
64
      new_var->SetShape(var->GetShape());
65 66 67 68 69
      new_var->SetDataType(var->GetDataType());
      new_var->SetType(var->GetType());
      new_var->SetLoDLevel(var->GetLoDLevel());
      new_var->SetPersistable(true);

70 71 72 73 74 75 76 77 78 79
      if (!param_filename.empty()) {
        paramlist.push_back(new_var->Name());
      } else {
        // append_op
        framework::OpDesc* op = load_block->AppendOp();
        op->SetType("load");
        op->SetOutput("Out", {new_var->Name()});
        op->SetAttr("file_path", {dirname + "/" + new_var->Name()});
        op->CheckAttrs();
      }
80 81
    }
  }
82 83 84 85 86 87 88 89 90 91 92 93

  if (!param_filename.empty()) {
    // sort paramlist to have consistent ordering
    std::sort(paramlist.begin(), paramlist.end());
    // append just the load_combine op
    framework::OpDesc* op = load_block->AppendOp();
    op->SetType("load_combine");
    op->SetOutput("Out", paramlist);
    op->SetAttr("file_path", {param_filename});
    op->CheckAttrs();
  }

94
  executor->Run(*load_program, scope, 0, true, true);
95

96
  delete load_program;
97
}
98

99 100
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                                             framework::Scope* scope,
K
kexinzhao 已提交
101
                                             const std::string& dirname) {
102 103
  std::string model_filename = dirname + "/__model__";
  std::string program_desc_str;
104
  VLOG(3) << "loading model from " << model_filename;
105
  ReadBinaryFile(model_filename, &program_desc_str);
106 107 108 109 110 111 112 113 114

  std::unique_ptr<framework::ProgramDesc> main_program(
      new framework::ProgramDesc(program_desc_str));

  LoadPersistables(executor, scope, *main_program, dirname, "");
  return main_program;
}

std::unique_ptr<framework::ProgramDesc> Load(
115
    framework::Executor* executor, framework::Scope* scope,
116
    const std::string& prog_filename, const std::string& param_filename) {
117 118
  std::string model_filename = prog_filename;
  std::string program_desc_str;
119
  ReadBinaryFile(model_filename, &program_desc_str);
120

K
kexinzhao 已提交
121 122
  std::unique_ptr<framework::ProgramDesc> main_program(
      new framework::ProgramDesc(program_desc_str));
123

124
  LoadPersistables(executor, scope, *main_program, "", param_filename);
125 126 127 128
  return main_program;
}

}  // namespace inference
129
}  // namespace paddle