paddle_model_convertor.cc 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/frontend/paddle_model_convertor.h"

#include <glog/logging.h>

#include <algorithm>
#include <unordered_set>
#include <utility>

#include "paddle/cinn/frontend/op_mappers/use_op_mappers.h"
#include "paddle/cinn/frontend/paddle/cpp/op_desc.h"
#include "paddle/cinn/frontend/paddle/cpp/program_desc.h"
#include "paddle/cinn/frontend/paddle/model_parser.h"
#include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/cinn/hlir/op/use_ops.h"

30 31
DECLARE_double(cinn_infer_model_version);

32 33 34 35 36
namespace cinn {
namespace frontend {

using cinn::utils::Attribute;

37 38
PaddleModelConvertor::PaddleModelConvertor()
    : PaddleModelConvertor(common::DefaultTarget(), nullptr, nullptr) {}
39

40 41 42 43
PaddleModelConvertor::PaddleModelConvertor(
    const common::Target& target,
    std::shared_ptr<NetBuilder> builder,
    std::shared_ptr<hlir::framework::Scope> scope)
44 45 46
    : target_(target), builder_(builder), scope_(scope) {
  if (!builder_) {
    // do not need scope
47 48
    builder_ =
        std::make_shared<NetBuilder>(cinn::UniqName("PaddleModelConvertor"));
49 50 51 52 53
  }
  if (!scope_) {
    // do not need scope
    scope_ = hlir::framework::Scope::Create();
  }
54 55 56 57 58 59
  ctx_ = std::make_unique<OpMapperContext>(*scope_,
                                           target_,
                                           builder_.get(),
                                           &var_map_,
                                           &var_model_to_program_map_,
                                           &fetch_var_names_);
60 61
}

62 63
void PaddleModelConvertor::PrepareRun(const paddle::cpp::BlockDesc& block_desc,
                                      OpMapperContext* ctx) {
64 65 66
  std::unordered_map<std::string, const paddle::cpp::VarDesc*> var_desc_map;
  // preserve var desc info lik shape and dtype
  for (int i = 0; i < block_desc.VarsSize(); i++) {
67
    const auto& var_desc = block_desc.GetConstVar<paddle::cpp::VarDesc>(i);
68 69 70 71 72 73 74 75
    var_desc_map[var_desc.Name()] = &var_desc;
  }

  for (int i = 0; i < block_desc.OpsSize(); i++) {
    const auto& op_desc = block_desc.GetConstOp<paddle::cpp::OpDesc>(i);

    if (op_desc.Type() == "feed") {
      for (const auto& var_name : op_desc.output_vars()) {
76 77 78 79
        CHECK(var_desc_map.count(var_name))
            << "Feed var [" << var_name << "] Not found in block";
        ctx->AddFeedInfo(var_name,
                         utils::GetFeedInfoFromDesc(*var_desc_map[var_name]));
80 81 82 83 84
      }
    }
  }
}

85 86
void PaddleModelConvertor::RunOp(const paddle::cpp::OpDesc& op_desc,
                                 const OpMapperContext& ctx) {
87
  const auto& op_type = op_desc.Type();
88
  auto kernel = OpMapperRegistry::Global()->Find(op_type);
89 90 91 92 93 94 95
  CHECK(kernel) << "Op [" << op_type << "] Not supported in OpMapper";
  VLOG(4) << "Running Op " << op_type;
  kernel->Run(op_desc, ctx);
}

std::unordered_map<std::string, Variable> PaddleModelConvertor::GetFetchList(
    const std::unordered_set<std::string>& fetch_name_list) const {
96 97
  // the return map's key is paddle variable name, the value is the cinn fetch
  // variable
98 99 100
  const std::unordered_set<std::string>* var_name_list = &fetch_name_list;
  if (fetch_name_list.empty()) {
    // if paddle var list is empty, fetch the program's fetch var instead
101 102
    CHECK(!fetch_var_names_.empty())
        << "Should not fetch empty variable in CINN.";
103 104 105 106 107 108 109
    var_name_list = &fetch_var_names_;
  }

  std::unordered_map<std::string, Variable> fetch_list;
  fetch_list.reserve(var_name_list->size());
  for (const auto& pd_name : *var_name_list) {
    CHECK(var_model_to_program_map_.count(pd_name))
110 111
        << "Cannot find cinn variable [" << pd_name
        << "] in var_model_to_program_map_";
112 113 114 115 116 117 118 119 120 121 122
    auto norm_pd_name = pd_name;
    // remove inplace output's suffix
    auto pos = pd_name.find(paddle::InplaceOutSuffix);
    if (pos != std::string::npos) {
      norm_pd_name.replace(pos, sizeof(paddle::InplaceOutSuffix), "");
    }
    fetch_list[pd_name] = var_map_.at(norm_pd_name);
  }
  return fetch_list;
}

123 124 125 126
Program PaddleModelConvertor::LoadModel(
    const std::string& model_dir,
    bool is_combined,
    const std::unordered_map<std::string, std::vector<int64_t>>& feed) {
127
  paddle::cpp::ProgramDesc program_desc;
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  if (FLAGS_cinn_infer_model_version < 2.0) {
    paddle::LoadModelPb(model_dir,
                        "/__model__",
                        "/params",
                        scope_.get(),
                        &program_desc,
                        is_combined,
                        false,
                        target_);
  } else {
    paddle::LoadModelPb(model_dir,
                        ".pdmodel",
                        ".pdiparams",
                        scope_.get(),
                        &program_desc,
                        is_combined,
                        false,
                        target_);
  }
147 148
  CHECK_EQ(program_desc.BlocksSize(), 1)
      << "CINN can only support the model with a single block";
149 150 151 152
  auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0);

