ir_printer.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <list>
#include <ostream>
#include <string>
#include <unordered_map>

20
#include "paddle/ir/core/block.h"
21 22 23
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
24
#include "paddle/ir/core/ir_printer.h"
25 26
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
27
#include "paddle/ir/core/utils.h"
28
#include "paddle/ir/core/value.h"
29 30 31 32 33 34 35

namespace ir {

namespace {
constexpr char newline[] = "\n";
}  // namespace

36 37 38 39 40
void BasicIrPrinter::PrintType(Type type) {
  if (!type) {
    os << "<<NULL TYPE>>";
    return;
  }
K
kangguangli 已提交
41

K
kangguangli 已提交
42 43 44
  if (type.isa<BFloat16Type>()) {
    os << "bf16";
  } else if (type.isa<Float16Type>()) {
45 46 47 48 49
    os << "f16";
  } else if (type.isa<Float32Type>()) {
    os << "f32";
  } else if (type.isa<Float64Type>()) {
    os << "f64";
K
kangguangli 已提交
50 51 52 53 54 55
  } else if (type.isa<BoolType>()) {
    os << "b";
  } else if (type.isa<Int8Type>()) {
    os << "i8";
  } else if (type.isa<UInt8Type>()) {
    os << "u8";
56 57 58 59 60 61
  } else if (type.isa<Int16Type>()) {
    os << "i16";
  } else if (type.isa<Int32Type>()) {
    os << "i32";
  } else if (type.isa<Int64Type>()) {
    os << "i64";
K
kangguangli 已提交
62 63 64 65
  } else if (type.isa<Complex64Type>()) {
    os << "c64";
  } else if (type.isa<Complex128Type>()) {
    os << "c128";
66 67 68 69 70 71 72 73 74 75 76 77
  } else if (type.isa<VectorType>()) {
    os << "vec[";
    auto inner_types = type.dyn_cast<VectorType>().data();
    PrintInterleave(
        inner_types.begin(),
        inner_types.end(),
        [this](Type v) { this->PrintType(v); },
        [this]() { this->os << ","; });
    os << "]";
  } else {
    auto& dialect = type.dialect();
    dialect.PrintType(type, os);
78
  }
79
}
80

X
xingmingyyj 已提交
81
void BasicIrPrinter::PrintAttribute(Attribute attr) {
82 83 84 85
  if (!attr) {
    os << "<#AttrNull>";
    return;
  }
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  if (auto s = attr.dyn_cast<StrAttribute>()) {
    os << s.data();
  } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
    os << b.data();
  } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
    os << f.data();
  } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
    os << d.data();
  } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
    os << i.data();
  } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
    os << i.data();
  } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
    os << p.data();
  } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
    const auto& vec = arr.data();
    os << "array[";
    PrintInterleave(
        vec.begin(),
        vec.end(),
        [this](Attribute v) { this->PrintAttribute(v); },
        [this]() { this->os << ","; });
    os << "]";
W
Wilber 已提交
110 111
  } else if (auto type = attr.dyn_cast<TypeAttribute>()) {
    os << type.data();
112 113 114
  } else {
    auto& dialect = attr.dialect();
    dialect.PrintAttribute(attr, os);
115
  }
116
}
117

118
void IrPrinter::PrintProgram(const Program* program) {
119 120
  auto top_level_op = program->module_op();
  for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
121
    auto& region = top_level_op->region(i);
X
xingmingyyj 已提交
122
    PrintRegion(region);
123
  }
124
}
125

126
void IrPrinter::PrintOperation(const Operation* op) {
127 128 129 130 131 132 133 134
  if (auto* dialect = op->dialect()) {
    dialect->PrintOperation(op, *this);
    return;
  }

  PrintGeneralOperation(op);
}

135
void IrPrinter::PrintGeneralOperation(const Operation* op) {
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  // TODO(lyk): add API to get opresults directly
  PrintOpResult(op);
  os << " =";

  os << " \"" << op->name() << "\"";

  // TODO(lyk): add API to get operands directly
  PrintOpOperands(op);

  PrintAttributeMap(op);
  os << " :";

  // PrintOpSingature
  PrintOperandsType(op);
  os << " -> ";

  // TODO(lyk): add API to get opresults directly
  PrintOpReturnType(op);
}

156
void IrPrinter::PrintFullOperation(const Operation* op) {
157 158 159 160 161
  PrintOperation(op);
  if (op->num_regions() > 0) {
    os << newline;
  }
  for (size_t i = 0; i < op->num_regions(); ++i) {
162
    auto& region = op->region(i);
163 164 165
    PrintRegion(region);
  }
}
166

