ir_printer.cc 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <list>
#include <ostream>
#include <string>
#include <unordered_map>

20
#include "paddle/ir/core/block.h"
21 22 23
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
24
#include "paddle/ir/core/ir_printer.h"
25 26
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
27
#include "paddle/ir/core/utils.h"
28
#include "paddle/ir/core/value.h"
29 30 31 32

namespace ir {

namespace {
33
constexpr char newline[] = "\n";  // NOLINT
34 35
}  // namespace

36 37 38 39 40
void BasicIrPrinter::PrintType(Type type) {
  if (!type) {
    os << "<<NULL TYPE>>";
    return;
  }
K
kangguangli 已提交
41

K
kangguangli 已提交
42 43 44
  if (type.isa<BFloat16Type>()) {
    os << "bf16";
  } else if (type.isa<Float16Type>()) {
45 46 47 48 49
    os << "f16";
  } else if (type.isa<Float32Type>()) {
    os << "f32";
  } else if (type.isa<Float64Type>()) {
    os << "f64";
K
kangguangli 已提交
50 51 52 53 54 55
  } else if (type.isa<BoolType>()) {
    os << "b";
  } else if (type.isa<Int8Type>()) {
    os << "i8";
  } else if (type.isa<UInt8Type>()) {
    os << "u8";
56 57 58 59 60 61
  } else if (type.isa<Int16Type>()) {
    os << "i16";
  } else if (type.isa<Int32Type>()) {
    os << "i32";
  } else if (type.isa<Int64Type>()) {
    os << "i64";
K
kangguangli 已提交
62 63 64 65
  } else if (type.isa<Complex64Type>()) {
    os << "c64";
  } else if (type.isa<Complex128Type>()) {
    os << "c128";
66 67 68 69 70 71 72 73 74 75 76 77
  } else if (type.isa<VectorType>()) {
    os << "vec[";
    auto inner_types = type.dyn_cast<VectorType>().data();
    PrintInterleave(
        inner_types.begin(),
        inner_types.end(),
        [this](Type v) { this->PrintType(v); },
        [this]() { this->os << ","; });
    os << "]";
  } else {
    auto& dialect = type.dialect();
    dialect.PrintType(type, os);
78
  }
79
}
80

X
xingmingyyj 已提交
81
void BasicIrPrinter::PrintAttribute(Attribute attr) {
82 83 84 85
  if (!attr) {
    os << "<#AttrNull>";
    return;
  }
86

87
  if (auto s = attr.dyn_cast<StrAttribute>()) {
88
    os << s.AsString();
89 90 91 92 93 94 95 96 97 98 99 100 101
  } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
    os << b.data();
  } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
    os << f.data();
  } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
    os << d.data();
  } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
    os << i.data();
  } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
    os << i.data();
  } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
    os << p.data();
  } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
102
    const auto& vec = arr.AsVector();
103 104 105 106 107 108 109
    os << "array[";
    PrintInterleave(
        vec.begin(),
        vec.end(),
        [this](Attribute v) { this->PrintAttribute(v); },
        [this]() { this->os << ","; });
    os << "]";
W
Wilber 已提交
110 111
  } else if (auto type = attr.dyn_cast<TypeAttribute>()) {
    os << type.data();
112 113 114
  } else {
    auto& dialect = attr.dialect();
    dialect.PrintAttribute(attr, os);
115
  }
116
}
117

118
void IrPrinter::PrintProgram(const Program* program) {
119 120
  auto top_level_op = program->module_op();
  for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
121
    auto& region = top_level_op->region(i);
X
xingmingyyj 已提交
122
    PrintRegion(region);
123
  }
124
}
125

126
void IrPrinter::PrintOperation(const Operation* op) {
127 128 129 130 131 132 133 134
  if (auto* dialect = op->dialect()) {
    dialect->PrintOperation(op, *this);
    return;
  }

  PrintGeneralOperation(op);
}

135
void IrPrinter::PrintGeneralOperation(const Operation* op) {
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  // TODO(lyk): add API to get opresults directly
  PrintOpResult(op);
  os << " =";

  os << " \"" << op->name() << "\"";

  // TODO(lyk): add API to get operands directly
  PrintOpOperands(op);

  PrintAttributeMap(op);
  os << " :";

  // PrintOpSingature
  PrintOperandsType(op);
  os << " -> ";

  // TODO(lyk): add API to get opresults directly
  PrintOpReturnType(op);
}

156
void IrPrinter::PrintFullOperation(const Operation* op) {
157 158 159 160 161
  PrintOperation(op);
  if (op->num_regions() > 0) {
    os << newline;
  }
  for (size_t i = 0; i < op->num_regions(); ++i) {
162
    auto& region = op->region(i);
163 164 165
    PrintRegion(region);
  }
}
166

167
void IrPrinter::PrintRegion(const Region& region) {
168
  for (auto block : region) {
169
    PrintBlock(block);
170
  }
171
}
172