  // Set feeds shape
  for (int i = 0; i < block_desc->VarsSize(); i++) {
153
    auto* var_desc = block_desc->GetVar<paddle::cpp::VarDesc>(i);
154 155 156
    const auto var_name = var_desc->Name();
    if (feed.count(var_name)) {
      const auto& var_shape = feed.at(var_name);
157 158
      VLOG(4) << "Update var " << var_name
              << "'s shape to: " << cinn::utils::Join(var_shape, ", ");
159 160 161 162
      var_desc->SetShape(var_shape);
    }
  }

163 164 165 166 167 168
  OpMapperContext ctx(*scope_,
                      target_,
                      builder_.get(),
                      &var_map_,
                      &var_model_to_program_map_,
                      &fetch_var_names_);
169 170 171 172 173 174 175 176 177

  PrepareRun(*block_desc, &ctx);
  for (int i = 0; i < block_desc->OpsSize(); i++) {
    auto* op_desc = block_desc->GetOp<paddle::cpp::OpDesc>(i);
    RunOp(*op_desc, ctx);
  }
  return builder_->Build();
}

178 179 180
void SetOpDescAttr(const std::string& attr_name,
                   const Attribute& attr_value,
                   paddle::cpp::OpDesc* op_desc) {
181 182
  class Visitor {
   public:
183 184
    Visitor(paddle::cpp::OpDesc* op_desc, const std::string& attr_name)
        : op_desc_(op_desc), attr_name_(attr_name) {}
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209

#define VISITOR_EXPAND(TYPE) \
  void operator()(const TYPE& v) { op_desc_->SetAttr(attr_name_, v); }

    VISITOR_EXPAND(bool)
    VISITOR_EXPAND(float)
    VISITOR_EXPAND(int)
    VISITOR_EXPAND(std::string)
    VISITOR_EXPAND(std::vector<bool>)
    VISITOR_EXPAND(std::vector<int>)
    VISITOR_EXPAND(std::vector<float>)
    VISITOR_EXPAND(std::vector<std::string>)
    VISITOR_EXPAND(int64_t)
    VISITOR_EXPAND(double)
    VISITOR_EXPAND(std::vector<int64_t>)
    VISITOR_EXPAND(std::vector<double>)
#undef VISITOR_EXPAND

   private:
    paddle::cpp::OpDesc* op_desc_;
    const std::string& attr_name_;
  };
  absl::visit(Visitor{op_desc, attr_name}, attr_value);
}

210 211 212 213 214 215
void PaddleModelConvertor::RunOp(
    const std::string& op_type,
    const std::map<std::string, std::vector<std::string>>& inputs,
    const std::map<std::string, std::vector<std::string>>& outputs,
    const std::map<std::string, Attribute>& attrs,
    const OpMapperContext& ctx) {
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  paddle::cpp::OpDesc op_desc;
  op_desc.SetType(op_type);
  for (const auto& in_pair : inputs) {
    op_desc.SetInput(in_pair.first, in_pair.second);
  }
  for (const auto& out_pair : outputs) {
    op_desc.SetOutput(out_pair.first, out_pair.second);
  }
  for (const auto& attr_pair : attrs) {
    SetOpDescAttr(attr_pair.first, attr_pair.second, &op_desc);
  }

  RunOp(op_desc, ctx);
}

231 232 233 234 235
void PaddleModelConvertor::RunOp(
    const std::string& op_type,
    const std::map<std::string, std::vector<std::string>>& inputs,
    const std::map<std::string, std::vector<std::string>>& outputs,
    const std::map<std::string, Attribute>& attrs) {
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
  RunOp(op_type, inputs, outputs, attrs, *ctx_);
}

Program PaddleModelConvertor::operator()() { return builder_->Build(); }

void PaddleModelConvertor::CreateInput(const std::string& dtype,
                                       const cinn::utils::ShapeType& shape,
                                       const std::string& name) {
  OpMapperContext::FeedInfo feed_info = {shape, common::Str2Type(dtype)};

  ctx_->AddFeedInfo(name, feed_info);
  RunOp("feed", {}, {{"Out", {name}}}, {});
}

}  // namespace frontend
}  // namespace cinn