167 168 169 170
void IrPrinter::PrintRegion(const Region& region) {
  for (auto it = region.begin(); it != region.end(); ++it) {
    auto* block = *it;
    PrintBlock(block);
171
  }
172
}
173

174
void IrPrinter::PrintBlock(const Block* block) {
175 176 177 178
  os << "{\n";
  for (auto it = block->begin(); it != block->end(); ++it) {
    PrintOperation(*it);
    os << newline;
179
  }
180 181
  os << "}\n";
}
182

183
void IrPrinter::PrintValue(const Value& v) {
184 185 186 187 188 189 190 191 192 193
  if (!v) {
    os << "<<NULL VALUE>>";
    return;
  }
  const void* key = static_cast<const void*>(v.impl());
  auto ret = aliases_.find(key);
  if (ret != aliases_.end()) {
    os << ret->second;
    return;
  }
194

195 196 197 198 199
  std::string new_name = "%" + std::to_string(cur_var_number_);
  cur_var_number_++;
  aliases_[key] = new_name;
  os << new_name;
}
200

201
void IrPrinter::PrintOpResult(const Operation* op) {
202 203 204 205 206
  os << " (";
  auto num_op_result = op->num_results();
  std::vector<OpResult> op_results;
  op_results.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
207
    op_results.push_back(op->result(idx));
208
  }
209 210 211 212 213 214 215
  PrintInterleave(
      op_results.begin(),
      op_results.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
216

217
void IrPrinter::PrintAttributeMap(const Operation* op) {
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
  os << " {";

  PrintInterleave(
      op->attributes().begin(),
      op->attributes().end(),
      [this](std::pair<std::string, Attribute> it) {
        this->os << it.first;
        this->os << ":";
        this->PrintAttribute(it.second);
      },
      [this]() { this->os << ","; });

  os << "}";
}

233
void IrPrinter::PrintOpOperands(const Operation* op) {
234 235 236 237 238
  os << " (";
  auto num_op_operands = op->num_operands();
  std::vector<Value> op_operands;
  op_operands.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
239
    op_operands.push_back(op->operand(idx));
240
  }
241 242 243 244 245 246 247
  PrintInterleave(
      op_operands.begin(),
      op_operands.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
248

249
void IrPrinter::PrintOperandsType(const Operation* op) {
250 251 252 253
  auto num_op_operands = op->num_operands();
  std::vector<Type> op_operand_types;
  op_operand_types.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
254
    auto op_operand = op->op_operand(idx);
255
    if (op_operand) {
256
      op_operand_types.push_back(op_operand.type());
257
    } else {
258
      op_operand_types.push_back(Type());
259 260
    }
  }
261 262 263 264 265 266 267 268
  os << " (";
  PrintInterleave(
      op_operand_types.begin(),
      op_operand_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
  os << ")";
}
269

270
void IrPrinter::PrintOpReturnType(const Operation* op) {
271 272 273 274
  auto num_op_result = op->num_results();
  std::vector<Type> op_result_types;
  op_result_types.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
275
    auto op_result = op->result(idx);
276 277 278 279
    if (op_result) {
      op_result_types.push_back(op_result.type());
    } else {
      op_result_types.push_back(Type(nullptr));
280 281
    }
  }
282 283 284 285 286 287
  PrintInterleave(
      op_result_types.begin(),
      op_result_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
}
288

289
void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const {
290 291
  printer.PrintGeneralOperation(op);
}
292

293
void Program::Print(std::ostream& os) const {
294
  IrPrinter printer(os);
295 296 297
  printer.PrintProgram(this);
}

298
void Operation::Print(std::ostream& os) const {
299 300
  IrPrinter printer(os);
  printer.PrintFullOperation(this);
301 302
}

Y
Yuanle Liu 已提交
303
void Type::Print(std::ostream& os) const {
304
  BasicIrPrinter printer(os);
305
  printer.PrintType(*this);
306 307
}

308
void Attribute::Print(std::ostream& os) const {
309
  BasicIrPrinter printer(os);
310 311 312 313 314 315 316 317 318 319 320 321 322
  printer.PrintAttribute(*this);
}

std::ostream& operator<<(std::ostream& os, Type type) {
  type.Print(os);
  return os;
}

std::ostream& operator<<(std::ostream& os, Attribute attr) {
  attr.Print(os);
  return os;
}

323
}  // namespace ir