cinn_compiler.cc 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"

17
#include <cstdint>
18
#include <iterator>
19 20
#include <map>
#include <memory>
21
#include <mutex>
22
#include <string>
23
#include <unordered_map>
24

25 26
#include "cinn/auto_schedule/auto_tuner.h"
#include "cinn/auto_schedule/tuning.h"
27 28
#include "cinn/common/target.h"
#include "cinn/common/type.h"
29
#include "cinn/frontend/op_mapper_registry.h"
30
#include "cinn/frontend/optimize.h"
31 32 33
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
34
#include "gflags/gflags.h"
35
#include "paddle/fluid/framework/framework.pb.h"
36 37
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
38
#include "paddle/fluid/framework/ir/node.h"
39
#include "paddle/fluid/framework/lod_tensor.h"
40
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
41 42 43
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
44
#include "paddle/fluid/inference/analysis/dot.h"
45
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
46
#include "paddle/fluid/platform/enforce.h"
47
#include "paddle/fluid/string/string_helper.h"
48

49
DECLARE_bool(enable_pe_launch_cinn);
50
DECLARE_bool(enable_cinn_auto_tune);
51 52 53 54
namespace paddle {
namespace framework {
namespace paddle2cinn {

55
using ::cinn::auto_schedule::AutoTuner;
56 57
using ::cinn::common::Target;
using ::cinn::frontend::Optimize;
58
using ::cinn::frontend::paddle::InplaceOutSuffix;
59
using ::cinn::hlir::framework::BuildScope;
60
using ::cinn::hlir::framework::GraphCompiler;
61 62 63
using inference::analysis::Dot;
using ir::Graph;
using ir::Node;
64

65 66
CinnCompiler *CinnCompiler::GetInstance() {
  static CinnCompiler *instance = new CinnCompiler();
S
sneaxiy 已提交
67
  return instance;
68 69
}

70 71
const CinnCompiledObject &CinnCompiler::Compile(
    const Graph &graph,
72
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
73 74
    const Target &target,
    void *stream) {
75
  VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
76 77
  CinnCacheKeyByAddress cur_key_by_address(
      graph, input_tensors, target.arch_str());
J
jiangcheng 已提交
78 79
  CinnCacheKeyByStructure cur_key_by_struct;

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  if (!cache_by_address_.count(cur_key_by_address)) {
    // generate the structure cache key
    cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
    if (!cache_by_struct_.count(cur_key_by_struct)) {
      std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
      auto compiled_res =
          CompileGraph(graph, input_tensors, target, compiled_num, stream);
      std::unique_lock<std::mutex> guard(lock_);
      // double check cache_by_struct_
      if (!cache_by_struct_.count(cur_key_by_struct)) {
        cache_by_struct_[cur_key_by_struct] = compiled_num;
        index2cache_.emplace(compiled_num, std::move(compiled_res));
      }
      // double check cache_by_address_
      if (!cache_by_address_.count(cur_key_by_address)) {
        cache_by_address_[cur_key_by_address] =
            cache_by_struct_.at(cur_key_by_struct);
      }
    } else {
      std::unique_lock<std::mutex> guard(lock_);
      // double check cache_by_address_
      if (!cache_by_address_.count(cur_key_by_address)) {
J
jiangcheng 已提交
102
        cache_by_address_[cur_key_by_address] =
103
            cache_by_struct_.at(cur_key_by_struct);
J
jiangcheng 已提交
104 105
      }
    }
106
  }
107
  return *index2cache_.at(cache_by_address_.at(cur_key_by_address));
108 109
}

110
const CinnCompiledObject &CinnCompiler::Compile(
111
    int64_t compilation_key,
112
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
113 114 115
    const Target &target,
    void *stream) {
  const auto &graph = FindGraph(compilation_key);
116
  return Compile(graph, input_tensors, target, stream);
117 118
}

119
const CinnCompiledObject &CinnCompiler::GetCompiledObject(
120 121
    int64_t cached_index) const {
  auto res = index2cache_.find(cached_index);
122 123
  PADDLE_ENFORCE_NE(res,
                    index2cache_.end(),
124 125 126 127 128
                    platform::errors::InvalidArgument(
                        "Index(%ld) not found in cache", cached_index));
  return *res->second;
}

129
int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
130
  int64_t graph_key = std::hash<Graph *>()((&(*graph)));
131
  PADDLE_ENFORCE_EQ(
132 133
      graphs_.count(graph_key),
      0,
134 135 136 137 138 139
      platform::errors::PreconditionNotMet(
          "The graph to be added is already in CinnCompiler, which is:\n",
          VizGraph(graph_key).c_str()));
  graphs_[graph_key] = std::move(graph);
  VLOG(4) << "-- Add a graph into CinnCompiler, which is:\n"
          << VizGraph(graph_key);
140 141 142
  return graph_key;
}

143
const Graph &CinnCompiler::FindGraph(int64_t graph_key) const {
144
  auto it = graphs_.find(graph_key);
145
  PADDLE_ENFORCE_NE(
146 147
      it,
      graphs_.end(),
148
      platform::errors::PreconditionNotMet(
149 150 151
          "Can not find the target graph, of which the key is: %lld",
          graph_key));
  return *it->second;
152 153
}

154
std::string CinnCompiler::VizGraph(int64_t graph_key) const {
155
  const Graph &graph = FindGraph(graph_key);
156 157 158
  return VizGraph(graph);
}

159
std::string CinnCompiler::VizGraph(const Graph &graph) const {
160
  Dot dot;
161
  std::unordered_map<const Node *, std::string> node2dot;
162 163
  int id = 0;
  // Create nodes
164
  for (const Node *n : graph.Nodes()) {
165 166
    std::string node_id = "Node" + std::to_string(id++);
    if (n->IsOp()) {
167 168 169 170 171 172 173
      dot.AddNode(node_id,
                  {Dot::Attr("shape", "box"),
                   Dot::Attr("style", "rounded,filled,bold"),
                   Dot::Attr("color", "#303A3A"),
                   Dot::Attr("fontcolor", "#ffffff")},
                  n->Name(),
                  true);
174 175 176 177 178
    } else if (n->IsVar()) {
      auto label = n->Name();
      if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
        auto shape = n->Var()->GetShape();
        std::vector<std::string> shape_str(shape.size());
179
        std::transform(
180
            shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) {
181 182
              return std::to_string(val);
            });
183 184 185 186
        label += "\n" + string::join_strings(shape_str, ',');
      }
      dot.AddNode(
          node_id,
187 188
          {Dot::Attr("shape", "box"),
           Dot::Attr("style", "rounded,filled,bold"),
189 190 191
           Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
           Dot::Attr("fontcolor",
                     n->Var()->IsParameter() ? "#ffffff" : "#000000")},
192 193
          label,
          true);
194 195 196 197
    }
    node2dot[n] = node_id;
  }
  // Create edges
198 199 200 201
  for (const Node *n : graph.Nodes()) {
    const auto &src_id = node2dot.at(n);
    for (auto *out : n->outputs) {
      const auto &dest_id = node2dot.at(out);
202 203
      dot.AddEdge(src_id, dest_id, {});
    }
204
  }
205
  return dot.Build();
206 207
}

