ir_printer.cc 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <list>
#include <ostream>
#include <string>
#include <unordered_map>

20
#include "paddle/ir/core/block.h"
21 22 23
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
24
#include "paddle/ir/core/ir_printer.h"
25 26
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
27
#include "paddle/ir/core/utils.h"
28
#include "paddle/ir/core/value.h"
29 30 31 32 33 34 35

namespace ir {

namespace {
constexpr char newline[] = "\n";
}  // namespace

36 37 38 39 40
void BasicIrPrinter::PrintType(Type type) {
  if (!type) {
    os << "<<NULL TYPE>>";
    return;
  }
K
kangguangli 已提交
41

K
kangguangli 已提交
42 43 44
  if (type.isa<BFloat16Type>()) {
    os << "bf16";
  } else if (type.isa<Float16Type>()) {
45 46 47 48 49
    os << "f16";
  } else if (type.isa<Float32Type>()) {
    os << "f32";
  } else if (type.isa<Float64Type>()) {
    os << "f64";
K
kangguangli 已提交
50 51 52 53 54 55
  } else if (type.isa<BoolType>()) {
    os << "b";
  } else if (type.isa<Int8Type>()) {
    os << "i8";
  } else if (type.isa<UInt8Type>()) {
    os << "u8";
56 57 58 59 60 61
  } else if (type.isa<Int16Type>()) {
    os << "i16";
  } else if (type.isa<Int32Type>()) {
    os << "i32";
  } else if (type.isa<Int64Type>()) {
    os << "i64";
K
kangguangli 已提交
62 63 64 65
  } else if (type.isa<Complex64Type>()) {
    os << "c64";
  } else if (type.isa<Complex128Type>()) {
    os << "c128";
66 67 68 69 70 71 72 73 74 75 76 77
  } else if (type.isa<VectorType>()) {
    os << "vec[";
    auto inner_types = type.dyn_cast<VectorType>().data();
    PrintInterleave(
        inner_types.begin(),
        inner_types.end(),
        [this](Type v) { this->PrintType(v); },
        [this]() { this->os << ","; });
    os << "]";
  } else {
    auto& dialect = type.dialect();
    dialect.PrintType(type, os);
78
  }
79
}
80

81 82 83 84 85
void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
  if (!attr) {
    os << "<#AttrNull>";
    return;
  }
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  if (auto s = attr.dyn_cast<StrAttribute>()) {
    os << s.data();
  } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
    os << b.data();
  } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
    os << f.data();
  } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
    os << d.data();
  } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
    os << i.data();
  } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
    os << i.data();
  } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
    os << p.data();
  } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
    const auto& vec = arr.data();
    os << "array[";
    PrintInterleave(
        vec.begin(),
        vec.end(),
        [this](Attribute v) { this->PrintAttribute(v); },
        [this]() { this->os << ","; });
    os << "]";
W
Wilber 已提交
110 111
  } else if (auto type = attr.dyn_cast<TypeAttribute>()) {
    os << type.data();
112 113 114
  } else {
    auto& dialect = attr.dialect();
    dialect.PrintAttribute(attr, os);
115
  }
116
}
117

118 119 120
void IrPrinter::PrintProgram(Program* program) {
  auto top_level_op = program->module_op();
  for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
121
    auto& region = top_level_op->region(i);
122 123 124 125 126 127
    for (auto it = region.begin(); it != region.end(); ++it) {
      auto* block = *it;
      os << "{\n";
      for (auto it = block->begin(); it != block->end(); ++it) {
        PrintOperation(*it);
        os << newline;
128
      }
129
      os << "}\n";
130 131
    }
  }
132
}
133

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
void IrPrinter::PrintOperation(Operation* op) {
  if (auto* dialect = op->dialect()) {
    dialect->PrintOperation(op, *this);
    return;
  }

  PrintGeneralOperation(op);
}

void IrPrinter::PrintGeneralOperation(Operation* op) {
  // TODO(lyk): add API to get opresults directly
  PrintOpResult(op);
  os << " =";

  os << " \"" << op->name() << "\"";

  // TODO(lyk): add API to get operands directly
  PrintOpOperands(op);

  PrintAttributeMap(op);
  os << " :";

  // PrintOpSingature
  PrintOperandsType(op);
  os << " -> ";

  // TODO(lyk): add API to get opresults directly
  PrintOpReturnType(op);
}