173
void IrPrinter::PrintBlock(const Block* block) {
174
  os << "{\n";
175 176
  for (auto item : *block) {
    PrintOperation(item);
177
    os << newline;
178
  }
179 180
  os << "}\n";
}
181

182
void IrPrinter::PrintValue(const Value& v) {
183 184 185 186 187 188 189 190 191 192
  if (!v) {
    os << "<<NULL VALUE>>";
    return;
  }
  const void* key = static_cast<const void*>(v.impl());
  auto ret = aliases_.find(key);
  if (ret != aliases_.end()) {
    os << ret->second;
    return;
  }
193

194 195 196 197 198
  std::string new_name = "%" + std::to_string(cur_var_number_);
  cur_var_number_++;
  aliases_[key] = new_name;
  os << new_name;
}
199

200
void IrPrinter::PrintOpResult(const Operation* op) {
201 202 203 204 205
  os << " (";
  auto num_op_result = op->num_results();
  std::vector<OpResult> op_results;
  op_results.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
206
    op_results.push_back(op->result(idx));
207
  }
208 209 210 211 212 213 214
  PrintInterleave(
      op_results.begin(),
      op_results.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
215

216
void IrPrinter::PrintAttributeMap(const Operation* op) {
X
xingmingyyj 已提交
217 218 219
  AttributeMap attributes = op->attributes();
  std::map<std::string, Attribute, std::less<std::string>> order_attributes(
      attributes.begin(), attributes.end());
220 221 222
  os << " {";

  PrintInterleave(
X
xingmingyyj 已提交
223 224
      order_attributes.begin(),
      order_attributes.end(),
225 226 227 228 229 230 231 232 233 234
      [this](std::pair<std::string, Attribute> it) {
        this->os << it.first;
        this->os << ":";
        this->PrintAttribute(it.second);
      },
      [this]() { this->os << ","; });

  os << "}";
}

235
void IrPrinter::PrintOpOperands(const Operation* op) {
236 237 238 239 240
  os << " (";
  auto num_op_operands = op->num_operands();
  std::vector<Value> op_operands;
  op_operands.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
241
    op_operands.push_back(op->operand_source(idx));
242
  }
243 244 245 246 247 248 249
  PrintInterleave(
      op_operands.begin(),
      op_operands.end(),
      [this](Value v) { this->PrintValue(v); },
      [this]() { this->os << ", "; });
  os << ")";
}
250

251
void IrPrinter::PrintOperandsType(const Operation* op) {
252 253 254 255
  auto num_op_operands = op->num_operands();
  std::vector<Type> op_operand_types;
  op_operand_types.reserve(num_op_operands);
  for (size_t idx = 0; idx < num_op_operands; idx++) {
256
    auto op_operand = op->operand(idx);
257
    if (op_operand) {
258
      op_operand_types.push_back(op_operand.type());
259
    } else {
260
      op_operand_types.emplace_back();
261 262
    }
  }
263 264 265 266 267 268 269 270
  os << " (";
  PrintInterleave(
      op_operand_types.begin(),
      op_operand_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
  os << ")";
}
271

272
void IrPrinter::PrintOpReturnType(const Operation* op) {
273 274 275 276
  auto num_op_result = op->num_results();
  std::vector<Type> op_result_types;
  op_result_types.reserve(num_op_result);
  for (size_t idx = 0; idx < num_op_result; idx++) {
277
    auto op_result = op->result(idx);
278 279 280
    if (op_result) {
      op_result_types.push_back(op_result.type());
    } else {
281
      op_result_types.emplace_back(nullptr);
282 283
    }
  }
284 285 286 287 288 289
  PrintInterleave(
      op_result_types.begin(),
      op_result_types.end(),
      [this](Type t) { this->PrintType(t); },
      [this]() { this->os << ", "; });
}
290

291
void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const {
292 293
  printer.PrintGeneralOperation(op);
}
294

295
void Program::Print(std::ostream& os) const {
296
  IrPrinter printer(os);
297 298 299
  printer.PrintProgram(this);
}

300
void Operation::Print(std::ostream& os) const {
301 302
  IrPrinter printer(os);
  printer.PrintFullOperation(this);
303 304
}

Y
Yuanle Liu 已提交
305
void Type::Print(std::ostream& os) const {
306
  BasicIrPrinter printer(os);
307
  printer.PrintType(*this);
308 309
}

310
void Attribute::Print(std::ostream& os) const {
311
  BasicIrPrinter printer(os);
312 313 314 315 316 317 318 319 320 321 322 323 324
  printer.PrintAttribute(*this);
}

std::ostream& operator<<(std::ostream& os, Type type) {
  type.Print(os);
  return os;
}

std::ostream& operator<<(std::ostream& os, Attribute attr) {
  attr.Print(os);
  return os;
}

325
}  // namespace ir