ir_printer.cc 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <list>
#include <ostream>
#include <string>
#include <unordered_map>

20
#include "paddle/ir/core/block.h"
21 22 23 24 25
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
26
#include "paddle/ir/core/utils.h"
27
#include "paddle/ir/core/value.h"
28 29 30 31 32 33 34

namespace ir {

namespace {
constexpr char newline[] = "\n";
}  // namespace

35
class BasicIRPrinter {
36
 public:
37
  explicit BasicIRPrinter(std::ostream& os) : os(os) {}
38

39
  void PrintType(Type type) {
K
kangguangli 已提交
40
    if (!type) {
41
      os << "<<NULL TYPE>>";
K
kangguangli 已提交
42 43 44
      return;
    }

45
    if (type.isa<Float16Type>()) {
46
      os << "f16";
47
    } else if (type.isa<Float32Type>()) {
48
      os << "f32";
49
    } else if (type.isa<Float64Type>()) {
50
      os << "f64";
51
    } else if (type.isa<Int16Type>()) {
52
      os << "i16";
53
    } else if (type.isa<Int32Type>()) {
54
      os << "i32";
55
    } else if (type.isa<Int64Type>()) {
56
      os << "i64";
57 58 59
    } else if (type.isa<VectorType>()) {
      os << "vec[";
      auto inner_types = type.dyn_cast<VectorType>().data();
60 61 62
      PrintInterleave(
          inner_types.begin(),
          inner_types.end(),
63 64 65
          [this](Type v) { this->PrintType(v); },
          [this]() { this->os << ","; });
      os << "]";
66 67 68 69 70 71
    } else {
      auto& dialect = type.dialect();
      dialect.PrintType(type, os);
    }
  }

72 73 74 75 76
  void PrintAttribute(const Attribute& attr) {
    if (!attr) {
      os << "<#AttrNull>";
      return;
    }
77

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    if (auto s = attr.dyn_cast<StrAttribute>()) {
      os << s.data();
    } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
      os << b.data();
    } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
      os << f.data();
    } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
      os << d.data();
    } else if (auto i = attr.dyn_cast<Int32_tAttribute>()) {
      os << i.data();
    } else if (auto i = attr.dyn_cast<Int64_tAttribute>()) {
      os << i.data();
    } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
      const auto& vec = arr.data();
      os << "array[";
      PrintInterleave(
          vec.begin(),
          vec.end(),
          [this](Attribute v) { this->PrintAttribute(v); },
          [this]() { this->os << ","; });
      os << "]";
    } else {
      auto& dialect = attr.dialect();
      dialect.PrintAttribute(attr, os);
    }
  }

 public:
106 107 108
  std::ostream& os;
};

109
class IRPrinter : public BasicIRPrinter {
110
 public:
111 112 113 114 115
  explicit IRPrinter(std::ostream& os) : BasicIRPrinter(os) {}

  /// @brief print program
  /// @param program
  /// @example
116
  void PrintProgram(Program* program) { PrintOperation(program->module_op()); }
117 118 119 120

  /// @brief print operation
  /// @param op
  /// @example
121
  void PrintOperation(Operation* op) {
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    for (size_t i = 0; i < op->num_regions(); ++i) {
      auto& region = op->GetRegion(i);
      for (auto it = region.begin(); it != region.end(); ++it) {
        auto* block = *it;
        os << "{\n";
        for (auto it = block->begin(); it != block->end(); ++it) {
          auto* op = *it;
          // TODO(lyk): add API to get opresults directly
          PrintOpResult(op);
          os << " =";

          os << " \"" << op->name() << "\"";

          // TODO(lyk): add API to get operands directly
          PrintOpOperands(op);

138
          PrintAttributeMap(op);
139 140 141 142 143 144 145 146 147 148 149 150 151
          os << " :";

          // PrintOpSingature
          PrintOperandsType(op);
          os << " -> ";

          // TODO(lyk): add API to get opresults directly
          PrintOpReturnType(op);

          os << newline;
        }
        os << "}\n";
      }
152 153 154
    }
  }

155
 private:
156
  void PrintValue(Value v) {
K
kangguangli 已提交
157 158 159 160
    if (!v) {
      os << "<<NULL VALUE>>";
      return;
    }
161
    const void* key = static_cast<const void*>(v.impl());
162 163
    auto ret = aliases_.find(key);
    if (ret != aliases_.end()) {
164 165 166 167
      os << ret->second;
      return;
    }

168 169 170
    std::string new_name = "%" + std::to_string(cur_var_number_);
    cur_var_number_++;
    aliases_[key] = new_name;
171 172 173
    os << new_name;
  }

174
  void PrintOpResult(Operation* op) {
175 176
    os << " (";
    auto num_op_result = op->num_results();
177
    std::vector<OpResult> op_results;
178 179 180 181 182 183 184
    op_results.reserve(num_op_result);
    for (size_t idx = 0; idx < num_op_result; idx++) {
      op_results.push_back(op->GetResultByIndex(idx));
    }
    PrintInterleave(
        op_results.begin(),
        op_results.end(),
185
        [this](Value v) { this->PrintValue(v); },
186 187
        [this]() { this->os << ", "; });
    os << ")";
188 189
  }

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  void PrintAttributeMap(Operation* op) {
    os << " {";

    PrintInterleave(
        op->attributes().begin(),
        op->attributes().end(),
        [this](std::pair<std::string, Attribute> it) {
          this->os << it.first;
          this->os << ":";
          this->PrintAttribute(it.second);
        },
        [this]() { this->os << ","; });

    os << "}";
  }

