cinn_compiler.cc 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"

17
#include <cstdint>
18
#include <iterator>
19 20 21
#include <map>
#include <memory>
#include <string>
22
#include <unordered_map>
23

24 25
#include "cinn/auto_schedule/auto_tuner.h"
#include "cinn/auto_schedule/tuning.h"
26 27 28 29 30 31 32 33 34 35
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h"
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
36
#include "gflags/gflags.h"
37
#include "paddle/fluid/framework/framework.pb.h"
38 39
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
40
#include "paddle/fluid/framework/ir/node.h"
41
#include "paddle/fluid/framework/lod_tensor.h"
42
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
43 44 45
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
46
#include "paddle/fluid/inference/analysis/dot.h"
47
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
48
#include "paddle/fluid/platform/enforce.h"
49
#include "paddle/fluid/string/string_helper.h"
50
#include "paddle/phi/core/utils/rw_lock.h"
51

52
DECLARE_bool(enable_pe_launch_cinn);
53
DECLARE_bool(enable_cinn_auto_tune);
54 55 56 57 58
namespace paddle {
namespace framework {
namespace paddle2cinn {

using ir::Graph;
59 60
using ir::Node;
using inference::analysis::Dot;
61 62 63
using ::cinn::common::Target;
using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler;
64
using ::cinn::auto_schedule::AutoTuner;
65 66 67 68 69 70 71 72 73
using ::cinn::hlir::framework::BuildScope;
using ::cinn::frontend::ProgramPass;
using ::cinn::hlir::framework::ApplyPass;

CinnCompiler* CinnCompiler::GetInstance() {
  static CinnCompiler instance;
  return &instance;
}

74 75 76
const CinnCompiledObject& CinnCompiler::Compile(
    const Graph& graph,
    const std::map<std::string, const LoDTensor*>& input_tensors,
77
    const Target& target, void* stream) {
78
  VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph);
J
jiangcheng 已提交
79 80 81 82
  CinnCacheKeyByAddress cur_key_by_address(graph, input_tensors,
                                           target.arch_str());
  CinnCacheKeyByStructure cur_key_by_struct;

83 84
  bool exist = false;
  {
85
    phi::AutoRDLock r_guard{&rwlock_};
J
jiangcheng 已提交
86 87 88 89 90 91 92 93 94 95 96 97
    exist = cache_by_address_.count(cur_key_by_address) != 0;
    // if cannot find graph by address, checkout whether the graph structure
    // have been stored in cache.
    if (!exist) {
      // generate the structure cache key
      cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());

      // if the graph structure can be found, storing the graph address in
      // cache for next query.
      if (cache_by_struct_.count(cur_key_by_struct) != 0) {
        exist = true;
        cache_by_address_[cur_key_by_address] =
98
            cache_by_struct_.at(cur_key_by_struct);
J
jiangcheng 已提交
99 100
      }
    }
101 102
  }
  if (!exist) {
103 104
    std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
    auto compiled_res =
105
        CompileGraph(graph, input_tensors, target, compiled_num, stream);
106
    phi::AutoWRLock w_guard{&rwlock_};
J
jiangcheng 已提交
107
    if (!cache_by_struct_.count(cur_key_by_struct)) {
108 109 110
      cache_by_address_[cur_key_by_address] = compiled_num;
      cache_by_struct_[cur_key_by_struct] = compiled_num;
      index2cache_.emplace(compiled_num, std::move(compiled_res));
111 112
    }
  }
113
  phi::AutoRDLock guard{&rwlock_};
114
  const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]];
115 116 117 118 119 120
  return cached_boj;
}

const CinnCompiledObject& CinnCompiler::Compile(
    const std::string& compilation_key,
    const std::map<std::string, const LoDTensor*>& input_tensors,
121
    const Target& target, void* stream) {
122
  const auto& graph = FindGraph(compilation_key);
123
  return Compile(graph, input_tensors, target, stream);
124 125
}

126 127 128 129 130 131 132 133 134
const CinnCompiledObject& CinnCompiler::GetCompiledObject(
    int64_t cached_index) const {
  auto res = index2cache_.find(cached_index);
  PADDLE_ENFORCE_NE(res, index2cache_.end(),
                    platform::errors::InvalidArgument(
                        "Index(%ld) not found in cache", cached_index));
  return *res->second;
}

135 136 137 138 139
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
  std::string graph_key;
  ProgramDesc program;
  GraphToProgram(*graph, &program);
  program.Proto()->SerializeToString(&graph_key);
140 141 142 143 144 145 146 147 148

  PADDLE_ENFORCE_EQ(
      graphs_.count(graph_key), 0,
      platform::errors::PreconditionNotMet(
          "The graph to be added is already in CinnCompiler, which is:\n",
          VizGraph(graph_key).c_str()));
  graphs_[graph_key] = std::move(graph);
  VLOG(4) << "-- Add a graph into CinnCompiler, which is:\n"
          << VizGraph(graph_key);
149 150 151 152 153 154
  return graph_key;
}

const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
  PADDLE_ENFORCE_NE(
      graphs_.count(graph_key), 0,
155 156 157
      platform::errors::PreconditionNotMet(
          "Can not find the target graph, of which the key is:\n%s",
          ReadableKey(graph_key).c_str()));
158 159 160
  return *graphs_.at(graph_key);
}

161 162 163 164 165 166
std::string CinnCompiler::VizGraph(const std::string& graph_key) const {
  const Graph& graph = FindGraph(graph_key);
  return VizGraph(graph);
}

