ir_printer.cc 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <list>
#include <ostream>
#include <string>
#include <unordered_map>

20
#include "paddle/ir/core/block.h"
21 22 23
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
24
#include "paddle/ir/core/ir_printer.h"
25 26
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
27
#include "paddle/ir/core/utils.h"
28
#include "paddle/ir/core/value.h"
29 30 31 32

namespace ir {

namespace {
33
constexpr char newline[] = "\n";  // NOLINT
34 35
}  // namespace

36 37 38 39 40
void BasicIrPrinter::PrintType(Type type) {
  if (!type) {
    os << "<<NULL TYPE>>";
    return;
  }
K
kangguangli 已提交
41

K
kangguangli 已提交
42 43 44
  if (type.isa<BFloat16Type>()) {
    os << "bf16";
  } else if (type.isa<Float16Type>()) {
45 46 47 48 49
    os << "f16";
  } else if (type.isa<Float32Type>()) {
    os << "f32";
  } else if (type.isa<Float64Type>()) {
    os << "f64";
K
kangguangli 已提交
50 51 52 53 54 55
  } else if (type.isa<BoolType>()) {
    os << "b";
  } else if (type.isa<Int8Type>()) {
    os << "i8";
  } else if (type.isa<UInt8Type>()) {
    os << "u8";
56 57 58 59 60 61
  } else if (type.isa<Int16Type>()) {
    os << "i16";
  } else if (type.isa<Int32Type>()) {
    os << "i32";
  } else if (type.isa<Int64Type>()) {
    os << "i64";
B
Bo Zhang 已提交
62 63
  } else if (type.isa<IndexType>()) {
    os << "index";
K
kangguangli 已提交
64 65 66 67
  } else if (type.isa<Complex64Type>()) {
    os << "c64";
  } else if (type.isa<Complex128Type>()) {
    os << "c128";
68 69 70 71 72 73 74 75 76 77 78 79
  } else if (type.isa<VectorType>()) {
    os << "vec[";
    auto inner_types = type.dyn_cast<VectorType>().data();
    PrintInterleave(
        inner_types.begin(),
        inner_types.end(),
        [this](Type v) { this->PrintType(v); },
        [this]() { this->os << ","; });
    os << "]";
  } else {
    auto& dialect = type.dialect();
    dialect.PrintType(type, os);
80
  }
81
}
82

X
xingmingyyj 已提交
83
void BasicIrPrinter::PrintAttribute(Attribute attr) {
84 85 86 87
  if (!attr) {
    os << "<#AttrNull>";
    return;
  }
88

89
  if (auto s = attr.dyn_cast<StrAttribute>()) {
90
    os << s.AsString();
91 92 93 94 95 96 97 98 99 100 101 102 103
  } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
    os << b.data();
  } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
    os << f.data();
  } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
    os << d.data();
  } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
    os << i.data();
  } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
    os << i.data();
  } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
    os << p.data();
  } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
104
    const auto& vec = arr.AsVector();
105 106 107 108 109 110 111
    os << "array[";
    PrintInterleave(
        vec.begin(),
        vec.end(),
        [this](Attribute v) { this->PrintAttribute(v); },
        [this]() { this->os << ","; });
    os << "]";
W
Wilber 已提交
112 113
  } else if (auto type = attr.dyn_cast<TypeAttribute>()) {
    os << type.data();
114 115 116
  } else {
    auto& dialect = attr.dialect();
    dialect.PrintAttribute(attr, os);
117
  }
118
}
119

120
void IrPrinter::PrintProgram(const Program* program) {
121 122
  auto top_level_op = program->module_op();
  for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
123
    auto& region = top_level_op->region(i);
X
xingmingyyj 已提交
124
    PrintRegion(region);
125
  }
126
}
127

128
void IrPrinter::PrintOperation(const Operation* op) {
129 130 131 132 133 134 135 136
  if (auto* dialect = op->dialect()) {
    dialect->PrintOperation(op, *this);
    return;
  }

  PrintGeneralOperation(op);
}

137
void IrPrinter::PrintGeneralOperation(const Operation* op) {
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
  // TODO(lyk): add API to get opresults directly
  PrintOpResult(op);
  os << " =";

  os << " \"" << op->name() << "\"";

  // TODO(lyk): add API to get operands directly
  PrintOpOperands(op);

  PrintAttributeMap(op);
  os << " :";

  // PrintOpSingature
  PrintOperandsType(op);
  os << " -> ";

  // TODO(lyk): add API to get opresults directly
  PrintOpReturnType(op);
}

158
void IrPrinter::PrintFullOperation(const Operation* op) {
159 160 161 162 163
  PrintOperation(op);
  if (op->num_regions() > 0) {
    os << newline;
  }
  for (size_t i = 0; i < op->num_regions(); ++i) {
164
    auto& region = op->region(i);
165 166 167
    PrintRegion(region);
  }
}
168

