collect_ir_nodes.cc 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/collect_ir_nodes.h"

#include <glog/logging.h>

#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn {
namespace ir {

namespace {

struct IrNodesCollector : public IRVisitor {
28
  using teller_t = std::function<bool(const Expr*)>;
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
  using handler_t = std::function<void(const Expr*)>;

  teller_t teller;
  handler_t handler;
  bool uniq_target_;
  bool find_target_{false};

  IrNodesCollector(teller_t&& teller, handler_t&& handler, bool uniq_target)
      : teller(teller), handler(handler), uniq_target_(uniq_target) {}

  void Visit(const Expr* expr) override {
    if (!expr->defined() || find_target_) return;
    if (visited_.count(expr->get())) return;

    if (teller(expr)) {
      handler(expr);
      if (uniq_target_) {
        find_target_ = true;
        return;
      }
    }
    visited_.insert(expr->get());

    switch (expr->node_type()) {
#define __(op__)                 \
  case ir::IrNodeTy::op__:       \
    Visit(expr->As<ir::op__>()); \
    break;

      NODETY_FORALL(__)

      default:
        LOG(FATAL) << "not supported NodeTy";
#undef __
    }
  }

#define __m(t__)                       \
  void Visit(const t__* x) override {  \
    for (auto* n : x->expr_fields()) { \
      if (n->defined()) {              \
        Visit(n);                      \
      }                                \
    }                                  \
  }

  NODETY_FORALL(__m)
#undef __m
  std::set<void*> visited_;
};

struct IrNodesWithoutTensorCollector : public IrNodesCollector {
81
  using teller_t = std::function<bool(const Expr*)>;
82
  using handler_t = std::function<void(const Expr*)>;
83 84 85
  IrNodesWithoutTensorCollector(teller_t teller,
                                handler_t handler,
                                bool uniq_target)
86 87 88 89 90 91 92 93 94 95 96 97
      : IrNodesCollector(std::move(teller), std::move(handler), uniq_target) {}

  void Visit(const _Tensor_* expr) override {
    for (auto& e : expr->shape) {
      IrNodesCollector::Visit(&e);
    }
  }
  void Visit(const Expr* expr) override { IrNodesCollector::Visit(expr); }
};

}  // namespace

98 99 100
std::set<Expr> CollectIRNodes(Expr expr,
                              std::function<bool(const Expr*)>&& teller,
                              bool uniq_target) {
101
  std::set<Expr> exprs;
102 103 104 105 106
  IrNodesCollector::handler_t handler = [&](const Expr* x) {
    exprs.insert(*x);
  };
  IrNodesCollector collector(
      std::move(teller), std::move(handler), uniq_target);
107 108 109 110
  collector.Visit(&expr);
  return exprs;
}

111 112
std::vector<Expr> CollectIRNodesInOrder(
    Expr expr, std::function<bool(const Expr*)>&& teller) {
113
  std::vector<Expr> exprs;
114 115 116 117 118
  IrNodesWithoutTensorCollector::handler_t handler = [&](const Expr* x) {
    exprs.push_back(*x);
  };
  IrNodesWithoutTensorCollector collector(
      std::move(teller), std::move(handler), false);
119 120 121 122
  collector.Visit(&expr);
  return exprs;
}

123 124
std::set<Expr> CollectIRNodesWithoutTensor(
    Expr expr, std::function<bool(const Expr*)>&& teller, bool uniq_target) {
125
  std::set<Expr> exprs;
126 127 128 129 130
  IrNodesWithoutTensorCollector::handler_t handler = [&](const Expr* x) {
    exprs.insert(*x);
  };
  IrNodesWithoutTensorCollector collector(
      std::move(teller), std::move(handler), uniq_target);
131 132 133 134
  collector.Visit(&expr);
  return exprs;
}

135 136
std::map<std::string, Expr> CollectTensorMap(
    Expr x, std::function<bool(const Expr*)>&& extra_teller) {
137 138
  std::map<std::string, Expr> tensor_map;

139 140
  auto tensors = CollectIRNodes(
      x, [&](const Expr* x) { return x->as_tensor() && extra_teller(x); });
141
  for (auto& e : tensors) {
142
    auto* t = e.as_tensor();
143 144 145 146 147
    tensor_map[t->name] = e;
  }
  return tensor_map;
}

148 149
std::set<Expr> CollectLoadTensors(Expr x,
                                  std::function<bool(const Expr*)>&& teller) {
150 151 152 153
  if (!x.defined()) return std::set<Expr>();
  struct Mutator : public ir::IRMutator<const Expr*> {
    std::function<bool(const Expr*)> teller;
    std::set<Expr> exprs;
154
    explicit Mutator(std::function<bool(const Expr*)>&& teller)
155
        : teller(std::move(teller)) {}
156

157 158 159
    void operator()(const Expr* expr) {
      ir::IRMutator<const Expr*>::Visit(expr, expr);
    }
160 161 162 163 164 165 166 167 168 169 170 171 172

    void Visit(const Load* op, const Expr* expr) override {
      if (teller(&op->tensor)) {
        exprs.insert(op->tensor);
      }
    }
  };

  Mutator mutator(std::move(teller));
  mutator(&x);
  return mutator.exprs;
}

173 174
std::set<Expr> CollectStoreTensors(Expr x,
                                   std::function<bool(const Expr*)>&& teller) {
175 176 177
  struct Mutator : public ir::IRMutator<const Expr*> {
    std::function<bool(const Expr*)> teller;
    std::set<Expr> exprs;
178
    explicit Mutator(std::function<bool(const Expr*)>&& teller)
179
        : teller(std::move(teller)) {}
180

181 182 183
    void operator()(const Expr* expr) {
      ir::IRMutator<const Expr*>::Visit(expr, expr);
    }
184 185 186 187 188 189 190 191 192 193 194 195 196

    void Visit(const Store* op, const Expr* expr) override {
      if (teller(&op->tensor)) {
        exprs.insert(op->tensor);
      }
    }
  };

  Mutator mutator(std::move(teller));
  mutator(&x);
  return mutator.exprs;
}

197 198
std::set<Expr> CollectReferencedTensors(
    Expr x, const std::function<bool(const Expr*)>& teller) {
199 200 201 202 203 204 205 206 207 208 209 210 211 212
  auto handle0 = teller;
  auto handle1 = teller;

  auto ts0 = CollectLoadTensors(x, std::move(handle0));
  auto ts1 = CollectLoadTensors(x, std::move(handle1));

  for (auto& item : ts1) {
    ts0.insert(item);
  }
  return ts0;
}

}  // namespace ir
}  // namespace cinn