model_parser.cc 7.8 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

S
update  
superjomn 已提交
15
#include "paddle/fluid/lite/model_parser/model_parser.h"
S
superjomn 已提交
16
#include <algorithm>
S
update  
superjomn 已提交
17
#include <fstream>
S
superjomn 已提交
18
#include <limits>
19
#include "paddle/fluid/lite/core/compatible_tensor.h"
S
update  
superjomn 已提交
20 21
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/variable.h"
S
update  
superjomn 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace paddle {
namespace lite {

int SizeOfType(framework::proto::VarType::Type type) {
  using Type = framework::proto::VarType::Type;
  switch (static_cast<int>(type)) {
#define DO(desc, type)            \
  case Type::VarType_Type_##desc: \
    return sizeof(type);
    DO(BOOL, bool);
    DO(FP16, float);
    DO(FP32, float);
    DO(INT8, int8_t);
    DO(INT32, int);
    DO(INT64, int64_t);
#undef DO
    default:
S
Superjomn 已提交
40
      LOG(FATAL) << "unknown data type " << type;
S
update  
superjomn 已提交
41
  }
S
Superjomn 已提交
42
  return -1;
S
update  
superjomn 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
}

void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
  using Type = framework::proto::VarType::Type;
  uint32_t version;
  is.read(reinterpret_cast<char *>(&version), sizeof(version));
  CHECK_EQ(version, 0U) << "Only version 0 is supported";
  // read tensor desc
  framework::proto::VarType::TensorDesc desc;
  {
    // int32_t size
    // proto buffer
    int32_t size;
    is.read(reinterpret_cast<char *>(&size), sizeof(size));
    std::unique_ptr<char[]> buf(new char[size]);
    is.read(reinterpret_cast<char *>(buf.get()), size);
    CHECK(desc.ParseFromArray(buf.get(), size)) << "Cannot parse tensor desc";
  }

  // read tensor
63 64 65 66 67
  std::vector<int64_t> dims_vec;
  std::copy(desc.dims().begin(), desc.dims().end(),
            std::back_inserter(dims_vec));
  lite::DDim dims(dims_vec);
  tensor->Resize(dims);
S
update  
superjomn 已提交
68
  void *buf;
69
  size_t size = tensor->dims().production() * SizeOfType(desc.data_type());
S
update  
superjomn 已提交
70 71
  // alllocate memory
  switch (static_cast<int>(desc.data_type())) {
72 73 74
#define DO(desc, type)                  \
  case Type::VarType_Type_##desc:       \
    buf = tensor->mutable_data<type>(); \
S
update  
superjomn 已提交
75
    break;
76
    // DO(BOOL, bool);
S
update  
superjomn 已提交
77 78 79 80 81 82 83
    DO(FP32, float);
    DO(INT8, int8_t);
    DO(INT16, int16_t);
    DO(INT32, int32_t);
    DO(INT64, int64_t);
#undef DO
    default:
84
      LOG(FATAL) << "unknown type " << desc.data_type();
S
update  
superjomn 已提交
85 86 87 88 89 90 91
  }

  is.read(static_cast<char *>(buf), size);
}

void LoadLoDTensor(std::istream &is, Variable *var) {
  auto *tensor = var->GetMutable<lite::Tensor>();
S
Superjomn 已提交
92
  uint32_t version{};
S
update  
superjomn 已提交
93
  is.read(reinterpret_cast<char *>(&version), sizeof(version));
N
nhzlx 已提交
94
  VLOG(3) << "model version " << version;
S
update  
superjomn 已提交
95 96

  // Load LoD information
S
Superjomn 已提交
97
  uint64_t lod_level{};
S
update  
superjomn 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
  auto &lod = *tensor->mutable_lod();
  lod.resize(lod_level);
  for (uint64_t i = 0; i < lod_level; ++i) {
    uint64_t size;
    is.read(reinterpret_cast<char *>(&size), sizeof(size));
    std::vector<size_t> tmp(size / sizeof(size_t));
    is.read(reinterpret_cast<char *>(tmp.data()),
            static_cast<std::streamsize>(size));
    lod[i] = tmp;
  }

  TensorFromStream(is, tensor);
}

// TODO(Superjomn) support SelectedRows.

void ReadBinaryFile(const std::string &filename, std::string *contents) {
  std::ifstream fin(filename, std::ios::in | std::ios::binary);
Y
Yan Chunwei 已提交
117
  CHECK(fin.is_open()) << "Cannot open file: " << filename;
S
update  
superjomn 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  fin.seekg(0, std::ios::end);
  auto size = fin.tellg();
  contents->clear();
  contents->resize(size);
  fin.seekg(0, std::ios::beg);
  fin.read(&(contents->at(0)), contents->size());
  fin.close();
}

std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
    const std::string &path) {
  std::string desc_str;
  ReadBinaryFile(path, &desc_str);
  std::unique_ptr<framework::proto::ProgramDesc> main_program(
      new framework::proto::ProgramDesc);
  main_program->ParseFromString(desc_str);
  return main_program;
}

void LoadParams(const std::string &path) {}