std::string CinnCompiler::VizGraph(const Graph& graph) const {
167 168 169 170 171 172 173 174 175 176 177
  Dot dot;
  std::unordered_map<const Node*, std::string> node2dot;
  int id = 0;
  // Create nodes
  for (const Node* n : graph.Nodes()) {
    std::string node_id = "Node" + std::to_string(id++);
    if (n->IsOp()) {
      dot.AddNode(
          node_id,
          {Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
           Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")},
178
          n->Name(), true);
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    } else if (n->IsVar()) {
      auto label = n->Name();
      if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
        auto shape = n->Var()->GetShape();
        std::vector<std::string> shape_str(shape.size());
        std::transform(shape.begin(), shape.end(), shape_str.begin(),
                       [](const auto& val) { return std::to_string(val); });
        label += "\n" + string::join_strings(shape_str, ',');
      }
      dot.AddNode(
          node_id,
          {Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
           Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
           Dot::Attr("fontcolor",
                     n->Var()->IsParameter() ? "#ffffff" : "#000000")},
194
          label, true);
195 196 197 198 199 200 201 202 203 204
    }
    node2dot[n] = node_id;
  }
  // Create edges
  for (const Node* n : graph.Nodes()) {
    const auto& src_id = node2dot.at(n);
    for (auto* out : n->outputs) {
      const auto& dest_id = node2dot.at(out);
      dot.AddEdge(src_id, dest_id, {});
    }
205
  }
206
  return dot.Build();
207 208
}

209 210
std::string CinnCompiler::ReadableKey(
    const std::string& compilation_key) const {
211
  proto::ProgramDesc desc;
212
  desc.ParseFromString(compilation_key);
213 214 215 216 217
  return desc.DebugString();
}

void CinnCompiler::Clear() {
  {
218
    phi::AutoWRLock guard{&rwlock_};
219
    graphs_.clear();
J
jiangcheng 已提交
220 221
    cache_by_address_.clear();
    cache_by_struct_.clear();
222
    index2cache_.clear();
223
  }
H
Huihuang Zheng 已提交
224
  real_compiled_num_.store(0);
225 226
}

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
void CinnCompiler::CheckCompiledValid(
    const ir::Graph& graph,
    const std::map<std::string, const LoDTensor*>& input_tensors,
    const CinnCompiledObject& compiled_obj) const {
  const auto& input_var_names = graph.Get<std::vector<std::string>>(kInputVars);
  const auto& output_var_names =
      graph.Get<std::vector<std::string>>(kOutputVars);
  auto* launch_context = compiled_obj.launch_context.get();
  // 1. check all of the output variables will be assigned by compiled program
  for (auto&& var_name : output_var_names) {
    PADDLE_ENFORCE_EQ(launch_context->IsVariableUsed(var_name), true,
                      platform::errors::PreconditionNotMet(
                          "Variable(%s) not applied in CINN", var_name));
  }
  // 2. check all of the used input variables were correctly deduced by CINN.
  for (const auto& var_name : input_var_names) {
    // some input variables were not used by CINN because they were eliminated
    // by its optimized passes or some operators of it need less inputs
    if (!launch_context->IsVariableUsed(var_name)) {
      VLOG(4) << "Input variable" << var_name << " not used by cinn";
      continue;
    }
    launch_context->CheckTensorEquivalent(var_name,
                                          *input_tensors.at(var_name));
  }
}

254 255 256
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
    const ir::Graph& graph,
    const std::map<std::string, const LoDTensor*>& input_tensors,
257
    const Target& target, std::int64_t compiled_num, void* stream) const {
258
  CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
259
  auto frontend_program = symbol();
260
  auto fetch_ids = symbol.GetFetchIds();
261
  ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"});
262
  ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity");
263 264 265
  ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding");
  ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"});

266 267
  auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
      frontend_program, target);
268
  VLOG(1) << "-- The " << compiled_num << "-th compilation ("
269 270 271 272
          << target.arch_str() << "), and its related graph:\n"
          << cinn_graph->Visualize();
  ApplyPass(cinn_graph.get(), "OpFusion");
  auto scope = BuildScope(target, cinn_graph);
273

274 275 276
  VLOG(4) << "All fetch var ids in CINN: "
          << string::join_strings(fetch_ids, ',');

277 278
  auto graph_compiler =
      std::make_unique<GraphCompiler>(target, scope, cinn_graph);
279 280
  GraphCompiler::CompileOptions options;
  options.with_instantiate_variables = false;
281 282 283
  if (!FLAGS_enable_pe_launch_cinn) {
    options.with_buffer_handle_instruction_inserted = true;
  }
284 285 286 287 288 289 290 291 292 293
  std::unique_ptr<AutoTuner> auto_tuner;
  if (FLAGS_enable_cinn_auto_tune) {
    VLOG(4) << "Compile with auto-tune";
    auto_tuner = std::make_unique<AutoTuner>(target, cinn_graph.get());
    auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get());
    ::cinn::auto_schedule::TuningOptions tuning_options;
    tuning_options.num_measure_trials = 0;
    auto tuning_result = auto_tuner->Tune(tuning_options);
    options.Apply(tuning_result);
  }
294 295
  auto compiled_res =
      graph_compiler->Build(options, std::move(fetch_ids), stream);
296
  auto compiled_obj = std::make_unique<CinnCompiledObject>();
297
  *compiled_obj = {std::move(graph_compiler), std::move(auto_tuner),
298
                   std::move(compiled_res.runtime_program), scope,
299
                   symbol.var_model_to_program_map()};
300
  compiled_obj->cached_index = compiled_num;
301 302 303
  compiled_obj->launch_context =
      std::make_unique<operators::details::CinnLaunchContext>(graph,
                                                              *compiled_obj);
304
  CheckCompiledValid(graph, input_tensors, *compiled_obj);
305 306 307 308 309 310
  return compiled_obj;
}

}  // namespace paddle2cinn
}  // namespace framework
}  // namespace paddle