build_cinn_pass.h 4.0 KB
Newer Older
J
jiangcheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19 20 21 22
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>

J
jiangcheng 已提交
23 24 25 26
#include "paddle/fluid/framework/ir/pass.h"

namespace paddle {
namespace framework {
27 28
namespace ir {
class MemOptVarInfo;
29
class Node;
30 31
}  // namespace ir

J
jiangcheng 已提交
32 33
namespace paddle2cinn {

34
constexpr char kCinnLaunchOp[] = "cinn_launch";
35 36 37 38
constexpr char kInputVars[] = "InputVars";
constexpr char kNoNeedBufferFeeds[] = "NoNeedBufferFeeds";
constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars";
39 40
constexpr char kMemOptVarInfoFromMainGraph[] =
    "mem_opt_var_info_from_main_graph";
41
constexpr char kSkipGcVarNames[] = "skip_gc_vars";
42

43 44 45
using Name2VarInfoMap =
    std::unordered_map<std::string,
                       std::shared_ptr<framework::ir::MemOptVarInfo>>;
46 47
using GraphNodeSet = std::unordered_set<ir::Node*>;

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
// OpTransInfo contains informations used to detect subgraphs
// supported by the CINN compiler.
class OpTransInfo {
  using DyOpCondT =
      std::unordered_map<std::string, std::function<bool(const ir::Node&)>>;
  using DeParamCondT =
      std::unordered_map<std::string, std::unordered_set<std::string>>;

 public:
  OpTransInfo();

  const DyOpCondT& dynamic_op_cond() const { return dynamic_op_cond_; }

  const DeParamCondT& deny_param_cond() const { return deny_param_cond_; }

  const std::unordered_set<std::string>& default_deny_ops() const {
    return default_deny_ops_;
  }
66 67 68

  std::unordered_set<std::string> GetDenyVarNames(
      const GraphNodeSet& cluster) const;
69 70 71 72 73 74 75 76 77 78

  static bool IsInplaceOp(const OpDesc& op_desc);

 private:
  DyOpCondT dynamic_op_cond_;

  DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
                                {"batch_norm_grad", {"ReserveSpace"}}};

  std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
79
};
J
jiangcheng 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

// A pass named BuildCinnPass, the function of this pass is:
//
// a) Detect the subgraphs that can be compiled by the CINN compiler. We call a
// detected subgraph a cluster, which is consisted of several op nodes.
//
// b) Call the CINN compiler to compile each original cluster and get the
// compiled cluster, which is consisted of several kCinnLaunchOp.
//
// c) Replace the original cluster with corresponding compiled cluster on the
// original graph.
//
// In this pass, some questions are handled with cautions:
//
// a) How to determine whether two op nodes can be divided into a cluster?
// Firstly, both op nodes should be compile supported.
// Secondly, there should be a direct path between the two op nodes through a
// var node.
98
// Thirdly, there should be no extra path between the two op nodes through
J
jiangcheng 已提交
99 100
// unsupported op nodes.
// Lastly, if op nodes a and b can be divied into a cluster, op nodes b and c
101 102 103 104
// can be divided into a cluster, a and c can also be divided into a cluster.
// The implementation of cluster detection is encapsulated in the
// SubGraphDetector
// class.
J
jiangcheng 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118
//
// b) How to deal with the links between the var nodes in global graph and the
// op nodes in a cluster?
// We first add links between the var nodes in global graph and the op nodes in
// the compiled cluster, and then remove useless links between the var nodes in
// global graph and the op nodes in the original cluster.
class BuildCinnPass : public framework::ir::Pass {
 protected:
  void ApplyImpl(framework::ir::Graph* graph) const override;
};

}  // namespace paddle2cinn
}  // namespace framework
}  // namespace paddle