C
update  
Chunwei 已提交
139 140 141
// Load directly to CPU, and latter transfer to other devices.
void LoadParam(const std::string &path, Variable *out) {
  std::ifstream fin(path, std::ios::binary);
S
Superjomn 已提交
142
  CHECK(fin.is_open()) << "failed to open file " << path;
C
update  
Chunwei 已提交
143 144 145
  LoadLoDTensor(fin, out);
}

S
superjomn 已提交
146 147
void LoadModel(const std::string &model_dir, Scope *scope,
               framework::proto::ProgramDesc *prog) {
S
update  
superjomn 已提交
148
  const std::string prog_path = model_dir + "/__model__";
S
superjomn 已提交
149
  *prog = *LoadProgram(prog_path);
S
update  
superjomn 已提交
150 151 152

  auto main_block = prog->blocks(0);
  for (auto &var : main_block.vars()) {
S
superjomn 已提交
153 154 155
    if (var.name() == "feed" || var.name() == "fetch" || !var.persistable())
      continue;

S
update  
superjomn 已提交
156
    std::string file_path = model_dir + "/" + var.name();
N
nhzlx 已提交
157
    VLOG(4) << "reading weight " << var.name();
S
superjomn 已提交
158

S
update  
superjomn 已提交
159
    std::ifstream file(file_path);
S
superjomn 已提交
160 161 162 163 164 165 166
    switch (var.type().type()) {
      case framework::proto::VarType_Type_LOD_TENSOR:
        LoadLoDTensor(file, scope->Var(var.name()));
        break;
      default:
        CHECK(false) << "unknown weight type";
    }
S
update  
superjomn 已提交
167 168
  }
}
S
superjomn 已提交
169

S
Superjomn 已提交
170
void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
S
Superjomn 已提交
171 172 173
  // the 1st field, uint32_t version
  constexpr uint32_t version = 0;
  os.write(reinterpret_cast<const char *>(&version), sizeof(version));
S
Superjomn 已提交
174 175

  {
S
Superjomn 已提交
176
    uint64_t size = tensor.lod().size();
S
Superjomn 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
    // the 2st field, LoD information
    // uint64_t lod_level
    // uint64_t lod_level_1 size in byte.
    // int*     lod_level_1 data
    // ...
    os.write(reinterpret_cast<const char *>(&size), sizeof(size));

    for (auto &each : tensor.lod()) {
      size = each.size() * sizeof(each.front());
      os.write(reinterpret_cast<const char *>(&size), sizeof(size));
      os.write(reinterpret_cast<const char *>(each.data()),
               static_cast<std::streamsize>(size));
    }
  }

S
Superjomn 已提交
192 193 194
  // There are two version fields in a LoDTensor.
  os.write(reinterpret_cast<const char *>(&version), sizeof(version));

S
Superjomn 已提交
195 196 197 198
  {  // the 2nd field, tensor description
    // int32_t  size
    // void*    protobuf message
    framework::proto::VarType::TensorDesc desc;
S
Superjomn 已提交
199 200
    // TODO(Superjomn) support other data types.
    desc.set_data_type(framework::proto::VarType_Type_FP32);
S
Superjomn 已提交
201 202 203
    auto dims = tensor.dims();
    auto *pb_dims = desc.mutable_dims();
    pb_dims->Resize(static_cast<int>(dims.size()), 0);
204
    auto dims_vec = dims.Vectorize();
205
    std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin());
S
Superjomn 已提交
206 207 208 209 210 211
    int32_t size = desc.ByteSize();
    os.write(reinterpret_cast<const char *>(&size), sizeof(size));
    auto out = desc.SerializeAsString();
    os.write(out.data(), size);
  }
  {  // the 3rd field, tensor data
212
    uint64_t size = tensor.data_size();
S
Superjomn 已提交
213 214 215 216
    CHECK_LT(size, std::numeric_limits<std::streamsize>::max())
        << "Index overflow when writing tensor";

#ifdef LITE_WITH_CUDA
217
    if (tensor.target() == TARGET(kCUDA)) {
S
Superjomn 已提交
218
      std::unique_ptr<char> tmp_buffer(new char[size]);
219
      TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data<float>(),
220
                                    tensor.data_size(), IoDirection::DtoH);
S
Superjomn 已提交
221 222
      os.write(static_cast<const char *>(tmp_buffer.get()),
               static_cast<std::streamsize>(size));
S
superjomn 已提交
223 224
    } else  // NOLINT
#endif      // LITE_WITH_CUDA
S
Superjomn 已提交
225 226 227 228 229 230 231
    {
      os.write(static_cast<const char *>(tensor.data<void>()),
               static_cast<std::streamsize>(size));
    }
  }
}

S
Superjomn 已提交
232 233
void SerializeTensor(std::ostream &os, const lite::Scope &scope,
                     const std::string &var_name) {
S
Superjomn 已提交
234
  // Store all the persistable vars.
S
Superjomn 已提交
235 236 237
  auto *var = scope.FindVar(var_name);
  const auto &tensor = var->Get<lite::Tensor>();
  TensorToStream(os, tensor);
S
Superjomn 已提交
238 239
}

S
update  
superjomn 已提交
240 241
}  // namespace lite
}  // namespace paddle