serializer.cc 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/jit/serializer.h"

17 18
#include <set>

19 20
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/variable.h"
21 22
#include "paddle/fluid/platform/device_context.h"

23
#include "paddle/fluid/jit/engine/executor_engine.h"
24
#include "paddle/fluid/jit/engine/interpreter_engine.h"
25
#include "paddle/fluid/jit/engine/pe_engine.h"
26
#include "paddle/fluid/jit/engine/predictor_engine.h"
27 28
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/property.h"
29 30
#include "paddle/fluid/jit/serializer_utils.h"

31 32
DECLARE_string(jit_engine_type);

33 34
namespace paddle {
namespace jit {
35
using FunctionInfoMap =
36
    std::unordered_map<std::string, std::shared_ptr<FunctionInfo>>;
37 38 39
Layer Deserializer::operator()(const std::string& path,
                               const phi::Place& place) {
  const auto& pdmodel_paths = utils::PdmodelFilePaths(path);
40 41
  // set is ordered
  std::set<std::string> param_names_set;
42
  FunctionInfoMap info_map;
43
  for (auto& it : pdmodel_paths) {
44
    auto& func_name = it.first;
45
    auto program_desc = LoadProgram(it.second);
46

47 48
    std::vector<std::string> persist_var_names;
    auto all_var_desc = program_desc.Block(0).AllVars();
49
    for (auto* desc_ptr : all_var_desc) {
50
      if (utils::IsPersistable(desc_ptr)) {
51
        persist_var_names.emplace_back(desc_ptr->Name());
52 53 54
      }
    }

55
    param_names_set.insert(persist_var_names.begin(), persist_var_names.end());
56 57
    info_map[func_name] = std::make_shared<FunctionInfo>(
        func_name, persist_var_names, program_desc);
58
    info_map[func_name]->SetProgramFilePath(it.second);
59 60
  }

61 62
  VariableMap params_dict;
  VariableMap attrs_dict;
63
  ReadTensorData(path + PDPARAMS_SUFFIX, param_names_set, place, &params_dict);
64

65 66 67 68 69
  if (utils::FileExists(path + PROPERTY_SUFFIX)) {
    ReadAttributeData(path + PROPERTY_SUFFIX, &attrs_dict);
    VLOG(3) << "Read Property Success!";
  }

70
  Layer layer = Layer(params_dict, attrs_dict, info_map, place);
71

72 73 74
  for (auto it = info_map.begin(); it != info_map.end(); ++it) {
    const std::string& func_name = it->first;
    auto& info = it->second;
75 76
    VLOG(3) << "Add function type: " << FLAGS_jit_engine_type
            << " Function name: " << func_name;
77
    if (FLAGS_jit_engine_type == "Executor") {
78 79 80
      layer.SetEngine(
          func_name,
          utils::MakeEngine<ExecutorEngine>(info, params_dict, place));
81
    } else if (FLAGS_jit_engine_type == "PE") {
82 83
      layer.SetEngine(func_name,
                      utils::MakeEngine<PEEngine>(info, params_dict, place));
84 85 86 87
    } else if (FLAGS_jit_engine_type == "New") {
      layer.SetEngine(
          func_name,
          utils::MakeEngine<InterpreterEngine>(info, params_dict, place));
88 89 90 91
    } else if (FLAGS_jit_engine_type == "Predictor") {
      layer.SetEngine(
          info->FunctionName(),
          utils::MakeEngine<PredictorEngine>(info, params_dict, place));
92
    } else {
93
      PD_THROW("Invalid JitLayer engine type.");
94
    }
95 96
  }

97
  return layer;
98 99
}

100 101 102
void Deserializer::ReadTensorData(const std::string& file_name,
                                  const std::set<std::string>& var_name,
                                  const phi::Place& place,
103
                                  VariableMap* params_dict) const {
104 105 106
  VLOG(3) << "ReadTensorData from: " << file_name;
  std::ifstream fin(file_name, std::ios::binary);
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
107
  auto& dev_ctx = *pool.Get(place);
108 109 110 111 112 113
  for (auto it = var_name.begin(); it != var_name.end(); it++) {
    VLOG(3) << "load Tensor: " << *it;
    Variable v;
    // TODO(dev): Support framework::Vocab
    DenseTensor* dense_tesnor = v.GetMutable<DenseTensor>();
    framework::DeserializeFromStream(fin, dense_tesnor, dev_ctx);
114
    (*params_dict)[*it] = std::make_shared<Variable>(v);
115 116 117
  }
}

118
void Deserializer::ReadAttributeData(const std::string& file_path,
119
                                     VariableMap* attrs_dict) const {
120 121 122
  VLOG(3) << "ReadPropertyData from: " << file_path;
  Property p;
  p.Deserialization(file_path);
123
  *attrs_dict = static_cast<VariableMap>(p.Values());
124 125
  return;
}
126

127
framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) {
128
  VLOG(3) << "LoadProgram from: " << file_name;
129 130 131 132 133 134 135 136 137
  std::ifstream fin(file_name, std::ios::in | std::ios::binary);
  fin.seekg(0, std::ios::end);
  std::string buffer(fin.tellg(), ' ');
  fin.seekg(0, std::ios::beg);
  fin.read(&buffer[0], buffer.size());
  fin.close();
  return framework::ProgramDesc(buffer);
}

138
Layer Load(const std::string& file_path, const phi::Place& place) {
139
  auto deserializer = Deserializer();
140
  return deserializer(file_path, place);
141 142 143 144
}

}  // namespace jit
}  // namespace paddle