208
std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
209
  const auto &graph = FindGraph(compilation_key);
210 211 212 213 214 215 216 217 218 219

  ProgramDesc program;
  GraphToProgram(graph, &program);

  std::string serial_graph;
  program.Proto()->SerializeToString(&serial_graph);
  return serial_graph;
}

std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
220
  const auto &graph = FindGraph(compilation_key);
221 222 223 224 225

  ProgramDesc program;
  GraphToProgram(graph, &program);

  return program.Proto()->DebugString();
226 227 228 229
}

void CinnCompiler::Clear() {
  {
230
    std::unique_lock<std::mutex> guard(lock_);
231
    graphs_.clear();
J
jiangcheng 已提交
232 233
    cache_by_address_.clear();
    cache_by_struct_.clear();
234
    index2cache_.clear();
235
  }
H
Huihuang Zheng 已提交
236
  real_compiled_num_.store(0);
237 238
}

239
void CinnCompiler::CheckCompiledValid(
240
    const ir::Graph &graph,
241
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
242 243
    const CinnCompiledObject &compiled_obj) const {
  const auto &input_var_names = graph.Get<std::vector<std::string>>(kInputVars);
244 245
  const auto &inplace_var_names =
      graph.Get<std::unordered_set<std::string>>(kInplaceVarNames);
246
  const auto &output_var_names =
247
      graph.Get<std::vector<std::string>>(kOutputVars);
248
  auto *launch_context = compiled_obj.launch_context.get();
249
  // 1. check all of the output variables will be assigned by compiled program
250 251 252 253 254
  for (auto var_name : output_var_names) {
    // inplace variables are renamed with a specified suffix
    if (inplace_var_names.count(var_name)) {
      var_name += InplaceOutSuffix;
    }
255 256
    PADDLE_ENFORCE_EQ(launch_context->IsVariableUsed(var_name),
                      true,
257 258 259 260
                      platform::errors::PreconditionNotMet(
                          "Variable(%s) not applied in CINN", var_name));
  }
  // 2. check all of the used input variables were correctly deduced by CINN.
261
  for (const auto &var_name : input_var_names) {
262 263 264 265 266 267 268 269 270 271 272
    // some input variables were not used by CINN because they were eliminated
    // by its optimized passes or some operators of it need less inputs
    if (!launch_context->IsVariableUsed(var_name)) {
      VLOG(4) << "Input variable" << var_name << " not used by cinn";
      continue;
    }
    launch_context->CheckTensorEquivalent(var_name,
                                          *input_tensors.at(var_name));
  }
}

