未验证 提交 68b0cf92 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】refactor codegen for cinn (#55955)

* refactor codegen for cinn

* add to_string to some type which can't be += with string

* fix multi-thread bug caused by static var

* delete dead code and comment
上级 db96ae58
...@@ -43,6 +43,7 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) { ...@@ -43,6 +43,7 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) {
if (!outputs.c_header_name.empty()) { if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader); auto source = Compile(module, OutputKind::CHeader);
str_ = "";
std::ofstream file(outputs.c_header_name); std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source; file << source;
...@@ -52,6 +53,7 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) { ...@@ -52,6 +53,7 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) {
if (!outputs.c_source_name.empty()) { if (!outputs.c_source_name.empty()) {
auto source = Compile(module, OutputKind::CImpl); auto source = Compile(module, OutputKind::CImpl);
str_ = "";
std::ofstream file(outputs.c_source_name); std::ofstream file(outputs.c_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name; CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name;
file << source; file << source;
...@@ -71,24 +73,20 @@ std::string CodeGenC::Compile(const ir::Module &module, ...@@ -71,24 +73,20 @@ std::string CodeGenC::Compile(const ir::Module &module,
if (inline_builtin_codes_) PrintBuiltinCodes(); if (inline_builtin_codes_) PrintBuiltinCodes();
std::vector<ir::Buffer> buffers;
for (auto &buffer : module->buffers) {
buffers.emplace_back(buffer.as_buffer_ref());
}
for (auto &func : module.functions()) { for (auto &func : module.functions()) {
Compile(func); Compile(func);
} }
} else { } else {
LOG(FATAL) << "Not supported OutputKind"; LOG(FATAL) << "Not supported OutputKind";
} }
return ss_.str(); return str_;
} }
std::string CodeGenC::Compile(const ir::LoweredFunc &function) {
// TODO(LiuYang): Here the Ret type seems unuseful
void CodeGenC::Compile(const ir::LoweredFunc &function) {
CHECK(function.defined()); CHECK(function.defined());
Print(function); IrPrinter::Visit(function);
os() << "\n\n"; str_ += "\n\n";
return ss_.str();
} }
std::string CodeGenC::GetTypeName(Type type) { std::string CodeGenC::GetTypeName(Type type) {
...@@ -164,11 +162,11 @@ void CodeGenC::Visit(const ir::Mod *op) { ...@@ -164,11 +162,11 @@ void CodeGenC::Visit(const ir::Mod *op) {
if (copied.is_constant()) { if (copied.is_constant()) {
int temp = static_cast<int>(copied.get_constant()); int temp = static_cast<int>(copied.get_constant());
if ((temp & (temp - 1)) == 0) { if ((temp & (temp - 1)) == 0) {
os() << "("; str_ += "(";
Print(op->a()); IrPrinter::Visit(op->a());
os() << " & "; str_ += " & ";
os() << std::to_string(temp - 1); str_ += std::to_string(temp - 1);
os() << ")"; str_ += ")";
return; return;
} }
} }
...@@ -186,9 +184,9 @@ void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); } ...@@ -186,9 +184,9 @@ void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Not *op) { void CodeGenC::Visit(const ir::Not *op) {
os() << "(!"; str_ += "(!";
IrPrinter::Print(op->v()); IrPrinter::Visit(op->v());
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); } void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); }
void CodeGenC::Visit(const ir::For *op) { void CodeGenC::Visit(const ir::For *op) {
...@@ -196,17 +194,17 @@ void CodeGenC::Visit(const ir::For *op) { ...@@ -196,17 +194,17 @@ void CodeGenC::Visit(const ir::For *op) {
Expr min = op->min; Expr min = op->min;
int num_task = 1; int num_task = 1;
if (op->is_parallel()) { if (op->is_parallel()) {
os() << "int num_task = max_concurrency();\n"; str_ += "int num_task = max_concurrency();\n";
DoIndent(); DoIndent();
os() << "omp_set_num_threads(num_task);\n"; str_ += "omp_set_num_threads(num_task);\n";
DoIndent(); DoIndent();
os() << "auto flambda = [=](int task_id, int num_task) -> int {\n"; str_ += "auto flambda = [=](int task_id, int num_task) -> int {\n";
IncIndent(); IncIndent();
DoIndent(); DoIndent();
os() << "int n_per_task = "; str_ += "int n_per_task = ";
Expr num_task_var = Var("num_task"); Expr num_task_var = Var("num_task");
Print((op->extent + num_task_var - 1) / num_task_var); IrPrinter::Visit((op->extent + num_task_var - 1) / num_task_var);
os() << ";\n"; str_ += ";\n";
CHECK_EQ(min.as_int32(), 0); CHECK_EQ(min.as_int32(), 0);
auto task_id = Var("task_id"); auto task_id = Var("task_id");
auto n_per_task = Var("n_per_task"); auto n_per_task = Var("n_per_task");
...@@ -214,125 +212,127 @@ void CodeGenC::Visit(const ir::For *op) { ...@@ -214,125 +212,127 @@ void CodeGenC::Visit(const ir::For *op) {
extent = (task_id + 1) * n_per_task; extent = (task_id + 1) * n_per_task;
DoIndent(); DoIndent();
} }
os() << "for ("; str_ += "for (";
os() << GetTypeRepr(Int(32)); str_ += GetTypeRepr(Int(32));
os() << " " << op->loop_var->name; str_ += " ";
os() << " = "; str_ += op->loop_var->name;
Print(min); str_ += " = ";
os() << "; "; IrPrinter::Visit(min);
os() << op->loop_var->name; str_ += "; ";
os() << " < "; str_ += op->loop_var->name;
Print(op->extent); str_ += " < ";
IrPrinter::Visit(op->extent);
if (op->is_parallel()) { if (op->is_parallel()) {
os() << " && "; str_ += " && ";
os() << op->loop_var->name; str_ += op->loop_var->name;
os() << " < "; str_ += " < ";
Print(extent); IrPrinter::Visit(extent);
} }
os() << "; "; str_ += "; ";
os() << op->loop_var->name; str_ += op->loop_var->name;
os() << " += 1"; str_ += " += 1";
os() << ") "; str_ += ") ";
Print(op->body); IrPrinter::Visit(op->body);
if (op->is_parallel()) { if (op->is_parallel()) {
os() << "\n"; str_ += "\n";
DoIndent(); DoIndent();
os() << "return 0;\n"; str_ += "return 0;\n";
DecIndent(); DecIndent();
DoIndent(); DoIndent();
os() << "};\n"; str_ += "};\n";
os() << "#pragma omp parallel num_threads(num_task)\n"; str_ += "#pragma omp parallel num_threads(num_task)\n";
DoIndent(); DoIndent();
os() << "{\n"; str_ += "{\n";
IncIndent(); IncIndent();
DoIndent(); DoIndent();
os() << "int task_id = omp_get_thread_num();\n"; str_ += "int task_id = omp_get_thread_num();\n";
DoIndent(); DoIndent();
os() << "flambda(task_id, num_task);\n"; str_ += "flambda(task_id, num_task);\n";
DecIndent(); DecIndent();
DoIndent(); DoIndent();
os() << "}"; str_ += "}";
} }
} }
void CodeGenC::Visit(const ir::PolyFor *op) { void CodeGenC::Visit(const ir::PolyFor *op) {
os() << "for ("; str_ += "for (";
os() << GetTypeRepr(Int(32)); str_ += GetTypeRepr(Int(32));
os() << " " << op->iterator->name; str_ += " ";
os() << " = "; str_ += op->iterator->name;
Print(op->init); str_ += " = ";
os() << "; "; IrPrinter::Visit(op->init);
Print(op->condition); str_ += "; ";
os() << "; "; IrPrinter::Visit(op->condition);
str_ += "; ";
os() << op->iterator->name;
os() << " += "; str_ += op->iterator->name;
Print(op->inc); str_ += " += ";
os() << ") "; IrPrinter::Visit(op->inc);
str_ += ") ";
Print(op->body);
IrPrinter::Visit(op->body);
} }
void CodeGenC::Visit(const ir::Select *op) { void CodeGenC::Visit(const ir::Select *op) {
os() << "("; str_ += "(";
os() << "("; str_ += "(";
Print(op->condition); IrPrinter::Visit(op->condition);
os() << ") ? "; str_ += ") ? ";
Print(op->true_value); IrPrinter::Visit(op->true_value);
os() << " : "; str_ += " : ";
Print(op->false_value); IrPrinter::Visit(op->false_value);
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::IfThenElse *op) { void CodeGenC::Visit(const ir::IfThenElse *op) {
os() << "if ("; str_ += "if (";
Print(op->condition); IrPrinter::Visit(op->condition);
os() << ") {\n"; str_ += ") {\n";
if (!op->true_case.As<ir::Block>()) IncIndent(); if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent(); DoIndent();
Print(op->true_case); IrPrinter::Visit(op->true_case);
if (!op->true_case.As<ir::Block>()) os() << ";"; if (!op->true_case.As<ir::Block>()) str_ += ";";
os() << "\n"; str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent(); if (!op->true_case.As<ir::Block>()) DecIndent();
DoIndent(); DoIndent();
os() << "}"; str_ += "}";
if (op->false_case.defined()) { if (op->false_case.defined()) {
os() << " else {\n"; str_ += " else {\n";
if (!op->true_case.As<ir::Block>()) IncIndent(); if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent(); DoIndent();
Print(op->false_case); IrPrinter::Visit(op->false_case);
if (!op->false_case.As<ir::Block>()) os() << ";"; if (!op->false_case.As<ir::Block>()) str_ += ";";
os() << "\n"; str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent(); if (!op->true_case.As<ir::Block>()) DecIndent();
DoIndent(); DoIndent();
os() << "}"; str_ += "}";
} }
} }
void CodeGenC::Visit(const ir::Block *op) { void CodeGenC::Visit(const ir::Block *op) {
os() << "{\n"; str_ += "{\n";
IncIndent(); IncIndent();
for (int i = 0; i < op->stmts.size() - 1; i++) { for (int i = 0; i < op->stmts.size() - 1; i++) {
DoIndent(); DoIndent();
Print(op->stmts[i]); IrPrinter::Visit(op->stmts[i]);
os() << ";\n"; str_ += ";\n";
} }
if (op->stmts.size() >= 1) { if (op->stmts.size() >= 1) {
DoIndent(); DoIndent();
Print(op->stmts.back()); IrPrinter::Visit(op->stmts.back());
os() << ";"; str_ += ";";
} }
DecIndent(); DecIndent();
os() << "\n"; str_ += "\n";
DoIndent(); DoIndent();
os() << "}"; str_ += "}";
} }
void CodeGenC::Visit(const ir::Call *op) { void CodeGenC::Visit(const ir::Call *op) {
if (op->name == runtime::intrinsic::buffer_malloc) { if (op->name == runtime::intrinsic::buffer_malloc) {
...@@ -340,13 +340,15 @@ void CodeGenC::Visit(const ir::Call *op) { ...@@ -340,13 +340,15 @@ void CodeGenC::Visit(const ir::Call *op) {
} else if (op->name == runtime::intrinsic::pod_values_to_array_repr) { } else if (op->name == runtime::intrinsic::pod_values_to_array_repr) {
PrintCall_pod_values_to_array(op); PrintCall_pod_values_to_array(op);
} else if (op->is_intrinsic_call()) { } else if (op->is_intrinsic_call()) {
os() << op->name << "("; str_ += op->name;
str_ += "(";
PrintCallArgs(op); PrintCallArgs(op);
os() << ")"; str_ += ")";
} else if (op->is_cinn_call()) { // call CINN LoweredFunc } else if (op->is_cinn_call()) { // call CINN LoweredFunc
os() << op->name << "("; str_ += op->name;
str_ += "(";
PrintCallArgs(op); PrintCallArgs(op);
os() << ")"; str_ += ")";
} else if (op->is_extern_call()) { } else if (op->is_extern_call()) {
const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup( const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup(
ExternFuncID{backend_C, op->name.c_str()}); ExternFuncID{backend_C, op->name.c_str()});
...@@ -356,9 +358,10 @@ void CodeGenC::Visit(const ir::Call *op) { ...@@ -356,9 +358,10 @@ void CodeGenC::Visit(const ir::Call *op) {
emitter.Emit(op); emitter.Emit(op);
} else { } else {
CHECK(!op->read_args.empty() || !op->write_args.empty()); CHECK(!op->read_args.empty() || !op->write_args.empty());
os() << op->name << "("; str_ += op->name;
str_ += "(";
PrintCallArgs(op); PrintCallArgs(op);
os() << ")"; str_ += ")";
} }
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
...@@ -367,38 +370,40 @@ void CodeGenC::Visit(const ir::Call *op) { ...@@ -367,38 +370,40 @@ void CodeGenC::Visit(const ir::Call *op) {
void CodeGenC::PrintCallArgs(const ir::Call *op) { void CodeGenC::PrintCallArgs(const ir::Call *op) {
if (!op->read_args.empty()) { if (!op->read_args.empty()) {
for (int i = 0; i < op->read_args.size() - 1; i++) { for (int i = 0; i < op->read_args.size() - 1; i++) {
Print(op->read_args[i]); IrPrinter::Visit(op->read_args[i]);
os() << ", "; str_ += ", ";
} }
Print(op->read_args.back()); IrPrinter::Visit(op->read_args.back());
} }
if (!op->write_args.empty()) { if (!op->write_args.empty()) {
if (!op->read_args.empty()) os() << ", "; if (!op->read_args.empty()) str_ += ", ";
for (int i = 0; i < op->write_args.size() - 1; i++) { for (int i = 0; i < op->write_args.size() - 1; i++) {
Print(op->write_args[i]); IrPrinter::Visit(op->write_args[i]);
os() << ", "; str_ += ", ";
} }
Print(op->write_args.back()); IrPrinter::Visit(op->write_args.back());
} }
} }
void CodeGenC::PrintCall_buffer_malloc(const ir::Call *op) { void CodeGenC::PrintCall_buffer_malloc(const ir::Call *op) {
CHECK_EQ(op->read_args.size(), 2UL); CHECK_EQ(op->read_args.size(), 2UL);
os() << op->name << "("; str_ += op->name;
str_ += "(";
PrintCastExpr("void*", op->read_args[0]); PrintCastExpr("void*", op->read_args[0]);
os() << ", "; str_ += ", ";
os() << op->read_args[1]; IrPrinter::Visit(op->read_args[1]);
os() << ")"; str_ += ")";
} }
void CodeGenC::PrintCall_cinn_pod_value_to_(const ir::Call *op) { void CodeGenC::PrintCall_cinn_pod_value_to_(const ir::Call *op) {
CHECK_EQ(op->read_args.size(), 1UL); CHECK_EQ(op->read_args.size(), 1UL);
os() << op->name << "("; str_ += op->name;
os() << "&("; str_ += "(";
Print(op->read_args[0]); str_ += "&(";
os() << ")"; IrPrinter::Visit(op->read_args[0]);
os() << ")"; str_ += ")";
str_ += ")";
} }
void CodeGenC::PrintCall_get_address(const ir::Call *op) { void CodeGenC::PrintCall_get_address(const ir::Call *op) {
...@@ -409,11 +414,11 @@ void CodeGenC::PrintCall_get_address(const ir::Call *op) { ...@@ -409,11 +414,11 @@ void CodeGenC::PrintCall_get_address(const ir::Call *op) {
CHECK(read_var || read_buf) << "Only Var or Buffer can get address"; CHECK(read_var || read_buf) << "Only Var or Buffer can get address";
if (read_var) { if (read_var) {
if (read_var->type().lanes() <= 1) os() << "&"; if (read_var->type().lanes() <= 1) str_ += "&";
os() << read_var->name; str_ += read_var->name;
} else if (read_buf) { } else if (read_buf) {
if (read_buf->type().lanes() <= 1) os() << "&"; if (read_buf->type().lanes() <= 1) str_ += "&";
os() << read_buf->name; str_ += read_buf->name;
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
...@@ -432,42 +437,45 @@ void CodeGenC::PrintCall_pod_values_to_array(const ir::Call *op) { ...@@ -432,42 +437,45 @@ void CodeGenC::PrintCall_pod_values_to_array(const ir::Call *op) {
arg_names.push_back(arg_var->name); arg_names.push_back(arg_var->name);
} }
os() << "cinn_pod_value_t " << output_var->name << "[] = "; str_ += "cinn_pod_value_t ";
os() << "{ "; str_ += output_var->name;
str_ += "[] = ";
str_ += "{ ";
os() << utils::Join(arg_names, ", "); str_ += utils::Join(arg_names, ", ");
os() << " }"; str_ += " }";
} }
void CodeGenC::Visit(const ir::_Module_ *op) { CINN_NOT_IMPLEMENTED } void CodeGenC::Visit(const ir::_Module_ *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; } void CodeGenC::Visit(const ir::_Var_ *op) { str_ += op->name; }
void CodeGenC::Visit(const ir::Load *op) { void CodeGenC::Visit(const ir::Load *op) {
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address. if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
CHECK(op->type().is_vector()); CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes()); PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::" str_ += "::";
<< "Load("; str_ += "Load(";
os() << op->tensor.As<ir::_Tensor_>()->name; str_ += op->tensor.As<ir::_Tensor_>()->name;
os() << ","; str_ += ",";
Print(dense_strided_ramp); IrPrinter::Visit(dense_strided_ramp);
os() << ")"; str_ += ")";
} else if (op->index().type().is_vector()) { } else if (op->index().type().is_vector()) {
// gather // gather
CHECK(op->type().is_vector()); CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes()); PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::Load("; str_ += "::Load(";
os() << op->tensor.As<ir::_Tensor_>()->name; str_ += op->tensor.As<ir::_Tensor_>()->name;
os() << ","; str_ += ",";
Print(op->index()); IrPrinter::Visit(op->index());
os() << ")"; str_ += ")";
} else if (op->is_addr_tensor()) { } else if (op->is_addr_tensor()) {
auto *tensor = op->tensor.As<ir::_Tensor_>(); auto *tensor = op->tensor.As<ir::_Tensor_>();
os() << tensor->name << "["; str_ += tensor->name;
Print(op->index()); str_ += "[";
os() << "]"; IrPrinter::Visit(op->index());
str_ += "]";
} else { } else {
IrPrinter::Visit(op); IrPrinter::Visit(op);
} }
...@@ -478,56 +486,59 @@ void CodeGenC::Visit(const ir::Store *op) { ...@@ -478,56 +486,59 @@ void CodeGenC::Visit(const ir::Store *op) {
auto *tensor = op->tensor.As<ir::_Tensor_>(); auto *tensor = op->tensor.As<ir::_Tensor_>();
CHECK(tensor); CHECK(tensor);
os() << tensor->name << "["; str_ += tensor->name;
Print(op->index()); str_ += "[";
os() << "]"; IrPrinter::Visit(op->index());
os() << " = "; str_ += "]";
Print(op->value); str_ += " = ";
IrPrinter::Visit(op->value);
} }
void CodeGenC::Visit(const ir::Alloc *op) { void CodeGenC::Visit(const ir::Alloc *op) {
os() << runtime::intrinsic::buffer_malloc; str_ += runtime::intrinsic::buffer_malloc;
os() << "("; str_ += "(";
os() << "(void*)(0), "; str_ += "(void*)(0), ";
auto *buffer = op->destination.As<ir::_Buffer_>(); auto *buffer = op->destination.As<ir::_Buffer_>();
os() << buffer->name; str_ += buffer->name;
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::Free *op) { void CodeGenC::Visit(const ir::Free *op) {
os() << runtime::intrinsic::buffer_free; str_ += runtime::intrinsic::buffer_free;
os() << "("; str_ += "(";
os() << "(void*)(0), "; str_ += "(void*)(0), ";
auto *buffer = op->destination.As<ir::_Buffer_>(); auto *buffer = op->destination.As<ir::_Buffer_>();
os() << buffer->name; str_ += buffer->name;
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::_Buffer_ *op) { os() << op->name; } void CodeGenC::Visit(const ir::_Buffer_ *op) { str_ += op->name; }
void CodeGenC::Visit(const ir::_Tensor_ *op) { os() << op->buffer->name; } void CodeGenC::Visit(const ir::_Tensor_ *op) { str_ += op->buffer->name; }
void CodeGenC::Visit(const ir::Let *op) { void CodeGenC::Visit(const ir::Let *op) {
bool is_vec = false; bool is_vec = false;
CHECK(op->type().valid()); CHECK(op->type().valid());
if (op->body.defined() && op->body.As<ir::Broadcast>()) { if (op->body.defined() && op->body.As<ir::Broadcast>()) {
// broadcast's type is hard to print, so use c++11 auto instead. // broadcast's type is hard to print, so use c++11 auto instead.
os() << "auto"; str_ += "auto";
is_vec = true; is_vec = true;
} else { } else {
os() << GetTypeRepr(op->type()); str_ += GetTypeRepr(op->type());
} }
os() << " "; str_ += " ";
Print(op->symbol); IrPrinter::Visit(op->symbol);
// native C array. // native C array.
if (op->type().lanes() > 1 && !is_vec) { if (op->type().lanes() > 1 && !is_vec) {
os() << "[" << op->type().lanes() << "]"; str_ += "[";
str_ += std::to_string(op->type().lanes());
str_ += "]";
} }
if (op->body.defined()) { if (op->body.defined()) {
os() << " = "; str_ += " = ";
Print(op->body); IrPrinter::Visit(op->body);
} }
} }
...@@ -537,22 +548,29 @@ void CodeGenC::Visit(const ir::Reduce *op) { ...@@ -537,22 +548,29 @@ void CodeGenC::Visit(const ir::Reduce *op) {
} }
void CodeGenC::Visit(const ir::Ramp *op) { void CodeGenC::Visit(const ir::Ramp *op) {
os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) str_ += "StackVec<";
<< ">::Ramp("; str_ += std::to_string(op->lanes);
Print(op->base); str_ += ",";
os() << ", "; str_ += GetTypeRepr(op->type().ElementOf());
Print(op->stride); str_ += ">::Ramp(";
os() << ", "; IrPrinter::Visit(op->base);
os() << op->lanes; str_ += ", ";
os() << ")"; IrPrinter::Visit(op->stride);
str_ += ", ";
str_ += std::to_string(op->lanes);
str_ += ")";
} }
void CodeGenC::Visit(const ir::Broadcast *op) { void CodeGenC::Visit(const ir::Broadcast *op) {
os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) str_ += "StackVec<";
<< ">::Broadcast("; str_ += std::to_string(op->lanes);
Print(op->value); str_ += ",";
os() << ", "; str_ += GetTypeRepr(op->type().ElementOf());
os() << op->lanes << ")"; str_ += ">::Broadcast(";
IrPrinter::Visit(op->value);
str_ += ", ";
str_ += std::to_string(op->lanes);
str_ += ")";
} }
void CodeGenC::Visit(const ir::FracOp *op) { ir::IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::FracOp *op) { ir::IrPrinter::Visit(op); }
...@@ -560,35 +578,41 @@ void CodeGenC::Visit(const ir::Sum *op) { ir::IrPrinter::Visit(op); } ...@@ -560,35 +578,41 @@ void CodeGenC::Visit(const ir::Sum *op) { ir::IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Product *op) { ir::IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::Product *op) { ir::IrPrinter::Visit(op); }
void CodeGenC::PrintCastExpr(const Type &type, Expr e) { void CodeGenC::PrintCastExpr(const Type &type, Expr e) {
os() << "((" << GetTypeRepr(type) << ")"; str_ += "((";
os() << "("; str_ += GetTypeRepr(type);
Print(e); str_ += ")";
os() << "))"; str_ += "(";
IrPrinter::Visit(e);
str_ += "))";
} }
void CodeGenC::PrintCastExpr(const std::string &type, Expr e) { void CodeGenC::PrintCastExpr(const std::string &type, Expr e) {
os() << "(" << type << ")"; str_ += "(";
os() << "("; str_ += type;
Print(e); str_ += ")";
os() << ")"; str_ += "(";
IrPrinter::Visit(e);
str_ += ")";
} }
void CodeGenC::PrintShape(const std::vector<Expr> &shape, void CodeGenC::PrintShape(const std::vector<Expr> &shape,
char leftb, char leftb,
char rightb) { char rightb) {
os() << leftb << " "; str_ += leftb;
str_ += " ";
for (int i = 0; i < shape.size() - 1; i++) { for (int i = 0; i < shape.size() - 1; i++) {
Print(shape[i]); IrPrinter::Visit(shape[i]);
os() << ", "; str_ += ", ";
} }
if (shape.size() > 1) Print(shape.back()); if (shape.size() > 1) IrPrinter::Visit(shape.back());
os() << " " << rightb; str_ += " ";
str_ += rightb;
} }
void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
PrintFunctionDeclaration(op); PrintFunctionDeclaration(op);
os() << "\n"; str_ += "\n";
DoIndent(); DoIndent();
...@@ -623,21 +647,21 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { ...@@ -623,21 +647,21 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
optim::RemoveNestedBlock(&func_body); optim::RemoveNestedBlock(&func_body);
Print(func_body); IrPrinter::Visit(func_body);
} }
void CodeGenC::PrintIncludes() { void CodeGenC::PrintIncludes() {
os() << "#include <cinn_runtime.h>\n"; str_ += "#include <cinn_runtime.h>\n";
os() << "#include <stdio.h>\n"; str_ += "#include <stdio.h>\n";
os() << "\n"; str_ += "\n";
} }
void CodeGenC::PrintFileGuardOpen(const std::string &name) { void CodeGenC::PrintFileGuardOpen(const std::string &name) {
os() << utils::StringFormat("#ifndef _%s_CINN_H_\n", Uppercase(name).c_str()); str_ += utils::StringFormat("#ifndef _%s_CINN_H_\n", Uppercase(name).c_str());
os() << utils::StringFormat("#define _%s_CINN_H_\n", Uppercase(name).c_str()); str_ += utils::StringFormat("#define _%s_CINN_H_\n", Uppercase(name).c_str());
os() << "\n"; str_ += "\n";
} }
void CodeGenC::PrintFileGuardClose(const std::string &module_name) { void CodeGenC::PrintFileGuardClose(const std::string &module_name) {
os() << utils::StringFormat("#endif // _%s_CINN_H_\n", str_ += utils::StringFormat("#endif // _%s_CINN_H_\n",
Uppercase(module_name).c_str()); Uppercase(module_name).c_str());
} }
...@@ -653,16 +677,16 @@ void CodeGenC::PrintBufferCreation(const std::vector<ir::Buffer> &buffers) { ...@@ -653,16 +677,16 @@ void CodeGenC::PrintBufferCreation(const std::vector<ir::Buffer> &buffers) {
Var variable = ir::_Var_::Make(buffer->name, buffer_ptr_type); Var variable = ir::_Var_::Make(buffer->name, buffer_ptr_type);
auto expr = ir::intrinsics::BufferCreate::Make(buffer); auto expr = ir::intrinsics::BufferCreate::Make(buffer);
expr = ir::Let::Make(variable, expr); expr = ir::Let::Make(variable, expr);
Print(expr); IrPrinter::Visit(expr);
os() << ";\n"; str_ += ";\n";
} }
} }
void CodeGenC::PrintBufferDestroy(const std::vector<ir::Buffer> &buffers) { void CodeGenC::PrintBufferDestroy(const std::vector<ir::Buffer> &buffers) {
for (auto &buffer : buffers) { for (auto &buffer : buffers) {
DoIndent(); DoIndent();
Print(buffer.DestroyExpr()); IrPrinter::Visit(buffer.DestroyExpr());
os() << ";\n"; str_ += ";\n";
} }
} }
...@@ -672,8 +696,8 @@ void CodeGenC::GenerateHeaderFile(const ir::Module &module) { ...@@ -672,8 +696,8 @@ void CodeGenC::GenerateHeaderFile(const ir::Module &module) {
for (auto &func : module.functions()) { for (auto &func : module.functions()) {
PrintFunctionDeclaration(func.As<ir::_LoweredFunc_>()); PrintFunctionDeclaration(func.As<ir::_LoweredFunc_>());
os() << ";\n"; str_ += ";\n";
os() << "\n\n"; str_ += "\n\n";
} }
PrintFileGuardClose(module.name()); PrintFileGuardClose(module.name());
...@@ -682,53 +706,58 @@ void CodeGenC::GenerateHeaderFile(const ir::Module &module) { ...@@ -682,53 +706,58 @@ void CodeGenC::GenerateHeaderFile(const ir::Module &module) {
void CodeGenC::PrintFuncArg(const ir::Argument &arg) { void CodeGenC::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) { if (arg.is_buffer()) {
if (arg.is_input()) { if (arg.is_input()) {
os() << "const struct cinn_buffer_t *"; str_ += "const struct cinn_buffer_t *";
} else { } else {
os() << "struct cinn_buffer_t *"; str_ += "struct cinn_buffer_t *";
} }
} else if (arg.is_var()) { } else if (arg.is_var()) {
os() << GetTypeRepr(arg.type()) << " "; str_ += GetTypeRepr(arg.type());
os() << arg.name(); str_ += " ";
str_ += arg.name();
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
os() << arg.name(); str_ += arg.name();
} }
void CodeGenC::PrintRuntimeType(const cinn_type_t &type) { void CodeGenC::PrintRuntimeType(const cinn_type_t &type) {
if (type == cinn_bool_t()) { if (type == cinn_bool_t()) {
os() << "cinn_bool_t()"; str_ += "cinn_bool_t()";
} else if (type == cinn_int8_t()) { } else if (type == cinn_int8_t()) {
os() << "cinn_int8_t()"; str_ += "cinn_int8_t()";
} else if (type == cinn_int16_t()) { } else if (type == cinn_int16_t()) {
os() << "cinn_int16_t()"; str_ += "cinn_int16_t()";
} else if (type == cinn_int32_t()) { } else if (type == cinn_int32_t()) {
os() << "cinn_int32_t()"; str_ += "cinn_int32_t()";
} else if (type == cinn_int64_t()) { } else if (type == cinn_int64_t()) {
os() << "cinn_int64_t()"; str_ += "cinn_int64_t()";
} else if (type == cinn_uint8_t()) { } else if (type == cinn_uint8_t()) {
os() << "cinn_uint8_t()"; str_ += "cinn_uint8_t()";
} else if (type == cinn_uint16_t()) { } else if (type == cinn_uint16_t()) {
os() << "cinn_uint16_t()"; str_ += "cinn_uint16_t()";
} else if (type == cinn_uint32_t()) { } else if (type == cinn_uint32_t()) {
os() << "cinn_uint32_t()"; str_ += "cinn_uint32_t()";
} else if (type == cinn_uint64_t()) { } else if (type == cinn_uint64_t()) {
os() << "cinn_uint64_t()"; str_ += "cinn_uint64_t()";
} else if (type == cinn_bfloat16_t()) { } else if (type == cinn_bfloat16_t()) {
os() << "cinn_bfloat16_t()"; str_ += "cinn_bfloat16_t()";
} else if (type == cinn_float16_t()) { } else if (type == cinn_float16_t()) {
os() << "cinn_float16_t()"; str_ += "cinn_float16_t()";
} else if (type == cinn_float32_t()) { } else if (type == cinn_float32_t()) {
os() << "cinn_float32_t()"; str_ += "cinn_float32_t()";
} else if (type == cinn_float64_t()) { } else if (type == cinn_float64_t()) {
os() << "cinn_float64_t()"; str_ += "cinn_float64_t()";
} else { } else {
LOG(FATAL) << "Unknown type is not supported to print"; LOG(FATAL) << "Unknown type is not supported to print";
} }
} }
void CodeGenC::PrintStackVecType(Type type, int lanes) { void CodeGenC::PrintStackVecType(Type type, int lanes) {
os() << "StackedVec<" << GetTypeRepr(type) << "," << lanes << ">"; str_ += "StackedVec<";
str_ += GetTypeRepr(type);
str_ += ",";
str_ += std::to_string(lanes);
str_ += ">";
} }
void CodeGenC::Visit(const ir::PrimitiveNode *op) { CINN_NOT_IMPLEMENTED } void CodeGenC::Visit(const ir::PrimitiveNode *op) { CINN_NOT_IMPLEMENTED }
...@@ -751,109 +780,117 @@ void CodeGenC::Visit(const ir::IntrinsicOp *op) { ...@@ -751,109 +780,117 @@ void CodeGenC::Visit(const ir::IntrinsicOp *op) {
} }
void CodeGenC::Visit(const ir::intrinsics::BufferGetDataHandle *op) { void CodeGenC::Visit(const ir::intrinsics::BufferGetDataHandle *op) {
os() << op->buffer.as_buffer()->name; str_ += op->buffer.as_buffer()->name;
os() << "->"; str_ += "->";
os() << "memory"; str_ += "memory";
} }
void CodeGenC::Visit(const ir::intrinsics::BufferGetDataConstHandle *op) { void CodeGenC::Visit(const ir::intrinsics::BufferGetDataConstHandle *op) {
os() << op->buffer.as_buffer()->name; str_ += op->buffer.as_buffer()->name;
os() << "->"; str_ += "->";
os() << "memory"; str_ += "memory";
} }
void CodeGenC::Visit(const ir::intrinsics::PodValueToX *op) { void CodeGenC::Visit(const ir::intrinsics::PodValueToX *op) {
auto to_type = op->GetOutputType(0); auto to_type = op->GetOutputType(0);
if (to_type == type_of<float>()) { if (to_type == type_of<float>()) {
os() << runtime::intrinsic::pod_value_to_float; str_ += runtime::intrinsic::pod_value_to_float;
} else if (to_type == type_of<double>()) { } else if (to_type == type_of<double>()) {
os() << runtime::intrinsic::pod_value_to_double; str_ += runtime::intrinsic::pod_value_to_double;
} else if (to_type == type_of<float16>()) { } else if (to_type == type_of<float16>()) {
os() << runtime::intrinsic::pod_value_to_float16; str_ += runtime::intrinsic::pod_value_to_float16;
} else if (to_type == type_of<bool>()) { } else if (to_type == type_of<bool>()) {
os() << runtime::intrinsic::pod_value_to_bool; str_ += runtime::intrinsic::pod_value_to_bool;
} else if (to_type == type_of<int8_t>()) { } else if (to_type == type_of<int8_t>()) {
os() << runtime::intrinsic::pod_value_to_int8; str_ += runtime::intrinsic::pod_value_to_int8;
} else if (to_type == type_of<int16_t>()) { } else if (to_type == type_of<int16_t>()) {
os() << runtime::intrinsic::pod_value_to_int16; str_ += runtime::intrinsic::pod_value_to_int16;
} else if (to_type == type_of<int32_t>()) { } else if (to_type == type_of<int32_t>()) {
os() << runtime::intrinsic::pod_value_to_int32; str_ += runtime::intrinsic::pod_value_to_int32;
} else if (to_type == type_of<int64_t>()) { } else if (to_type == type_of<int64_t>()) {
os() << runtime::intrinsic::pod_value_to_int64; str_ += runtime::intrinsic::pod_value_to_int64;
} else if (to_type == type_of<uint8_t>()) { } else if (to_type == type_of<uint8_t>()) {
os() << runtime::intrinsic::pod_value_to_uint8; str_ += runtime::intrinsic::pod_value_to_uint8;
} else if (to_type == type_of<uint16_t>()) { } else if (to_type == type_of<uint16_t>()) {
os() << runtime::intrinsic::pod_value_to_uint16; str_ += runtime::intrinsic::pod_value_to_uint16;
} else if (to_type == type_of<uint32_t>()) { } else if (to_type == type_of<uint32_t>()) {
os() << runtime::intrinsic::pod_value_to_uint32; str_ += runtime::intrinsic::pod_value_to_uint32;
} else if (to_type == type_of<uint64_t>()) { } else if (to_type == type_of<uint64_t>()) {
os() << runtime::intrinsic::pod_value_to_uint64; str_ += runtime::intrinsic::pod_value_to_uint64;
} else if (to_type == type_of<void *>()) { } else if (to_type == type_of<void *>()) {
os() << runtime::intrinsic::pod_value_to_void_p; str_ += runtime::intrinsic::pod_value_to_void_p;
} else if (to_type == type_of<cinn_buffer_t *>()) { } else if (to_type == type_of<cinn_buffer_t *>()) {
os() << runtime::intrinsic::pod_value_to_buffer_p; str_ += runtime::intrinsic::pod_value_to_buffer_p;
} else { } else {
LOG(FATAL) << "Not supported type: " << to_type; LOG(FATAL) << "Not supported type: " << to_type;
} }
os() << "("; str_ += "(";
Print(op->pod_value_ptr); IrPrinter::Visit(op->pod_value_ptr);
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::intrinsics::BufferCreate *op) { void CodeGenC::Visit(const ir::intrinsics::BufferCreate *op) {
const ir::_Buffer_ *buffer_arg = op->buffer.as_buffer(); const ir::_Buffer_ *buffer_arg = op->buffer.as_buffer();
CHECK(buffer_arg); CHECK(buffer_arg);
os() << runtime::intrinsic::buffer_create; str_ += runtime::intrinsic::buffer_create;
os() << "("; str_ += "(";
PrintCastExpr("cinn_device_kind_t", Expr(buffer_arg->target.runtime_arch())); PrintCastExpr("cinn_device_kind_t", Expr(buffer_arg->target.runtime_arch()));
os() << "/*target*/, "; str_ += "/*target*/, ";
PrintRuntimeType(runtime::ToRuntimeType(buffer_arg->dtype.ElementOf())); PrintRuntimeType(runtime::ToRuntimeType(buffer_arg->dtype.ElementOf()));
os() << ", "; str_ += ", ";
PrintShape(buffer_arg->shape); PrintShape(buffer_arg->shape);
if (buffer_arg->data_alignment > 0) { if (buffer_arg->data_alignment > 0) {
os() << ", " << buffer_arg->data_alignment << "/*align*/"; str_ += ", ";
str_ += std::to_string(buffer_arg->data_alignment);
str_ += "/*align*/";
} }
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::intrinsics::GetAddr *op) { void CodeGenC::Visit(const ir::intrinsics::GetAddr *op) {
if (op->data.as_buffer()) { if (op->data.as_buffer()) {
os() << "&" << op->data.as_buffer()->name; str_ += "&";
str_ += op->data.as_buffer()->name;
} else if (op->data.as_var()) { } else if (op->data.as_var()) {
os() << "&" << op->data.as_var()->name; str_ += "&";
str_ += op->data.as_var()->name;
} else { } else {
os() << "&("; str_ += "&(";
Print(op->data); IrPrinter::Visit(op->data);
os() << ")"; str_ += ")";
} }
} }
void CodeGenC::Visit(const ir::intrinsics::ArgsConstruct *op) { void CodeGenC::Visit(const ir::intrinsics::ArgsConstruct *op) {
os() << runtime::intrinsic::args_construct_repr << "("; str_ += runtime::intrinsic::args_construct_repr;
os() << op->var->name << ", "; str_ += "(";
os() << op->args.size() << ", "; str_ += op->var->name;
str_ += ", ";
str_ += std::to_string(op->args.size());
str_ += ", ";
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
Print(op->args[i]); IrPrinter::Visit(op->args[i]);
os() << ", "; str_ += ", ";
} }
if (!op->args.empty()) { if (!op->args.empty()) {
Print(op->args.back()); IrPrinter::Visit(op->args.back());
} }
os() << ")"; str_ += ")";
} }
void CodeGenC::Visit(const ir::intrinsics::BuiltinIntrin *op) { void CodeGenC::Visit(const ir::intrinsics::BuiltinIntrin *op) {
os() << op->name << "("; str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
Print(op->args[i]); IrPrinter::Visit(op->args[i]);
os() << ", "; str_ += ", ";
} }
Print(op->args.back()); IrPrinter::Visit(op->args.back());
} }
os() << ")"; str_ += ")";
} }
std::string ReadWholeFile(const std::string &path) { std::string ReadWholeFile(const std::string &path) {
...@@ -874,7 +911,8 @@ void CodeGenC::PrintBuiltinCodes() { ...@@ -874,7 +911,8 @@ void CodeGenC::PrintBuiltinCodes() {
auto source = auto source =
ReadWholeFile(FLAGS_cinn_x86_builtin_code_root + "/" + x86_code_file); ReadWholeFile(FLAGS_cinn_x86_builtin_code_root + "/" + x86_code_file);
os() << source << "\n"; str_ += source;
str_ += "\n";
} }
namespace detail { namespace detail {
......
...@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter { ...@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter {
void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; } void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; }
protected: protected:
std::string Compile(const ir::LoweredFunc& function); void Compile(const ir::LoweredFunc& function);
std::string Compile(const ir::Buffer& buffer);
void GenerateHeaderFile(const ir::Module& module); void GenerateHeaderFile(const ir::Module& module);
...@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter { ...@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter {
// @} // @}
void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) { void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) {
os() << "void " << op->name << "("; str_ += "void ";
os() << "void* _args, int32_t num_args"; str_ += op->name;
os() << ")"; str_ += "(";
str_ += "void* _args, int32_t num_args";
str_ += ")";
} }
void PrintShape(const std::vector<Expr>& shape, void PrintShape(const std::vector<Expr>& shape,
......
...@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) { ...@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) {
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_load("; str_ += "cinn_avx512_load(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_load("; str_ += "cinn_avx256_load(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) { ...@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) {
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_set1("; str_ += "cinn_avx512_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value); PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_set1("; str_ += "cinn_avx256_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value); PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) { ...@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) {
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_store("; str_ += "cinn_avx512_store(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ", "; str_ += ", ";
Print(op->value); IrPrinter::Visit(op->value);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_store("; str_ += "cinn_avx256_store(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ", "; str_ += ", ";
Print(op->value); IrPrinter::Visit(op->value);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) { ...@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value; Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value;
if (SupportsAVX512()) { if (SupportsAVX512()) {
os() << "cinn_avx512_set1("; str_ += "cinn_avx512_set1(";
Print(value); IrPrinter::Visit(value);
os() << ")"; str_ += ")";
} else if (SupportsAVX256()) { } else if (SupportsAVX256()) {
os() << "cinn_avx256_set1("; str_ += "cinn_avx256_set1(";
Print(value); IrPrinter::Visit(value);
os() << ")"; str_ += ")";
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
} else { } else {
Print(*op); IrPrinter::Visit(*op);
} }
} }
...@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) { ...@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) {
} }
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op->name << "("; str_ += "cinn_avx512_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]); PrintVecInputArgument(&op->args[i]);
os() << ", "; str_ += ", ";
} }
Print(op->args.back()); IrPrinter::Visit(op->args.back());
} }
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op->name << "("; str_ += "cinn_avx256_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]); PrintVecInputArgument(&op->args[i]);
os() << ", "; str_ += ", ";
} }
PrintVecInputArgument(&op->args.back()); PrintVecInputArgument(&op->args.back());
} }
os() << ")"; str_ += ")";
} else if (bits == 128) { } else if (bits == 128) {
os() << "cinn_avx128_" << op->name << "("; str_ += "cinn_avx128_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]); PrintVecInputArgument(&op->args[i]);
os() << ", "; str_ += ", ";
} }
PrintVecInputArgument(&op->args.back()); PrintVecInputArgument(&op->args.back());
} }
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
......
...@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC { ...@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC {
template <typename Op> template <typename Op>
void PrintAbsAddr(const Op *op) { void PrintAbsAddr(const Op *op) {
os() << op->tensor.template As<ir::_Tensor_>()->name << " + "; str_ += op->tensor.template As<ir::_Tensor_>()->name;
str_ += " + ";
auto index = op->index(); auto index = op->index();
auto *ramp_n = index.template As<ir::Ramp>(); auto *ramp_n = index.template As<ir::Ramp>();
if (ramp_n) { if (ramp_n) {
CHECK(!ramp_n->base.template As<ir::Ramp>()) CHECK(!ramp_n->base.template As<ir::Ramp>())
<< "base of a Ramp node should not be Ramp type"; << "base of a Ramp node should not be Ramp type";
Print(ramp_n->base); IrPrinter::Visit(ramp_n->base);
} else { } else {
Print(op->index()); IrPrinter::Visit(op->index());
} }
} }
...@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op, ...@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op,
// TODO(Superjomn) Consider support BLAS. // TODO(Superjomn) Consider support BLAS.
int bits = a.type().bits() * a.type().lanes(); int bits = a.type().bits() * a.type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op_repr << "("; str_ += "cinn_avx512_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a); PrintVecInputArgument(&a);
os() << ", "; str_ += ", ";
PrintVecInputArgument(&b); PrintVecInputArgument(&b);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op_repr << "("; str_ += "cinn_avx256_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a); PrintVecInputArgument(&a);
os() << ", "; str_ += ", ";
PrintVecInputArgument(&b); PrintVecInputArgument(&b);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
......
...@@ -62,6 +62,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -62,6 +62,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
CodeGenC::inline_builtin_codes_ = false; CodeGenC::inline_builtin_codes_ = false;
if (!outputs.c_header_name.empty()) { if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader); auto source = Compile(module, OutputKind::CHeader);
str_ = "";
std::ofstream file(outputs.c_header_name); std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source; file << source;
...@@ -71,6 +72,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -71,6 +72,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
if (!outputs.cuda_source_name.empty()) { if (!outputs.cuda_source_name.empty()) {
auto source = Compile(module, OutputKind::CImpl); auto source = Compile(module, OutputKind::CImpl);
str_ = "";
std::ofstream file(outputs.cuda_source_name); std::ofstream file(outputs.cuda_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name; CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name;
file << source; file << source;
...@@ -79,9 +81,8 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -79,9 +81,8 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
} }
} }
std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) { void CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
Print(Expr(func)); IrPrinter::Visit(Expr(func));
return ss_.str();
} }
std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs( std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
...@@ -117,10 +118,10 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs( ...@@ -117,10 +118,10 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
// clear names valid within scope when enter a new function // clear names valid within scope when enter a new function
vectorized_tensor_names_.clear(); vectorized_tensor_names_.clear();
os() << "__global__\n"; str_ += "__global__\n";
PrintFunctionDeclaration(op); PrintFunctionDeclaration(op);
os() << "\n"; str_ += "\n";
DoIndent(); DoIndent();
...@@ -145,15 +146,16 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { ...@@ -145,15 +146,16 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
if (!func_body.As<ir::Block>()) { if (!func_body.As<ir::Block>()) {
func_body = ir::Block::Make({func_body}); func_body = ir::Block::Make({func_body});
} }
Print(func_body); IrPrinter::Visit(func_body);
} }
void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) { void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) {
if (utils::Startswith(op->name, "threadIdx") || if (utils::Startswith(op->name, "threadIdx") ||
utils::Startswith(op->name, "blockIdx")) { utils::Startswith(op->name, "blockIdx")) {
os() << "(int)" + op->name; str_ += "(int)";
str_ += op->name;
} else { } else {
os() << op->name; str_ += op->name;
} }
} }
...@@ -163,58 +165,63 @@ void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) { ...@@ -163,58 +165,63 @@ void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) {
} }
void CodeGenCUDA_Dev::Visit(const ir::Min *op) { void CodeGenCUDA_Dev::Visit(const ir::Min *op) {
os() << "min("; str_ += "min(";
Print(op->a()); IrPrinter::Visit(op->a());
os() << ", "; str_ += ", ";
Print(op->b()); IrPrinter::Visit(op->b());
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::Visit(const ir::Max *op) { void CodeGenCUDA_Dev::Visit(const ir::Max *op) {
os() << "max("; str_ += "max(";
Print(op->a()); IrPrinter::Visit(op->a());
os() << ", "; str_ += ", ";
Print(op->b()); IrPrinter::Visit(op->b());
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) { void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) {
os() << "void "; str_ += "void ";
if (op->cuda_axis_info.valid()) { if (op->cuda_axis_info.valid()) {
int thread_num = 1; int thread_num = 1;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
thread_num *= op->cuda_axis_info.block_dim(i); thread_num *= op->cuda_axis_info.block_dim(i);
} }
os() << "__launch_bounds__(" << thread_num << ") "; str_ += "__launch_bounds__(";
str_ += std::to_string(thread_num);
str_ += ") ";
} }
os() << op->name << "("; str_ += op->name;
str_ += "(";
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
auto &arg = op->args[i]; auto &arg = op->args[i];
PrintFuncArg(arg); PrintFuncArg(arg);
os() << ", "; str_ += ", ";
} }
if (!op->args.empty()) { if (!op->args.empty()) {
PrintFuncArg(op->args.back()); PrintFuncArg(op->args.back());
} }
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) { void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) { if (arg.is_buffer()) {
// In CUDA kernel, only primitive type is supported, so we replace the // In CUDA kernel, only primitive type is supported, so we replace the
// buffer with T*j // buffer with T*j
if (arg.is_input()) os() << "const "; if (arg.is_input()) str_ += "const ";
os() << GetTypeRepr(arg.buffer_arg()->dtype); str_ += GetTypeRepr(arg.buffer_arg()->dtype);
os() << "* "; str_ += "* ";
os() << kCKeywordRestrict << " "; str_ += kCKeywordRestrict;
os() << ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>()); str_ += " ";
str_ += ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>());
} else if (arg.is_var()) { } else if (arg.is_var()) {
if (arg.var_arg()->type().is_cpp_handle()) { if (arg.var_arg()->type().is_cpp_handle()) {
os() << kCKeywordRestrict; str_ += kCKeywordRestrict;
} }
os() << GetTypeRepr(arg.type()) << " "; str_ += GetTypeRepr(arg.type());
os() << arg.name(); str_ += " ";
str_ += arg.name();
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
...@@ -230,7 +237,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -230,7 +237,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
PrintIncludes(); PrintIncludes();
if (for_nvrtc_) { if (for_nvrtc_) {
os() << "\nextern \"C\" {\n\n"; str_ += "\nextern \"C\" {\n\n";
} }
PrintBuiltinCodes(); PrintBuiltinCodes();
...@@ -243,27 +250,30 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -243,27 +250,30 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
} }
if (for_nvrtc_) { if (for_nvrtc_) {
os() << "\n\n}"; str_ += "\n\n}";
} }
return str_;
return ss_.str();
} }
void CodeGenCUDA_Dev::PrintIncludes() { os() << GetSourceHeader(); } void CodeGenCUDA_Dev::PrintIncludes() { str_ += GetSourceHeader(); }
void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
CHECK_NE(buffer->type(), Void()); CHECK_NE(buffer->type(), Void());
auto print_gpu_memory = [&](const std::string &mark) { auto print_gpu_memory = [&](const std::string &mark) {
os() << mark << GetTypeRepr(buffer->dtype) << " " << buffer->name << " "; str_ += mark;
str_ += GetTypeRepr(buffer->dtype);
str_ += " ";
str_ += buffer->name;
str_ += " ";
os() << "[ "; str_ += "[ ";
Expr buffer_size(1); Expr buffer_size(1);
for (int i = 0; i < buffer->shape.size(); i++) { for (int i = 0; i < buffer->shape.size(); i++) {
buffer_size = buffer_size * buffer->shape[i]; buffer_size = buffer_size * buffer->shape[i];
} }
optim::Simplify(&buffer_size); optim::Simplify(&buffer_size);
Print(buffer_size); IrPrinter::Visit(buffer_size);
os() << " ]"; str_ += " ]";
}; };
switch (buffer->memory_type) { switch (buffer->memory_type) {
case ir::MemoryType::GPUShared: case ir::MemoryType::GPUShared:
...@@ -281,46 +291,47 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { ...@@ -281,46 +291,47 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
} }
void CodeGenCUDA_Dev::Visit(const ir::Call *op) { void CodeGenCUDA_Dev::Visit(const ir::Call *op) {
os() << op->name + "("; str_ += op->name;
str_ += "(";
if (!op->read_args.empty()) { if (!op->read_args.empty()) {
for (int i = 0; i < op->read_args.size() - 1; i++) { for (int i = 0; i < op->read_args.size() - 1; i++) {
auto &arg = op->read_args[i]; auto &arg = op->read_args[i];
if (arg.as_tensor()) { if (arg.as_tensor()) {
os() << arg.as_tensor()->name; str_ += arg.as_tensor()->name;
os() << ", "; str_ += ", ";
} else { } else {
Print(arg); IrPrinter::Visit(arg);
os() << ", "; str_ += ", ";
} }
} }
if (op->read_args.back().as_tensor()) { if (op->read_args.back().as_tensor()) {
os() << op->read_args.back().as_tensor()->name; str_ += op->read_args.back().as_tensor()->name;
} else { } else {
Print(op->read_args.back()); IrPrinter::Visit(op->read_args.back());
} }
} }
if (!op->write_args.empty()) { if (!op->write_args.empty()) {
os() << ", "; str_ += ", ";
for (int i = 0; i < op->write_args.size() - 1; i++) { for (int i = 0; i < op->write_args.size() - 1; i++) {
auto &arg = op->write_args[i]; auto &arg = op->write_args[i];
if (arg.as_tensor()) { if (arg.as_tensor()) {
os() << arg.as_tensor()->name; str_ += arg.as_tensor()->name;
os() << ", "; str_ += ", ";
} else { } else {
Print(arg); IrPrinter::Visit(arg);
os() << ", "; str_ += ", ";
} }
} }
if (op->write_args.back().as_tensor()) { if (op->write_args.back().as_tensor()) {
os() << op->write_args.back().as_tensor()->name; str_ += op->write_args.back().as_tensor()->name;
} else { } else {
Print(op->write_args.back()); IrPrinter::Visit(op->write_args.back());
} }
} }
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::Visit(const ir::Let *op) { void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
...@@ -331,20 +342,21 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { ...@@ -331,20 +342,21 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
if (op->type().is_customized() && if (op->type().is_customized() &&
utils::Startswith(op->type().customized_type(), utils::Startswith(op->type().customized_type(),
common::customized_type::kcuda_builtin_vector_t)) { common::customized_type::kcuda_builtin_vector_t)) {
os() << GetTypeRepr(op->type()); str_ += GetTypeRepr(op->type());
if (op->type().is_cpp_handle()) { if (op->type().is_cpp_handle()) {
os() << " " << kCKeywordRestrict; str_ += " ";
str_ += kCKeywordRestrict;
} }
os() << " "; str_ += " ";
Print(op->symbol); IrPrinter::Visit(op->symbol);
vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol)); vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol));
// skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not // skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not
// overloaded. // overloaded.
if (op->body.As<ir::IntImm>() && op->body.As<ir::IntImm>()->value == 0) { if (op->body.As<ir::IntImm>() && op->body.As<ir::IntImm>()->value == 0) {
return; return;
} }
os() << " = "; str_ += " = ";
Print(op->body); IrPrinter::Visit(op->body);
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -374,10 +386,14 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ...@@ -374,10 +386,14 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op,
return false; return false;
} }
if (is_store && tensor->type().is_cpp_handle()) { if (is_store && tensor->type().is_cpp_handle()) {
os() << tensor->name << "[" << index << "]"; str_ += tensor->name;
str_ += "[";
str_ += std::to_string(index);
str_ += "]";
} else { } else {
os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") str_ += tensor->name;
<< index2suffix[index]; str_ += (tensor->type().is_cpp_handle() ? "->" : ".");
str_ += index2suffix[index];
} }
return true; return true;
} }
...@@ -396,8 +412,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) { ...@@ -396,8 +412,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) {
// accesses element at a cuda built-in vector, others still resolve to // accesses element at a cuda built-in vector, others still resolve to
// CodeGenC // CodeGenC
if (PrintBuiltinVectorAccess(op, op->index(), true)) { if (PrintBuiltinVectorAccess(op, op->index(), true)) {
os() << " = "; str_ += " = ";
Print(op->value); IrPrinter::Visit(op->value);
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
......
...@@ -54,7 +54,7 @@ class CodeGenCUDA_Dev : public CodeGenC { ...@@ -54,7 +54,7 @@ class CodeGenCUDA_Dev : public CodeGenC {
//! Compile on NVRTC. //! Compile on NVRTC.
std::string Compile(const ir::Module& module, bool for_nvrtc = true); std::string Compile(const ir::Module& module, bool for_nvrtc = true);
std::string Compile(const ir::LoweredFunc& func); void Compile(const ir::LoweredFunc& func);
/** /**
* \brief Print a function argument in CUDA syntax. Currently, just some * \brief Print a function argument in CUDA syntax. Currently, just some
......
...@@ -44,14 +44,24 @@ struct Type::Storage { ...@@ -44,14 +44,24 @@ struct Type::Storage {
Type::~Type() {} Type::~Type() {}
std::ostream &operator<<(std::ostream &os, const Type &t) { std::string Type::to_string() const {
if (t.is_cpp_const()) os << "const "; std::string ret = "";
os << Type2Str(t); if (is_cpp_const()) ret += "const ";
ret += Type2Str(*this);
if (lanes() > 1) {
ret += "<";
ret += std::to_string(lanes());
ret += ">";
}
if (is_cpp_handle()) ret += "*";
if (is_cpp_handle2()) ret += "**";
if (t.lanes() > 1) os << "<" << t.lanes() << ">"; return ret;
if (t.is_cpp_handle()) os << "*"; }
if (t.is_cpp_handle2()) os << "**";
std::ostream &operator<<(std::ostream &os, const Type &t) {
os << t.to_string();
return os; return os;
} }
......
...@@ -149,6 +149,8 @@ struct Type { ...@@ -149,6 +149,8 @@ struct Type {
//! Check if a dtype is supported in CINN yet. //! Check if a dtype is supported in CINN yet.
bool is_supported() const; bool is_supported() const;
std::string to_string() const;
friend std::ostream& operator<<(std::ostream& os, const Type& t); friend std::ostream& operator<<(std::ostream& os, const Type& t);
~Type(); ~Type();
......
...@@ -32,83 +32,104 @@ namespace ir { ...@@ -32,83 +32,104 @@ namespace ir {
using common::bfloat16; using common::bfloat16;
using common::float16; using common::float16;
void IrPrinter::Print(Expr e) { IRVisitorRequireReImpl::Visit(&e); } void IrPrinter::Print(const Expr &e) {
IRVisitorRequireReImpl::Visit(&e);
os_ << str_;
str_ = "";
}
void IrPrinter::Print(const std::vector<Expr> &exprs, void IrPrinter::Print(const std::vector<Expr> &exprs,
const std::string &splitter) { const std::string &splitter) {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) { for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Print(exprs[i]); Visit(exprs[i]);
os_ << splitter; str_ += splitter;
} }
if (!exprs.empty()) Print(exprs.back()); if (!exprs.empty()) Visit(exprs.back());
os_ << str_;
str_ = "";
} }
void IrPrinter::Visit(const IntImm *x) { void IrPrinter::Visit(const IntImm *x) {
if (x->type().is_int(64)) { if (x->type().is_int(64)) {
os_ << x->value << "ll"; str_ += std::to_string(x->value);
str_ += "ll";
} else if (x->type().is_int(32)) { } else if (x->type().is_int(32)) {
os_ << x->value; str_ += std::to_string(x->value);
} else if (x->type().is_int(16)) { } else if (x->type().is_int(16)) {
os_ << "(int16_t)" << x->value; str_ += "(int16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_int(8)) { } else if (x->type().is_int(8)) {
os_ << "(int8_t)" << x->value; str_ += "(int8_t)";
str_ += std::to_string(x->value);
} else { } else {
LOG(FATAL) << "Not support int type: " << x->type(); LOG(FATAL) << "Not support int type: " << x->type();
} }
} }
void IrPrinter::Visit(const UIntImm *x) { void IrPrinter::Visit(const UIntImm *x) {
if (x->type().is_uint(64)) { if (x->type().is_uint(64)) {
os_ << x->value << "ull"; str_ += std::to_string(x->value);
str_ += "ull";
} else if (x->type().is_uint(32)) { } else if (x->type().is_uint(32)) {
os_ << x->value; str_ += std::to_string(x->value);
} else if (x->type().is_uint(16)) { } else if (x->type().is_uint(16)) {
os_ << "(uint16_t)" << x->value; str_ += "(uint16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(8)) { } else if (x->type().is_uint(8)) {
os_ << "(uint8_t)" << x->value; str_ += "(uint8_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(1)) { } else if (x->type().is_uint(1)) {
if (x->value) { if (x->value) {
os_ << "true"; str_ += "true";
} else { } else {
os_ << "false"; str_ += "false";
} }
} else { } else {
LOG(FATAL) << "Not support uint type: " << x->type(); LOG(FATAL) << "Not support uint type: " << x->type();
} }
} }
void IrPrinter::Visit(const FloatImm *x) { void IrPrinter::Visit(const FloatImm *x) {
std::ostringstream ss;
if (x->type().is_float16()) { if (x->type().is_float16()) {
if (std::isinf(x->value)) { if (std::isinf(x->value)) {
os_ << "cinn::common::raw_uint16_to_float16(0x7c00)"; ss << "cinn::common::raw_uint16_to_float16(0x7c00)";
} else if (std::isnan(x->value)) { } else if (std::isnan(x->value)) {
os_ << "cinn::common::raw_uint16_to_float16(0x7e00)"; ss << "cinn::common::raw_uint16_to_float16(0x7e00)";
} else { } else {
os_ << "(float16)" ss << "(float16)";
<< std::setprecision(std::numeric_limits<float16>::max_digits10) ss << std::setprecision(std::numeric_limits<float16>::max_digits10);
<< static_cast<float16>(x->value) << "f"; ss << static_cast<float16>(x->value) << "f";
} }
} else if (x->type().is_bfloat16()) { } else if (x->type().is_bfloat16()) {
if (std::isinf(x->value)) { if (std::isinf(x->value)) {
os_ << "cinn::common::raw_uint16_to_bfloat16(0x7F80)"; ss << "cinn::common::raw_uint16_to_bfloat16(0x7F80)";
} else if (std::isnan(x->value)) { } else if (std::isnan(x->value)) {
os_ << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)"; ss << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)";
} else { } else {
os_ << "(bfloat16)" ss << "(bfloat16)";
<< std::setprecision(std::numeric_limits<bfloat16>::max_digits10) ss << std::setprecision(std::numeric_limits<bfloat16>::max_digits10);
<< static_cast<bfloat16>(x->value) << "f"; ss << static_cast<bfloat16>(x->value) << "f";
} }
} else if (x->type().is_float(32)) { } else if (x->type().is_float(32)) {
os_ << std::setprecision(std::numeric_limits<float>::max_digits10) ss << std::setprecision(std::numeric_limits<float>::max_digits10);
<< std::showpoint << x->value; ss << std::showpoint;
ss << x->value;
if (std::isfinite(x->value)) { if (std::isfinite(x->value)) {
os_ << "f"; ss << "f";
} }
} else if (x->type().is_float(64)) { } else if (x->type().is_float(64)) {
os_ << std::setprecision(std::numeric_limits<double>::max_digits10) ss << std::setprecision(std::numeric_limits<double>::max_digits10);
<< std::showpoint << x->value; ss << std::showpoint;
ss << x->value;
} else { } else {
LOG(FATAL) << "Not support float type: " << x->type(); LOG(FATAL) << "Not support float type: " << x->type();
} }
str_ += ss.str();
}
void IrPrinter::Visit(const StringImm *x) {
str_ += "\"";
str_ += x->value;
str_ += "\"";
} }
void IrPrinter::Visit(const StringImm *x) { os_ << "\"" << x->value << "\""; }
void IrPrinter::Visit(const Add *x) { PrintBinaryOp("+", x); } void IrPrinter::Visit(const Add *x) { PrintBinaryOp("+", x); }
void IrPrinter::Visit(const Sub *x) { PrintBinaryOp("-", x); } void IrPrinter::Visit(const Sub *x) { PrintBinaryOp("-", x); }
void IrPrinter::Visit(const Mul *x) { PrintBinaryOp("*", x); } void IrPrinter::Visit(const Mul *x) { PrintBinaryOp("*", x); }
...@@ -123,36 +144,38 @@ void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); } ...@@ -123,36 +144,38 @@ void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); }
void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); } void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); }
void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); } void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); }
void IrPrinter::Visit(const Not *x) { void IrPrinter::Visit(const Not *x) {
os_ << "!"; str_ += "!";
Print(x->v()); Visit(x->v());
} }
void IrPrinter::Visit(const Min *x) { void IrPrinter::Visit(const Min *x) {
os_ << "cinn_min("; str_ += "cinn_min(";
Print(x->a()); Visit(x->a());
os_ << ", "; str_ += ", ";
Print(x->b()); Visit(x->b());
os_ << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Max *x) { void IrPrinter::Visit(const Max *x) {
os_ << "cinn_max("; str_ += "cinn_max(";
Print(x->a()); Visit(x->a());
os_ << ", "; str_ += ", ";
Print(x->b()); Visit(x->b());
os_ << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Minus *x) { void IrPrinter::Visit(const Minus *x) {
os_ << "-("; str_ += "-(";
Print(x->v()); Visit(x->v());
os_ << ")"; str_ += ")";
} }
void IrPrinter::Visit(const For *x) { void IrPrinter::Visit(const For *x) {
if (x->is_parallel()) { if (x->is_parallel()) {
os() << "parallel for ("; str_ += "parallel for (";
} else if (x->is_unrolled()) { } else if (x->is_unrolled()) {
os() << "unroll for ("; str_ += "unroll for (";
} else if (x->is_vectorized()) { } else if (x->is_vectorized()) {
int factor = x->vectorize_info().factor; int factor = x->vectorize_info().factor;
os() << "vectorize[" << factor << "] for ("; str_ += "vectorize[";
str_ += std::to_string(factor);
str_ += "] for (";
} else if (x->is_binded()) { } else if (x->is_binded()) {
auto &bind_info = x->bind_info(); auto &bind_info = x->bind_info();
if (bind_info.valid()) { if (bind_info.valid()) {
...@@ -160,181 +183,189 @@ void IrPrinter::Visit(const For *x) { ...@@ -160,181 +183,189 @@ void IrPrinter::Visit(const For *x) {
auto for_type = bind_info.for_type; auto for_type = bind_info.for_type;
std::string prefix = std::string prefix =
for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx."; for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx.";
os() << "thread_bind[" << prefix << axis_name << "] for ("; str_ += "thread_bind[";
str_ += prefix;
str_ += axis_name;
str_ += "] for (";
} else { } else {
os() << "thread_bind[invalid info] for ("; str_ += "thread_bind[invalid info] for (";
} }
} else if (x->is_serial()) { } else if (x->is_serial()) {
os() << "serial for ("; str_ += "serial for (";
} else if (x->is_default()) { } else if (x->is_default()) {
os() << "default for ("; str_ += "default for (";
} else { } else {
os() << "for ("; str_ += "for (";
} }
Print(x->loop_var); Visit(x->loop_var);
os_ << ", "; str_ += ", ";
Print(x->min); Visit(x->min);
os_ << ", "; str_ += ", ";
Print(x->extent); Visit(x->extent);
os_ << ")\n"; str_ += ")\n";
DoIndent(); DoIndent();
Print(x->body); Visit(x->body);
} }
void IrPrinter::Visit(const PolyFor *x) { void IrPrinter::Visit(const PolyFor *x) {
if (x->is_parallel()) { if (x->is_parallel()) {
os() << "parallel poly_for ("; str_ += "parallel poly_for (";
} else { } else {
os() << "poly_for ("; str_ += "poly_for (";
} }
Print(x->iterator); Visit(x->iterator);
os_ << ", "; str_ += ", ";
Print(x->init); Visit(x->init);
os_ << ", "; str_ += ", ";
Print(x->condition); Visit(x->condition);
os_ << ", "; str_ += ", ";
Print(x->inc); Visit(x->inc);
os_ << ")\n"; str_ += ")\n";
DoIndent(); DoIndent();
Print(x->body); Visit(x->body);
} }
void IrPrinter::Visit(const IfThenElse *x) { void IrPrinter::Visit(const IfThenElse *x) {
os_ << "if ("; str_ += "if (";
Print(x->condition); Visit(x->condition);
os_ << ") {\n"; str_ += ") {\n";
IncIndent(); IncIndent();
DoIndent(); DoIndent();
Print(x->true_case); Visit(x->true_case);
DecIndent(); DecIndent();
os() << "\n"; str_ += "\n";
DoIndent(); DoIndent();
os() << "}"; str_ += "}";
if (x->false_case.defined()) { if (x->false_case.defined()) {
os_ << " else {\n"; str_ += " else {\n";
IncIndent(); IncIndent();
DoIndent(); DoIndent();
Print(x->false_case); Visit(x->false_case);
os() << "\n"; str_ += "\n";
DecIndent(); DecIndent();
DoIndent(); DoIndent();
os_ << "}"; str_ += "}";
} }
} }
void IrPrinter::Visit(const Block *x) { void IrPrinter::Visit(const Block *x) {
os_ << "{\n"; str_ += "{\n";
IncIndent(); IncIndent();
for (std::size_t i = 0; !x->stmts.empty() && i + 1 < x->stmts.size(); i++) { for (std::size_t i = 0; !x->stmts.empty() && i + 1 < x->stmts.size(); i++) {
DoIndent(); DoIndent();
Print(x->stmts[i]); Visit(x->stmts[i]);
os_ << "\n"; str_ += "\n";
} }
if (!x->stmts.empty()) { if (!x->stmts.empty()) {
DoIndent(); DoIndent();
Print(x->stmts.back()); Visit(x->stmts.back());
} }
DecIndent(); DecIndent();
os_ << "\n"; str_ += "\n";
DoIndent(); DoIndent();
os_ << "}"; str_ += "}";
} }
void IrPrinter::Visit(const Call *x) { void IrPrinter::Visit(const Call *x) {
os_ << x->name << "("; str_ += x->name;
str_ += "(";
if (!x->read_args.empty()) { if (!x->read_args.empty()) {
for (std::size_t i = 0; i + 1 < x->read_args.size(); i++) { for (std::size_t i = 0; i + 1 < x->read_args.size(); i++) {
Print(x->read_args[i]); Visit(x->read_args[i]);
os_ << ", "; str_ += ", ";
} }
Print(x->read_args.back()); Visit(x->read_args.back());
} }
if (!x->write_args.empty()) { if (!x->write_args.empty()) {
if (!x->read_args.empty()) os() << ", "; if (!x->read_args.empty()) str_ += ", ";
for (std::size_t i = 0; i + 1 < x->write_args.size(); i++) { for (std::size_t i = 0; i + 1 < x->write_args.size(); i++) {
Print(x->write_args[i]); Visit(x->write_args[i]);
os_ << ", "; str_ += ", ";
} }
Print(x->write_args.back()); Visit(x->write_args.back());
} }
os_ << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Cast *x) { void IrPrinter::Visit(const Cast *x) {
os() << x->type(); str_ += x->type().to_string();
os() << "("; str_ += "(";
os() << x->v(); Visit(x->v());
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const _Module_ *x) {} void IrPrinter::Visit(const _Module_ *x) {}
void IrPrinter::Visit(const _Var_ *x) { os_ << x->name; } void IrPrinter::Visit(const _Var_ *x) { str_ += x->name; }
void IrPrinter::Visit(const Alloc *x) { void IrPrinter::Visit(const Alloc *x) {
auto *buffer = x->destination.As<ir::_Buffer_>(); auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer); CHECK(buffer);
os_ << "alloc(" << buffer->name << ", "; str_ += "alloc(";
Print(x->extents); str_ += buffer->name;
os_ << ")"; str_ += ", ";
Visit(x->extents);
str_ += ")";
} }
void IrPrinter::Visit(const Select *x) { void IrPrinter::Visit(const Select *x) {
os_ << "select("; str_ += "select(";
Print(x->condition); Visit(x->condition);
os_ << ", "; str_ += ", ";
Print(x->true_value); Visit(x->true_value);
os_ << ", "; str_ += ", ";
Print(x->false_value); Visit(x->false_value);
os_ << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Load *x) { void IrPrinter::Visit(const Load *x) {
if (x->is_addr_tensor()) { if (x->is_addr_tensor()) {
auto *tensor = x->tensor.As<ir::_Tensor_>(); auto *tensor = x->tensor.As<ir::_Tensor_>();
CHECK(tensor); CHECK(tensor);
os_ << tensor->name; str_ += tensor->name;
} else if (x->is_addr_scalar()) { } else if (x->is_addr_scalar()) {
Print(x->tensor); Visit(x->tensor);
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
os_ << "["; str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) { for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Print(x->indices[i]); Visit(x->indices[i]);
os() << ", "; str_ += ", ";
} }
if (!x->indices.empty()) Print(x->indices.back()); if (!x->indices.empty()) Visit(x->indices.back());
os_ << "]"; str_ += "]";
} }
void IrPrinter::Visit(const Store *x) { void IrPrinter::Visit(const Store *x) {
if (x->is_addr_tensor()) { if (x->is_addr_tensor()) {
auto *tensor_node = x->tensor.As<ir::_Tensor_>(); auto *tensor_node = x->tensor.As<ir::_Tensor_>();
CHECK(tensor_node); CHECK(tensor_node);
os_ << tensor_node->name; str_ += tensor_node->name;
} else if (x->is_addr_scalar()) { } else if (x->is_addr_scalar()) {
Print(x->tensor); Visit(x->tensor);
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
os_ << "["; str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) { for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Print(x->indices[i]); Visit(x->indices[i]);
os() << ", "; str_ += ", ";
} }
if (!x->indices.empty()) Print(x->indices.back()); if (!x->indices.empty()) Visit(x->indices.back());
os_ << "] = "; str_ += "] = ";
Print(x->value); Visit(x->value);
} }
void IrPrinter::Visit(const Free *x) { void IrPrinter::Visit(const Free *x) {
auto *buffer = x->destination.As<ir::_Buffer_>(); auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer); CHECK(buffer);
os_ << "free(" << buffer->name << ")"; str_ += "free(";
str_ += buffer->name;
str_ += ")";
} }
void IrPrinter::DoIndent() { os_ << std::string(indent_, ' '); } void IrPrinter::DoIndent() { str_ += std::string(indent_, ' '); }
void IrPrinter::IncIndent() { indent_ += indent_unit; } void IrPrinter::IncIndent() { indent_ += indent_unit; }
void IrPrinter::DecIndent() { indent_ -= indent_unit; } void IrPrinter::DecIndent() { indent_ -= indent_unit; }
...@@ -345,126 +376,139 @@ void IrPrinter::Visit(const _Buffer_ *x) { ...@@ -345,126 +376,139 @@ void IrPrinter::Visit(const _Buffer_ *x) {
std::back_inserter(dim_names), std::back_inserter(dim_names),
[&](const Expr &x) { return utils::GetStreamCnt(x); }); [&](const Expr &x) { return utils::GetStreamCnt(x); });
os_ << "_Buffer_<" << x->type() << ": " << utils::Join(dim_names, ",") << ">(" str_ += "_Buffer_<";
<< x->name << ")";
str_ += x->type().to_string();
str_ += ": ";
str_ += utils::Join(dim_names, ",");
str_ += ">(";
str_ += x->name;
str_ += ")";
} }
void IrPrinter::Visit(const _Tensor_ *x) { void IrPrinter::Visit(const _Tensor_ *x) {
os_ << "Tensor("; str_ += "Tensor(";
os() << x->name << ", "; str_ += x->name;
os() << "["; str_ += ", ";
str_ += "[";
if (!x->shape.empty()) { if (!x->shape.empty()) {
for (std::size_t i = 0; i + 1 < x->shape.size(); i++) { for (std::size_t i = 0; i + 1 < x->shape.size(); i++) {
Print(x->shape[i]); Visit(x->shape[i]);
os() << ","; str_ += ",";
} }
Print(x->shape.back()); Visit(x->shape.back());
} }
os_ << "])"; str_ += "])";
} }
void IrPrinter::Visit(const _LoweredFunc_ *f) { void IrPrinter::Visit(const _LoweredFunc_ *f) {
os_ << "function " << f->name << " "; str_ += "function ";
str_ += f->name;
str_ += " ";
std::vector<std::string> arg_names; std::vector<std::string> arg_names;
for (auto &arg : f->args) { for (auto &arg : f->args) {
arg_names.push_back(arg.name()); arg_names.push_back(arg.name());
} }
os_ << "(" << utils::Join(arg_names, ", ") << ")\n"; str_ += "(";
str_ += utils::Join(arg_names, ", ");
str_ += ")\n";
Print(f->body); Visit(f->body);
} }
void IrPrinter::Visit(const Let *f) { void IrPrinter::Visit(const Let *f) {
CHECK(f->type().valid()); CHECK(f->type().valid());
os() << f->type() << " "; str_ += f->type().to_string();
Print(f->symbol); str_ += " ";
Visit(f->symbol);
if (f->body.defined()) { if (f->body.defined()) {
os() << " = "; str_ += " = ";
Print(f->body); Visit(f->body);
} }
} }
void IrPrinter::Visit(const Reduce *f) { void IrPrinter::Visit(const Reduce *f) {
os() << "Reduce("; str_ += "Reduce(";
switch (f->reduce_type) { switch (f->reduce_type) {
case Reduce::ReduceType::kSum: case Reduce::ReduceType::kSum:
os() << "sum"; str_ += "sum";
break; break;
case Reduce::ReduceType::kSub: case Reduce::ReduceType::kSub:
os() << "sub"; str_ += "sub";
break; break;
case Reduce::ReduceType::kDiv: case Reduce::ReduceType::kDiv:
os() << "Div"; str_ += "Div";
break; break;
case Reduce::ReduceType::kMul: case Reduce::ReduceType::kMul:
os() << "Mul"; str_ += "Mul";
break; break;
case Reduce::ReduceType::kMax: case Reduce::ReduceType::kMax:
os() << "Max"; str_ += "Max";
break; break;
case Reduce::ReduceType::kMin: case Reduce::ReduceType::kMin:
os() << "Min"; str_ += "Min";
break; break;
case Reduce::ReduceType::kAll: case Reduce::ReduceType::kAll:
os() << "&&"; str_ += "&&";
break; break;
case Reduce::ReduceType::kAny: case Reduce::ReduceType::kAny:
os() << "||"; str_ += "||";
break; break;
} }
os() << ", "; str_ += ", ";
Print(f->body); Visit(f->body);
os() << ","; str_ += ",";
Print(f->init); Visit(f->init);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Ramp *x) { void IrPrinter::Visit(const Ramp *x) {
os() << "Ramp("; str_ += "Ramp(";
Print(x->base); Visit(x->base);
os() << ","; str_ += ",";
Print(x->stride); Visit(x->stride);
os() << ","; str_ += ",";
os() << x->lanes; str_ += std::to_string(x->lanes);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Broadcast *x) { void IrPrinter::Visit(const Broadcast *x) {
os() << "Broadcast("; str_ += "Broadcast(";
Print(x->value); Visit(x->value);
os() << ","; str_ += ",";
os() << x->lanes; str_ += std::to_string(x->lanes);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const FracOp *x) { void IrPrinter::Visit(const FracOp *x) {
os() << "("; str_ += "(";
Print(x->a()); Visit(x->a());
os() << " / "; str_ += " / ";
Print(x->b()); Visit(x->b());
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Product *x) { void IrPrinter::Visit(const Product *x) {
os() << "("; str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) { for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Print(x->operand(i)); Visit(x->operand(i));
os() << " * "; str_ += " * ";
} }
if (!x->operands().empty()) Print(x->operands().back()); if (!x->operands().empty()) Visit(x->operands().back());
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const Sum *x) { void IrPrinter::Visit(const Sum *x) {
os() << "("; str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) { for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Print(x->operand(i)); Visit(x->operand(i));
os() << " + "; str_ += " + ";
} }
if (!x->operands().empty()) Print(x->operands().back()); if (!x->operands().empty()) Visit(x->operands().back());
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const PrimitiveNode *x) { void IrPrinter::Visit(const PrimitiveNode *x) {
os() << x->name << "("; str_ += x->name;
str_ += "(";
std::vector<std::string> args_repr; std::vector<std::string> args_repr;
for (auto &args : x->arguments) { for (auto &args : x->arguments) {
std::vector<std::string> arg_repr; std::vector<std::string> arg_repr;
...@@ -474,41 +518,46 @@ void IrPrinter::Visit(const PrimitiveNode *x) { ...@@ -474,41 +518,46 @@ void IrPrinter::Visit(const PrimitiveNode *x) {
args_repr.push_back(utils::Join(arg_repr, ",")); args_repr.push_back(utils::Join(arg_repr, ","));
} }
os() << utils::Join(args_repr, ","); str_ += utils::Join(args_repr, ",");
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const _BufferRange_ *x) { void IrPrinter::Visit(const _BufferRange_ *x) {
auto *buffer = x->buffer.As<ir::_Buffer_>(); auto *buffer = x->buffer.As<ir::_Buffer_>();
CHECK(buffer); CHECK(buffer);
os() << buffer->name << "["; str_ += buffer->name;
str_ += "[";
for (std::size_t i = 0; i < x->ranges.size(); i++) { for (std::size_t i = 0; i < x->ranges.size(); i++) {
if (i) os() << ", "; if (i) str_ += ", ";
auto &range = x->ranges[i]; auto &range = x->ranges[i];
os() << range->name << "("; str_ += range->name;
str_ += "(";
if (range->lower_bound.defined()) { if (range->lower_bound.defined()) {
os() << range->lower_bound << ":"; Visit(range->lower_bound);
str_ += ":";
} else { } else {
os() << "undefined:"; str_ += "undefined:";
} }
if (range->upper_bound.defined()) { if (range->upper_bound.defined()) {
os() << range->upper_bound; Visit(range->upper_bound);
} else { } else {
os() << "undefined"; str_ += "undefined";
} }
os() << ")"; str_ += ")";
} }
os() << "]"; str_ += "]";
} }
void IrPrinter::Visit(const ScheduleBlock *x) {} void IrPrinter::Visit(const ScheduleBlock *x) {}
void IrPrinter::Visit(const ScheduleBlockRealize *x) { void IrPrinter::Visit(const ScheduleBlockRealize *x) {
auto *schedule_block = x->schedule_block.As<ScheduleBlock>(); auto *schedule_block = x->schedule_block.As<ScheduleBlock>();
os() << "ScheduleBlock(" << schedule_block->name << ")\n"; str_ += "ScheduleBlock(";
str_ += schedule_block->name;
str_ += ")\n";
DoIndent(); DoIndent();
os() << "{\n"; str_ += "{\n";
// print block vars and bindings // print block vars and bindings
auto iter_vars = schedule_block->iter_vars; auto iter_vars = schedule_block->iter_vars;
auto iter_values = x->iter_values; auto iter_values = x->iter_values;
...@@ -516,54 +565,61 @@ void IrPrinter::Visit(const ScheduleBlockRealize *x) { ...@@ -516,54 +565,61 @@ void IrPrinter::Visit(const ScheduleBlockRealize *x) {
IncIndent(); IncIndent();
if (!iter_vars.empty()) DoIndent(); if (!iter_vars.empty()) DoIndent();
for (std::size_t i = 0; i < iter_vars.size(); i++) { for (std::size_t i = 0; i < iter_vars.size(); i++) {
if (i) os() << ", "; if (i) str_ += ", ";
os() << iter_vars[i]->name; str_ += iter_vars[i]->name;
} }
if (!iter_vars.empty()) os() << " = axis.bind("; if (!iter_vars.empty()) str_ += " = axis.bind(";
for (std::size_t i = 0; i < iter_values.size(); i++) { for (std::size_t i = 0; i < iter_values.size(); i++) {
if (i) os() << ", "; if (i) str_ += ", ";
os() << iter_values[i]; Visit(iter_values[i]);
} }
if (!iter_vars.empty()) os() << ")\n"; if (!iter_vars.empty()) str_ += ")\n";
// print block body // print block body
if (!schedule_block->read_buffers.empty()) { if (!schedule_block->read_buffers.empty()) {
DoIndent(); DoIndent();
os() << "read_buffers("; str_ += "read_buffers(";
auto &read_buffers = schedule_block->read_buffers; auto &read_buffers = schedule_block->read_buffers;
for (std::size_t i = 0; i < read_buffers.size(); i++) { for (std::size_t i = 0; i < read_buffers.size(); i++) {
if (i) os() << ", "; if (i) str_ += ", ";
Print(read_buffers[i]); Visit(read_buffers[i]);
} }
os() << ")\n"; str_ += ")\n";
} }
if (!schedule_block->write_buffers.empty()) { if (!schedule_block->write_buffers.empty()) {
DoIndent(); DoIndent();
os() << "write_buffers("; str_ += "write_buffers(";
auto &write_buffers = schedule_block->write_buffers; auto &write_buffers = schedule_block->write_buffers;
for (std::size_t i = 0; i < write_buffers.size(); i++) { for (std::size_t i = 0; i < write_buffers.size(); i++) {
if (i) os() << ", "; if (i) str_ += ", ";
Print(write_buffers[i]); Visit(write_buffers[i]);
} }
os() << ")\n"; str_ += ")\n";
} }
if (!schedule_block->attrs.empty()) { if (!schedule_block->attrs.empty()) {
DoIndent(); DoIndent();
os() << "attrs("; str_ += "attrs(";
bool comma = false; bool comma = false;
for (auto &&kv : schedule_block->attrs) { for (auto &&kv : schedule_block->attrs) {
if (comma) os() << ", "; if (comma) str_ += ", ";
os() << kv.first << ":"; str_ += kv.first;
absl::visit([this](auto &&arg) { this->os() << arg; }, kv.second); str_ += ":";
absl::visit(
[this](auto &&arg) {
std::ostringstream ss;
ss << arg;
this->str_ += ss.str();
},
kv.second);
comma = true; comma = true;
} }
os() << ")\n"; str_ += ")\n";
} }
DoIndent(); DoIndent();
Print(schedule_block->body); Visit(schedule_block->body);
os() << "\n"; str_ += "\n";
DecIndent(); DecIndent();
DoIndent(); DoIndent();
os() << "}"; str_ += "}";
} }
void IrPrinter::Visit(const IntrinsicOp *x) { void IrPrinter::Visit(const IntrinsicOp *x) {
...@@ -578,50 +634,52 @@ void IrPrinter::Visit(const IntrinsicOp *x) { ...@@ -578,50 +634,52 @@ void IrPrinter::Visit(const IntrinsicOp *x) {
} }
} }
void IrPrinter::Visit(const intrinsics::BufferGetDataHandle *x) { void IrPrinter::Visit(const intrinsics::BufferGetDataHandle *x) {
os() << runtime::intrinsic::buffer_get_data_handle; str_ += runtime::intrinsic::buffer_get_data_handle;
Print(x->buffer); Visit(x->buffer);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const intrinsics::BufferGetDataConstHandle *x) { void IrPrinter::Visit(const intrinsics::BufferGetDataConstHandle *x) {
os() << runtime::intrinsic::buffer_get_data_const_handle; str_ += runtime::intrinsic::buffer_get_data_const_handle;
Print(x->buffer); Visit(x->buffer);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const intrinsics::PodValueToX *x) { void IrPrinter::Visit(const intrinsics::PodValueToX *x) {
os() << "pod_value_to_"; str_ += "pod_value_to_";
os() << x->GetOutputType(0); str_ += x->GetOutputType(0).to_string();
os() << "("; str_ += "(";
Print(x->pod_value_ptr); Visit(x->pod_value_ptr);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const intrinsics::BufferCreate *x) { void IrPrinter::Visit(const intrinsics::BufferCreate *x) {
os() << runtime::intrinsic::buffer_create; str_ += runtime::intrinsic::buffer_create;
os() << "()"; str_ += "()";
} }
void IrPrinter::Visit(const intrinsics::GetAddr *x) { void IrPrinter::Visit(const intrinsics::GetAddr *x) {
os() << "get_addr("; str_ += "get_addr(";
Print(x->data); Visit(x->data);
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) { void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) {
os() << runtime::intrinsic::args_construct_repr; str_ += runtime::intrinsic::args_construct_repr;
os() << "("; str_ += "(";
Print(std::vector<Expr>(x->args.begin(), x->args.end())); Visit(std::vector<Expr>(x->args.begin(), x->args.end()));
os() << ")"; str_ += ")";
} }
void IrPrinter::Visit(const intrinsics::BuiltinIntrin *x) { void IrPrinter::Visit(const intrinsics::BuiltinIntrin *x) {
os_ << runtime::intrinsic::builtin_intrin_repr << "_"; str_ += runtime::intrinsic::builtin_intrin_repr;
os_ << x->name << "("; str_ += "_";
str_ += x->name;
str_ += "(";
if (!x->args.empty()) { if (!x->args.empty()) {
for (std::size_t i = 0; i + 1 < x->args.size(); i++) { for (std::size_t i = 0; i + 1 < x->args.size(); i++) {
Print(x->args[i]); Visit(x->args[i]);
os_ << ", "; str_ += ", ";
} }
Print(x->args.back()); Visit(x->args.back());
} }
os_ << ")"; str_ += ")";
} }
std::ostream &operator<<(std::ostream &os, Expr a) { std::ostream &operator<<(std::ostream &os, Expr a) {
......
...@@ -30,10 +30,10 @@ namespace ir { ...@@ -30,10 +30,10 @@ namespace ir {
class Module; class Module;
struct IrPrinter : public IRVisitorRequireReImpl<void> { struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os) {} explicit IrPrinter(std::ostream &os) : os_(os), str_("") {}
//! Emit an expression on the output stream. //! Emit an expression on the output stream.
void Print(Expr e); void Print(const Expr &e);
//! Emit a expression list with , splitted. //! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs, void Print(const std::vector<Expr> &exprs,
const std::string &splitter = ", "); const std::string &splitter = ", ");
...@@ -50,6 +50,17 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> { ...@@ -50,6 +50,17 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
std::ostream &os() { return os_; } std::ostream &os() { return os_; }
void Visit(const Expr &x) { IRVisitorRequireReImpl::Visit(&x); }
void Visit(const std::vector<Expr> &exprs,
const std::string &splitter = ", ") {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
}
#define __(op__) void Visit(const op__ *x) override; #define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL(__) NODETY_FORALL(__)
#undef __ #undef __
...@@ -58,6 +69,9 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> { ...@@ -58,6 +69,9 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
INTRINSIC_KIND_FOR_EACH(__) INTRINSIC_KIND_FOR_EACH(__)
#undef __ #undef __
protected:
std::string str_;
private: private:
std::ostream &os_; std::ostream &os_;
uint16_t indent_{}; uint16_t indent_{};
...@@ -71,11 +85,13 @@ std::ostream &operator<<(std::ostream &os, const Module &m); ...@@ -71,11 +85,13 @@ std::ostream &operator<<(std::ostream &os, const Module &m);
template <typename IRN> template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op, void IrPrinter::PrintBinaryOp(const std::string &op,
const BinaryOpNode<IRN> *x) { const BinaryOpNode<IRN> *x) {
os_ << "("; str_ += "(";
Print(x->a()); Visit(x->a());
os_ << " " + op + " "; str_ += " ";
Print(x->b()); str_ += op;
os_ << ")"; str_ += " ";
Visit(x->b());
str_ += ")";
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册