  void PrintOpOperands(Operation* op) {
207 208
    os << " (";
    auto num_op_operands = op->num_operands();
209
    std::vector<Value> op_operands;
210 211
    op_operands.reserve(num_op_operands);
    for (size_t idx = 0; idx < num_op_operands; idx++) {
212
      op_operands.push_back(op->GetOperandByIndex(idx).source());
213 214 215 216
    }
    PrintInterleave(
        op_operands.begin(),
        op_operands.end(),
217
        [this](Value v) { this->PrintValue(v); },
218 219
        [this]() { this->os << ", "; });
    os << ")";
220 221
  }

222
  void PrintOperandsType(Operation* op) {
223
    auto num_op_operands = op->num_operands();
224
    std::vector<Type> op_operand_types;
225 226
    op_operand_types.reserve(num_op_operands);
    for (size_t idx = 0; idx < num_op_operands; idx++) {
K
kangguangli 已提交
227 228 229 230
      auto op_operand = op->GetOperandByIndex(idx);
      if (op_operand) {
        op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
      } else {
231
        op_operand_types.push_back(Type(nullptr));
K
kangguangli 已提交
232
      }
233
    }
234
    os << " (";
235 236 237
    PrintInterleave(
        op_operand_types.begin(),
        op_operand_types.end(),
238
        [this](Type t) { this->PrintType(t); },
239 240
        [this]() { this->os << ", "; });
    os << ")";
241 242
  }

243
  void PrintOpReturnType(Operation* op) {
244
    auto num_op_result = op->num_results();
245
    std::vector<Type> op_result_types;
246 247
    op_result_types.reserve(num_op_result);
    for (size_t idx = 0; idx < num_op_result; idx++) {
K
kangguangli 已提交
248 249 250 251
      auto op_result = op->GetResultByIndex(idx);
      if (op_result) {
        op_result_types.push_back(op_result.type());
      } else {
252
        op_result_types.push_back(Type(nullptr));
K
kangguangli 已提交
253
      }
254 255 256 257
    }
    PrintInterleave(
        op_result_types.begin(),
        op_result_types.end(),
258
        [this](Type t) { this->PrintType(t); },
259
        [this]() { this->os << ", "; });
260 261 262
  }

 private:
263 264
  size_t cur_var_number_{0};
  std::unordered_map<const void*, std::string> aliases_;
265 266
};

Y
Yuanle Liu 已提交
267
void Program::Print(std::ostream& os) {
268 269 270 271 272 273 274 275 276
  IRPrinter printer(os);
  printer.PrintProgram(this);
}

void Operation::Print(std::ostream& os) {
  IRPrinter printer(os);
  printer.PrintOperation(this);
}

Y
Yuanle Liu 已提交
277
void Type::Print(std::ostream& os) const {
278 279
  BasicIRPrinter printer(os);
  printer.PrintType(*this);
280 281
}

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
void Attribute::Print(std::ostream& os) const {
  BasicIRPrinter printer(os);
  printer.PrintAttribute(*this);
}

std::ostream& operator<<(std::ostream& os, Type type) {
  type.Print(os);
  return os;
}

std::ostream& operator<<(std::ostream& os, Attribute attr) {
  attr.Print(os);
  return os;
}

297
}  // namespace ir