void IrPrinter::PrintFullOperation(Operation* op) {
  PrintOperation(op);
  if (op->num_regions() > 0) {
    os << newline;
  }
  for (size_t i = 0; i < op->num_regions(); ++i) {
170
    auto& region = op->region(i);
171 172 173
    PrintRegion(region);
  }
}
174

175 176 177 178
void IrPrinter::PrintRegion(const Region& region) {
  for (auto it = region.begin(); it != region.end(); ++it) {
    auto* block = *it;
    PrintBlock(block);
179
  }
180
}
181

182 183 184 185 186
void IrPrinter::PrintBlock(Block* block) {
  os << "{\n";
  for (auto it = block->begin(); it != block->end(); ++it) {
    PrintOperation(*it);
    os << newline;
187
  }
188 189
  os << "}\n";
}
190

191 192 193 194 195 196 197 198 199 200 201
void IrPrinter::PrintValue(Value v) {
  if (!v) {
    os << "<<NULL VALUE>>";
    return;
  }
  const void* key = static_cast<const void*>(v.impl());
  auto ret = aliases_.find(key);
  if (ret != aliases_.end()) {
    os << ret->second;
    return;
  }
202

203 204 205 206 207
  std::string new_name = "%" + std::to_string(cur_var_number_);
  cur_var_number_++;
  aliases_[key] = new_name;
  os << new_name;
}
208

209 210 211 212 213 214
void IrPrinter::PrintOpResult(Operation* op) {
  os << " (";
  auto num_op_result = op->num_results();
  std::vector<OpResult> op_results;
  op_results.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
215
    op_results.push_back(op->result(idx));
216
  }
217 218 219 220 221 222 223
  PrintInterleave(
      op_results.begin(),
      op_results.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
224

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
void IrPrinter::PrintAttributeMap(Operation* op) {
  os << " {";

  PrintInterleave(
      op->attributes().begin(),
      op->attributes().end(),
      [this](std::pair<std::string, Attribute> it) {
        this->os << it.first;
        this->os << ":";
        this->PrintAttribute(it.second);
      },
      [this]() { this->os << ","; });

  os << "}";
}

void IrPrinter::PrintOpOperands(Operation* op) {
  os << " (";
  auto num_op_operands = op->num_operands();
  std::vector<Value> op_operands;
  op_operands.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
247
    op_operands.push_back(op->operand(idx));
248
  }
249 250 251 252 253 254 255
  PrintInterleave(
      op_operands.begin(),
      op_operands.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
256

257 258 259 260 261
void IrPrinter::PrintOperandsType(Operation* op) {
  auto num_op_operands = op->num_operands();
  std::vector<Type> op_operand_types;
  op_operand_types.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
262
    auto op_operand = op->op_operand(idx);
263
    if (op_operand) {
264
      op_operand_types.push_back(op_operand.type());
265
    } else {
266
      op_operand_types.push_back(Type());
267 268
    }
  }
269 270 271 272 273 274 275 276
  os << " (";
  PrintInterleave(
      op_operand_types.begin(),
      op_operand_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
  os << ")";
}
277

278 279 280 281 282
void IrPrinter::PrintOpReturnType(Operation* op) {
  auto num_op_result = op->num_results();
  std::vector<Type> op_result_types;
  op_result_types.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
283
    auto op_result = op->result(idx);
284 285 286 287
    if (op_result) {
      op_result_types.push_back(op_result.type());
    } else {
      op_result_types.push_back(Type(nullptr));
288 289
    }
  }
290 291 292 293 294 295
  PrintInterleave(
      op_result_types.begin(),
      op_result_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
}
296

297 298 299
void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
  printer.PrintGeneralOperation(op);
}
300

Y
Yuanle Liu 已提交
301
void Program::Print(std::ostream& os) {
302
  IrPrinter printer(os);
303 304 305 306
  printer.PrintProgram(this);
}

void Operation::Print(std::ostream& os) {
307 308
  IrPrinter printer(os);
  printer.PrintFullOperation(this);
309 310
}

Y
Yuanle Liu 已提交
311
void Type::Print(std::ostream& os) const {
312
  BasicIrPrinter printer(os);
313
  printer.PrintType(*this);
314 315
}

316
void Attribute::Print(std::ostream& os) const {
317
  BasicIrPrinter printer(os);
318 319 320 321 322 323 324 325 326 327 328 329 330
  printer.PrintAttribute(*this);
}

std::ostream& operator<<(std::ostream& os, Type type) {
  type.Print(os);
  return os;
}

std::ostream& operator<<(std::ostream& os, Attribute attr) {
  attr.Print(os);
  return os;
}

331
}  // namespace ir