273
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
274
    const ir::Graph &graph,
275
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
276
    const Target &target,
277
    std::int64_t compiled_num,
278
    void *stream) const {
279
  CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
280
  auto frontend_program = symbol();
281
  auto fetch_ids = symbol.GetFetchIds();
282 283
  VLOG(4) << "All fetch var ids in CINN: "
          << string::join_strings(fetch_ids, ',');
284

285 286
  auto cinn_graph = Optimize(&frontend_program, fetch_ids, target);
  VLOG(4) << "-- The " << compiled_num << "-th compilation ("
287 288
          << target.arch_str() << "), and its related graph:\n"
          << cinn_graph->Visualize();
289

290
  auto scope = BuildScope(target, cinn_graph);
291 292
  auto graph_compiler =
      std::make_unique<GraphCompiler>(target, scope, cinn_graph);
293 294
  GraphCompiler::CompileOptions options;
  options.with_instantiate_variables = false;
295 296 297
  if (!FLAGS_enable_pe_launch_cinn) {
    options.with_buffer_handle_instruction_inserted = true;
  }
298 299 300 301 302 303 304 305 306 307
  std::unique_ptr<AutoTuner> auto_tuner;
  if (FLAGS_enable_cinn_auto_tune) {
    VLOG(4) << "Compile with auto-tune";
    auto_tuner = std::make_unique<AutoTuner>(target, cinn_graph.get());
    auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get());
    ::cinn::auto_schedule::TuningOptions tuning_options;
    tuning_options.num_measure_trials = 0;
    auto tuning_result = auto_tuner->Tune(tuning_options);
    options.Apply(tuning_result);
  }
308 309
  auto compiled_res =
      graph_compiler->Build(options, std::move(fetch_ids), stream);
310
  auto compiled_obj = std::make_unique<CinnCompiledObject>();
311 312 313 314
  *compiled_obj = {std::move(graph_compiler),
                   std::move(auto_tuner),
                   std::move(compiled_res.runtime_program),
                   scope,
315
                   symbol.var_model_to_program_map()};
316
  compiled_obj->cached_index = compiled_num;
317 318 319
  compiled_obj->launch_context =
      std::make_unique<operators::details::CinnLaunchContext>(graph,
                                                              *compiled_obj);
320
  CheckCompiledValid(graph, input_tensors, *compiled_obj);
321 322 323 324 325 326
  return compiled_obj;
}

}  // namespace paddle2cinn
}  // namespace framework
}  // namespace paddle