169
void IrPrinter::PrintRegion(const Region& region) {
170
  for (auto block : region) {
171
    PrintBlock(block);
172
  }
173
}
174

175
void IrPrinter::PrintBlock(const Block* block) {
176
  os << "{\n";
177 178
  for (auto item : *block) {
    PrintOperation(item);
179
    os << newline;
180
  }
181 182
  os << "}\n";
}
183

184
void IrPrinter::PrintValue(const Value& v) {
185 186 187 188 189 190 191 192 193 194
  if (!v) {
    os << "<<NULL VALUE>>";
    return;
  }
  const void* key = static_cast<const void*>(v.impl());
  auto ret = aliases_.find(key);
  if (ret != aliases_.end()) {
    os << ret->second;
    return;
  }
195

196 197 198 199 200
  std::string new_name = "%" + std::to_string(cur_var_number_);
  cur_var_number_++;
  aliases_[key] = new_name;
  os << new_name;
}
201

202
void IrPrinter::PrintOpResult(const Operation* op) {
203 204 205 206 207
  os << " (";
  auto num_op_result = op->num_results();
  std::vector<OpResult> op_results;
  op_results.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
208
    op_results.push_back(op->result(idx));
209
  }
210 211 212 213 214 215 216
  PrintInterleave(
      op_results.begin(),
      op_results.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
217

218
void IrPrinter::PrintAttributeMap(const Operation* op) {
X
xingmingyyj 已提交
219 220 221
  AttributeMap attributes = op->attributes();
  std::map<std::string, Attribute, std::less<std::string>> order_attributes(
      attributes.begin(), attributes.end());
222 223 224
  os << " {";

  PrintInterleave(
X
xingmingyyj 已提交
225 226
      order_attributes.begin(),
      order_attributes.end(),
227 228 229 230 231 232 233 234 235 236
      [this](std::pair<std::string, Attribute> it) {
        this->os << it.first;
        this->os << ":";
        this->PrintAttribute(it.second);
      },
      [this]() { this->os << ","; });

  os << "}";
}

237
void IrPrinter::PrintOpOperands(const Operation* op) {
238 239 240 241 242
  os << " (";
  auto num_op_operands = op->num_operands();
  std::vector<Value> op_operands;
  op_operands.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
243
    op_operands.push_back(op->operand_source(idx));
244
  }
245 246 247 248 249 250 251
  PrintInterleave(
      op_operands.begin(),
      op_operands.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
252

253
void IrPrinter::PrintOperandsType(const Operation* op) {
254 255 256 257
  auto num_op_operands = op->num_operands();
  std::vector<Type> op_operand_types;
  op_operand_types.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
258
    auto op_operand = op->operand(idx);
259
    if (op_operand) {
260
      op_operand_types.push_back(op_operand.type());
261
    } else {
262
      op_operand_types.emplace_back();
263 264
    }
  }
265 266 267 268 269 270 271 272
  os << " (";
  PrintInterleave(
      op_operand_types.begin(),
      op_operand_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
  os << ")";
}
273

274
void IrPrinter::PrintOpReturnType(const Operation* op) {
275 276 277 278
  auto num_op_result = op->num_results();
  std::vector<Type> op_result_types;
  op_result_types.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
279
    auto op_result = op->result(idx);
280 281 282
    if (op_result) {
      op_result_types.push_back(op_result.type());
    } else {
283
      op_result_types.emplace_back(nullptr);
284 285
    }
  }
286 287 288 289 290 291
  PrintInterleave(
      op_result_types.begin(),
      op_result_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
}
292

293
void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const {
294 295
  printer.PrintGeneralOperation(op);
}
296

297
void Program::Print(std::ostream& os) const {
298
  IrPrinter printer(os);
299 300 301
  printer.PrintProgram(this);
}

302
void Operation::Print(std::ostream& os) const {
303 304
  IrPrinter printer(os);
  printer.PrintFullOperation(this);
305 306
}

Y
Yuanle Liu 已提交
307
void Type::Print(std::ostream& os) const {
308
  BasicIrPrinter printer(os);
309
  printer.PrintType(*this);
310 311
}

312
void Attribute::Print(std::ostream& os) const {
313
  BasicIrPrinter printer(os);
314 315 316 317 318 319 320 321 322 323 324 325 326
  printer.PrintAttribute(*this);
}

std::ostream& operator<<(std::ostream& os, Type type) {
  type.Print(os);
  return os;
}

std::ostream& operator<<(std::ostream& os, Attribute attr) {
  attr.Print(os);
  return os;
}

327 328 329 330 331
std::ostream& operator<<(std::ostream& os, const Program& prog) {
  prog.Print(os);
  return os;
}

332
}  // namespace ir