未验证 提交 68b0cf92 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】refactor codegen for cinn (#55955)

* refactor codegen for cinn

* add to_string to some type which can't be += with string

* fix multi-thread bug caused by static var

* delete dead code and comment
上级 db96ae58
......@@ -43,6 +43,7 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) {
if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader);
str_ = "";
std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source;
......@@ -52,6 +53,7 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) {
if (!outputs.c_source_name.empty()) {
auto source = Compile(module, OutputKind::CImpl);
str_ = "";
std::ofstream file(outputs.c_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name;
file << source;
......@@ -71,24 +73,20 @@ std::string CodeGenC::Compile(const ir::Module &module,
if (inline_builtin_codes_) PrintBuiltinCodes();
std::vector<ir::Buffer> buffers;
for (auto &buffer : module->buffers) {
buffers.emplace_back(buffer.as_buffer_ref());
}
for (auto &func : module.functions()) {
Compile(func);
}
} else {
LOG(FATAL) << "Not supported OutputKind";
}
return ss_.str();
return str_;
}
std::string CodeGenC::Compile(const ir::LoweredFunc &function) {
// TODO(LiuYang): Here the Ret type seems unuseful
void CodeGenC::Compile(const ir::LoweredFunc &function) {
CHECK(function.defined());
Print(function);
os() << "\n\n";
return ss_.str();
IrPrinter::Visit(function);
str_ += "\n\n";
}
std::string CodeGenC::GetTypeName(Type type) {
......@@ -164,11 +162,11 @@ void CodeGenC::Visit(const ir::Mod *op) {
if (copied.is_constant()) {
int temp = static_cast<int>(copied.get_constant());
if ((temp & (temp - 1)) == 0) {
os() << "(";
Print(op->a());
os() << " & ";
os() << std::to_string(temp - 1);
os() << ")";
str_ += "(";
IrPrinter::Visit(op->a());
str_ += " & ";
str_ += std::to_string(temp - 1);
str_ += ")";
return;
}
}
......@@ -186,9 +184,9 @@ void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Not *op) {
os() << "(!";
IrPrinter::Print(op->v());
os() << ")";
str_ += "(!";
IrPrinter::Visit(op->v());
str_ += ")";
}
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); }
void CodeGenC::Visit(const ir::For *op) {
......@@ -196,17 +194,17 @@ void CodeGenC::Visit(const ir::For *op) {
Expr min = op->min;
int num_task = 1;
if (op->is_parallel()) {
os() << "int num_task = max_concurrency();\n";
str_ += "int num_task = max_concurrency();\n";
DoIndent();
os() << "omp_set_num_threads(num_task);\n";
str_ += "omp_set_num_threads(num_task);\n";
DoIndent();
os() << "auto flambda = [=](int task_id, int num_task) -> int {\n";
str_ += "auto flambda = [=](int task_id, int num_task) -> int {\n";
IncIndent();
DoIndent();
os() << "int n_per_task = ";
str_ += "int n_per_task = ";
Expr num_task_var = Var("num_task");
Print((op->extent + num_task_var - 1) / num_task_var);
os() << ";\n";
IrPrinter::Visit((op->extent + num_task_var - 1) / num_task_var);
str_ += ";\n";
CHECK_EQ(min.as_int32(), 0);
auto task_id = Var("task_id");
auto n_per_task = Var("n_per_task");
......@@ -214,125 +212,127 @@ void CodeGenC::Visit(const ir::For *op) {
extent = (task_id + 1) * n_per_task;
DoIndent();
}
os() << "for (";
os() << GetTypeRepr(Int(32));
os() << " " << op->loop_var->name;
os() << " = ";
Print(min);
os() << "; ";
os() << op->loop_var->name;
os() << " < ";
Print(op->extent);
str_ += "for (";
str_ += GetTypeRepr(Int(32));
str_ += " ";
str_ += op->loop_var->name;
str_ += " = ";
IrPrinter::Visit(min);
str_ += "; ";
str_ += op->loop_var->name;
str_ += " < ";
IrPrinter::Visit(op->extent);
if (op->is_parallel()) {
os() << " && ";
os() << op->loop_var->name;
os() << " < ";
Print(extent);
str_ += " && ";
str_ += op->loop_var->name;
str_ += " < ";
IrPrinter::Visit(extent);
}
os() << "; ";
str_ += "; ";
os() << op->loop_var->name;
os() << " += 1";
os() << ") ";
str_ += op->loop_var->name;
str_ += " += 1";
str_ += ") ";
Print(op->body);
IrPrinter::Visit(op->body);
if (op->is_parallel()) {
os() << "\n";
str_ += "\n";
DoIndent();
os() << "return 0;\n";
str_ += "return 0;\n";
DecIndent();
DoIndent();
os() << "};\n";
os() << "#pragma omp parallel num_threads(num_task)\n";
str_ += "};\n";
str_ += "#pragma omp parallel num_threads(num_task)\n";
DoIndent();
os() << "{\n";
str_ += "{\n";
IncIndent();
DoIndent();
os() << "int task_id = omp_get_thread_num();\n";
str_ += "int task_id = omp_get_thread_num();\n";
DoIndent();
os() << "flambda(task_id, num_task);\n";
str_ += "flambda(task_id, num_task);\n";
DecIndent();
DoIndent();
os() << "}";
str_ += "}";
}
}
void CodeGenC::Visit(const ir::PolyFor *op) {
os() << "for (";
os() << GetTypeRepr(Int(32));
os() << " " << op->iterator->name;
os() << " = ";
Print(op->init);
os() << "; ";
Print(op->condition);
os() << "; ";
os() << op->iterator->name;
os() << " += ";
Print(op->inc);
os() << ") ";
Print(op->body);
str_ += "for (";
str_ += GetTypeRepr(Int(32));
str_ += " ";
str_ += op->iterator->name;
str_ += " = ";
IrPrinter::Visit(op->init);
str_ += "; ";
IrPrinter::Visit(op->condition);
str_ += "; ";
str_ += op->iterator->name;
str_ += " += ";
IrPrinter::Visit(op->inc);
str_ += ") ";
IrPrinter::Visit(op->body);
}
void CodeGenC::Visit(const ir::Select *op) {
os() << "(";
os() << "(";
Print(op->condition);
os() << ") ? ";
Print(op->true_value);
os() << " : ";
Print(op->false_value);
os() << ")";
str_ += "(";
str_ += "(";
IrPrinter::Visit(op->condition);
str_ += ") ? ";
IrPrinter::Visit(op->true_value);
str_ += " : ";
IrPrinter::Visit(op->false_value);
str_ += ")";
}
void CodeGenC::Visit(const ir::IfThenElse *op) {
os() << "if (";
Print(op->condition);
os() << ") {\n";
str_ += "if (";
IrPrinter::Visit(op->condition);
str_ += ") {\n";
if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
Print(op->true_case);
if (!op->true_case.As<ir::Block>()) os() << ";";
os() << "\n";
IrPrinter::Visit(op->true_case);
if (!op->true_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();
DoIndent();
os() << "}";
str_ += "}";
if (op->false_case.defined()) {
os() << " else {\n";
str_ += " else {\n";
if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
Print(op->false_case);
if (!op->false_case.As<ir::Block>()) os() << ";";
os() << "\n";
IrPrinter::Visit(op->false_case);
if (!op->false_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();
DoIndent();
os() << "}";
str_ += "}";
}
}
void CodeGenC::Visit(const ir::Block *op) {
os() << "{\n";
str_ += "{\n";
IncIndent();
for (int i = 0; i < op->stmts.size() - 1; i++) {
DoIndent();
Print(op->stmts[i]);
os() << ";\n";
IrPrinter::Visit(op->stmts[i]);
str_ += ";\n";
}
if (op->stmts.size() >= 1) {
DoIndent();
Print(op->stmts.back());
os() << ";";
IrPrinter::Visit(op->stmts.back());
str_ += ";";
}
DecIndent();
os() << "\n";
str_ += "\n";
DoIndent();
os() << "}";
str_ += "}";
}
void CodeGenC::Visit(const ir::Call *op) {
if (op->name == runtime::intrinsic::buffer_malloc) {
......@@ -340,13 +340,15 @@ void CodeGenC::Visit(const ir::Call *op) {
} else if (op->name == runtime::intrinsic::pod_values_to_array_repr) {
PrintCall_pod_values_to_array(op);
} else if (op->is_intrinsic_call()) {
os() << op->name << "(";
str_ += op->name;
str_ += "(";
PrintCallArgs(op);
os() << ")";
str_ += ")";
} else if (op->is_cinn_call()) { // call CINN LoweredFunc
os() << op->name << "(";
str_ += op->name;
str_ += "(";
PrintCallArgs(op);
os() << ")";
str_ += ")";
} else if (op->is_extern_call()) {
const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup(
ExternFuncID{backend_C, op->name.c_str()});
......@@ -356,9 +358,10 @@ void CodeGenC::Visit(const ir::Call *op) {
emitter.Emit(op);
} else {
CHECK(!op->read_args.empty() || !op->write_args.empty());
os() << op->name << "(";
str_ += op->name;
str_ += "(";
PrintCallArgs(op);
os() << ")";
str_ += ")";
}
} else {
CINN_NOT_IMPLEMENTED
......@@ -367,38 +370,40 @@ void CodeGenC::Visit(const ir::Call *op) {
void CodeGenC::PrintCallArgs(const ir::Call *op) {
if (!op->read_args.empty()) {
for (int i = 0; i < op->read_args.size() - 1; i++) {
Print(op->read_args[i]);
os() << ", ";
IrPrinter::Visit(op->read_args[i]);
str_ += ", ";
}
Print(op->read_args.back());
IrPrinter::Visit(op->read_args.back());
}
if (!op->write_args.empty()) {
if (!op->read_args.empty()) os() << ", ";
if (!op->read_args.empty()) str_ += ", ";
for (int i = 0; i < op->write_args.size() - 1; i++) {
Print(op->write_args[i]);
os() << ", ";
IrPrinter::Visit(op->write_args[i]);
str_ += ", ";
}
Print(op->write_args.back());
IrPrinter::Visit(op->write_args.back());
}
}
void CodeGenC::PrintCall_buffer_malloc(const ir::Call *op) {
CHECK_EQ(op->read_args.size(), 2UL);
os() << op->name << "(";
str_ += op->name;
str_ += "(";
PrintCastExpr("void*", op->read_args[0]);
os() << ", ";
os() << op->read_args[1];
os() << ")";
str_ += ", ";
IrPrinter::Visit(op->read_args[1]);
str_ += ")";
}
void CodeGenC::PrintCall_cinn_pod_value_to_(const ir::Call *op) {
CHECK_EQ(op->read_args.size(), 1UL);
os() << op->name << "(";
os() << "&(";
Print(op->read_args[0]);
os() << ")";
os() << ")";
str_ += op->name;
str_ += "(";
str_ += "&(";
IrPrinter::Visit(op->read_args[0]);
str_ += ")";
str_ += ")";
}
void CodeGenC::PrintCall_get_address(const ir::Call *op) {
......@@ -409,11 +414,11 @@ void CodeGenC::PrintCall_get_address(const ir::Call *op) {
CHECK(read_var || read_buf) << "Only Var or Buffer can get address";
if (read_var) {
if (read_var->type().lanes() <= 1) os() << "&";
os() << read_var->name;
if (read_var->type().lanes() <= 1) str_ += "&";
str_ += read_var->name;
} else if (read_buf) {
if (read_buf->type().lanes() <= 1) os() << "&";
os() << read_buf->name;
if (read_buf->type().lanes() <= 1) str_ += "&";
str_ += read_buf->name;
} else {
CINN_NOT_IMPLEMENTED
}
......@@ -432,42 +437,45 @@ void CodeGenC::PrintCall_pod_values_to_array(const ir::Call *op) {
arg_names.push_back(arg_var->name);
}
os() << "cinn_pod_value_t " << output_var->name << "[] = ";
os() << "{ ";
str_ += "cinn_pod_value_t ";
str_ += output_var->name;
str_ += "[] = ";
str_ += "{ ";
os() << utils::Join(arg_names, ", ");
str_ += utils::Join(arg_names, ", ");
os() << " }";
str_ += " }";
}
void CodeGenC::Visit(const ir::_Module_ *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; }
void CodeGenC::Visit(const ir::_Var_ *op) { str_ += op->name; }
void CodeGenC::Visit(const ir::Load *op) {
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::"
<< "Load(";
os() << op->tensor.As<ir::_Tensor_>()->name;
os() << ",";
Print(dense_strided_ramp);
os() << ")";
str_ += "::";
str_ += "Load(";
str_ += op->tensor.As<ir::_Tensor_>()->name;
str_ += ",";
IrPrinter::Visit(dense_strided_ramp);
str_ += ")";
} else if (op->index().type().is_vector()) {
// gather
CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::Load(";
os() << op->tensor.As<ir::_Tensor_>()->name;
os() << ",";
Print(op->index());
os() << ")";
str_ += "::Load(";
str_ += op->tensor.As<ir::_Tensor_>()->name;
str_ += ",";
IrPrinter::Visit(op->index());
str_ += ")";
} else if (op->is_addr_tensor()) {
auto *tensor = op->tensor.As<ir::_Tensor_>();
os() << tensor->name << "[";
Print(op->index());
os() << "]";
str_ += tensor->name;
str_ += "[";
IrPrinter::Visit(op->index());
str_ += "]";
} else {
IrPrinter::Visit(op);
}
......@@ -478,56 +486,59 @@ void CodeGenC::Visit(const ir::Store *op) {
auto *tensor = op->tensor.As<ir::_Tensor_>();
CHECK(tensor);
os() << tensor->name << "[";
Print(op->index());
os() << "]";
os() << " = ";
Print(op->value);
str_ += tensor->name;
str_ += "[";
IrPrinter::Visit(op->index());
str_ += "]";
str_ += " = ";
IrPrinter::Visit(op->value);
}
void CodeGenC::Visit(const ir::Alloc *op) {
os() << runtime::intrinsic::buffer_malloc;
os() << "(";
os() << "(void*)(0), ";
str_ += runtime::intrinsic::buffer_malloc;
str_ += "(";
str_ += "(void*)(0), ";
auto *buffer = op->destination.As<ir::_Buffer_>();
os() << buffer->name;
os() << ")";
str_ += buffer->name;
str_ += ")";
}
void CodeGenC::Visit(const ir::Free *op) {
os() << runtime::intrinsic::buffer_free;
os() << "(";
os() << "(void*)(0), ";
str_ += runtime::intrinsic::buffer_free;
str_ += "(";
str_ += "(void*)(0), ";
auto *buffer = op->destination.As<ir::_Buffer_>();
os() << buffer->name;
os() << ")";
str_ += buffer->name;
str_ += ")";
}
void CodeGenC::Visit(const ir::_Buffer_ *op) { os() << op->name; }
void CodeGenC::Visit(const ir::_Tensor_ *op) { os() << op->buffer->name; }
void CodeGenC::Visit(const ir::_Buffer_ *op) { str_ += op->name; }
void CodeGenC::Visit(const ir::_Tensor_ *op) { str_ += op->buffer->name; }
void CodeGenC::Visit(const ir::Let *op) {
bool is_vec = false;
CHECK(op->type().valid());
if (op->body.defined() && op->body.As<ir::Broadcast>()) {
// broadcast's type is hard to print, so use c++11 auto instead.
os() << "auto";
str_ += "auto";
is_vec = true;
} else {
os() << GetTypeRepr(op->type());
str_ += GetTypeRepr(op->type());
}
os() << " ";
Print(op->symbol);
str_ += " ";
IrPrinter::Visit(op->symbol);
// native C array.
if (op->type().lanes() > 1 && !is_vec) {
os() << "[" << op->type().lanes() << "]";
str_ += "[";
str_ += std::to_string(op->type().lanes());
str_ += "]";
}
if (op->body.defined()) {
os() << " = ";
Print(op->body);
str_ += " = ";
IrPrinter::Visit(op->body);
}
}
......@@ -537,22 +548,29 @@ void CodeGenC::Visit(const ir::Reduce *op) {
}
void CodeGenC::Visit(const ir::Ramp *op) {
os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf())
<< ">::Ramp(";
Print(op->base);
os() << ", ";
Print(op->stride);
os() << ", ";
os() << op->lanes;
os() << ")";
str_ += "StackVec<";
str_ += std::to_string(op->lanes);
str_ += ",";
str_ += GetTypeRepr(op->type().ElementOf());
str_ += ">::Ramp(";
IrPrinter::Visit(op->base);
str_ += ", ";
IrPrinter::Visit(op->stride);
str_ += ", ";
str_ += std::to_string(op->lanes);
str_ += ")";
}
void CodeGenC::Visit(const ir::Broadcast *op) {
os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf())
<< ">::Broadcast(";
Print(op->value);
os() << ", ";
os() << op->lanes << ")";
str_ += "StackVec<";
str_ += std::to_string(op->lanes);
str_ += ",";
str_ += GetTypeRepr(op->type().ElementOf());
str_ += ">::Broadcast(";
IrPrinter::Visit(op->value);
str_ += ", ";
str_ += std::to_string(op->lanes);
str_ += ")";
}
void CodeGenC::Visit(const ir::FracOp *op) { ir::IrPrinter::Visit(op); }
......@@ -560,35 +578,41 @@ void CodeGenC::Visit(const ir::Sum *op) { ir::IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Product *op) { ir::IrPrinter::Visit(op); }
void CodeGenC::PrintCastExpr(const Type &type, Expr e) {
os() << "((" << GetTypeRepr(type) << ")";
os() << "(";
Print(e);
os() << "))";
str_ += "((";
str_ += GetTypeRepr(type);
str_ += ")";
str_ += "(";
IrPrinter::Visit(e);
str_ += "))";
}
void CodeGenC::PrintCastExpr(const std::string &type, Expr e) {
os() << "(" << type << ")";
os() << "(";
Print(e);
os() << ")";
str_ += "(";
str_ += type;
str_ += ")";
str_ += "(";
IrPrinter::Visit(e);
str_ += ")";
}
void CodeGenC::PrintShape(const std::vector<Expr> &shape,
char leftb,
char rightb) {
os() << leftb << " ";
str_ += leftb;
str_ += " ";
for (int i = 0; i < shape.size() - 1; i++) {
Print(shape[i]);
os() << ", ";
IrPrinter::Visit(shape[i]);
str_ += ", ";
}
if (shape.size() > 1) Print(shape.back());
if (shape.size() > 1) IrPrinter::Visit(shape.back());
os() << " " << rightb;
str_ += " ";
str_ += rightb;
}
void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
PrintFunctionDeclaration(op);
os() << "\n";
str_ += "\n";
DoIndent();
......@@ -623,21 +647,21 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
optim::RemoveNestedBlock(&func_body);
Print(func_body);
IrPrinter::Visit(func_body);
}
void CodeGenC::PrintIncludes() {
os() << "#include <cinn_runtime.h>\n";
os() << "#include <stdio.h>\n";
os() << "\n";
str_ += "#include <cinn_runtime.h>\n";
str_ += "#include <stdio.h>\n";
str_ += "\n";
}
void CodeGenC::PrintFileGuardOpen(const std::string &name) {
os() << utils::StringFormat("#ifndef _%s_CINN_H_\n", Uppercase(name).c_str());
os() << utils::StringFormat("#define _%s_CINN_H_\n", Uppercase(name).c_str());
os() << "\n";
str_ += utils::StringFormat("#ifndef _%s_CINN_H_\n", Uppercase(name).c_str());
str_ += utils::StringFormat("#define _%s_CINN_H_\n", Uppercase(name).c_str());
str_ += "\n";
}
void CodeGenC::PrintFileGuardClose(const std::string &module_name) {
os() << utils::StringFormat("#endif // _%s_CINN_H_\n",
str_ += utils::StringFormat("#endif // _%s_CINN_H_\n",
Uppercase(module_name).c_str());
}
......@@ -653,16 +677,16 @@ void CodeGenC::PrintBufferCreation(const std::vector<ir::Buffer> &buffers) {
Var variable = ir::_Var_::Make(buffer->name, buffer_ptr_type);
auto expr = ir::intrinsics::BufferCreate::Make(buffer);
expr = ir::Let::Make(variable, expr);
Print(expr);
os() << ";\n";
IrPrinter::Visit(expr);
str_ += ";\n";
}
}
void CodeGenC::PrintBufferDestroy(const std::vector<ir::Buffer> &buffers) {
for (auto &buffer : buffers) {
DoIndent();
Print(buffer.DestroyExpr());
os() << ";\n";
IrPrinter::Visit(buffer.DestroyExpr());
str_ += ";\n";
}
}
......@@ -672,8 +696,8 @@ void CodeGenC::GenerateHeaderFile(const ir::Module &module) {
for (auto &func : module.functions()) {
PrintFunctionDeclaration(func.As<ir::_LoweredFunc_>());
os() << ";\n";
os() << "\n\n";
str_ += ";\n";
str_ += "\n\n";
}
PrintFileGuardClose(module.name());
......@@ -682,53 +706,58 @@ void CodeGenC::GenerateHeaderFile(const ir::Module &module) {
void CodeGenC::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) {
if (arg.is_input()) {
os() << "const struct cinn_buffer_t *";
str_ += "const struct cinn_buffer_t *";
} else {
os() << "struct cinn_buffer_t *";
str_ += "struct cinn_buffer_t *";
}
} else if (arg.is_var()) {
os() << GetTypeRepr(arg.type()) << " ";
os() << arg.name();
str_ += GetTypeRepr(arg.type());
str_ += " ";
str_ += arg.name();
} else {
CINN_NOT_IMPLEMENTED
}
os() << arg.name();
str_ += arg.name();
}
void CodeGenC::PrintRuntimeType(const cinn_type_t &type) {
if (type == cinn_bool_t()) {
os() << "cinn_bool_t()";
str_ += "cinn_bool_t()";
} else if (type == cinn_int8_t()) {
os() << "cinn_int8_t()";
str_ += "cinn_int8_t()";
} else if (type == cinn_int16_t()) {
os() << "cinn_int16_t()";
str_ += "cinn_int16_t()";
} else if (type == cinn_int32_t()) {
os() << "cinn_int32_t()";
str_ += "cinn_int32_t()";
} else if (type == cinn_int64_t()) {
os() << "cinn_int64_t()";
str_ += "cinn_int64_t()";
} else if (type == cinn_uint8_t()) {
os() << "cinn_uint8_t()";
str_ += "cinn_uint8_t()";
} else if (type == cinn_uint16_t()) {
os() << "cinn_uint16_t()";
str_ += "cinn_uint16_t()";
} else if (type == cinn_uint32_t()) {
os() << "cinn_uint32_t()";
str_ += "cinn_uint32_t()";
} else if (type == cinn_uint64_t()) {
os() << "cinn_uint64_t()";
str_ += "cinn_uint64_t()";
} else if (type == cinn_bfloat16_t()) {
os() << "cinn_bfloat16_t()";
str_ += "cinn_bfloat16_t()";
} else if (type == cinn_float16_t()) {
os() << "cinn_float16_t()";
str_ += "cinn_float16_t()";
} else if (type == cinn_float32_t()) {
os() << "cinn_float32_t()";
str_ += "cinn_float32_t()";
} else if (type == cinn_float64_t()) {
os() << "cinn_float64_t()";
str_ += "cinn_float64_t()";
} else {
LOG(FATAL) << "Unknown type is not supported to print";
}
}
void CodeGenC::PrintStackVecType(Type type, int lanes) {
os() << "StackedVec<" << GetTypeRepr(type) << "," << lanes << ">";
str_ += "StackedVec<";
str_ += GetTypeRepr(type);
str_ += ",";
str_ += std::to_string(lanes);
str_ += ">";
}
void CodeGenC::Visit(const ir::PrimitiveNode *op) { CINN_NOT_IMPLEMENTED }
......@@ -751,109 +780,117 @@ void CodeGenC::Visit(const ir::IntrinsicOp *op) {
}
void CodeGenC::Visit(const ir::intrinsics::BufferGetDataHandle *op) {
os() << op->buffer.as_buffer()->name;
os() << "->";
os() << "memory";
str_ += op->buffer.as_buffer()->name;
str_ += "->";
str_ += "memory";
}
void CodeGenC::Visit(const ir::intrinsics::BufferGetDataConstHandle *op) {
os() << op->buffer.as_buffer()->name;
os() << "->";
os() << "memory";
str_ += op->buffer.as_buffer()->name;
str_ += "->";
str_ += "memory";
}
void CodeGenC::Visit(const ir::intrinsics::PodValueToX *op) {
auto to_type = op->GetOutputType(0);
if (to_type == type_of<float>()) {
os() << runtime::intrinsic::pod_value_to_float;
str_ += runtime::intrinsic::pod_value_to_float;
} else if (to_type == type_of<double>()) {
os() << runtime::intrinsic::pod_value_to_double;
str_ += runtime::intrinsic::pod_value_to_double;
} else if (to_type == type_of<float16>()) {
os() << runtime::intrinsic::pod_value_to_float16;
str_ += runtime::intrinsic::pod_value_to_float16;
} else if (to_type == type_of<bool>()) {
os() << runtime::intrinsic::pod_value_to_bool;
str_ += runtime::intrinsic::pod_value_to_bool;
} else if (to_type == type_of<int8_t>()) {
os() << runtime::intrinsic::pod_value_to_int8;
str_ += runtime::intrinsic::pod_value_to_int8;
} else if (to_type == type_of<int16_t>()) {
os() << runtime::intrinsic::pod_value_to_int16;
str_ += runtime::intrinsic::pod_value_to_int16;
} else if (to_type == type_of<int32_t>()) {
os() << runtime::intrinsic::pod_value_to_int32;
str_ += runtime::intrinsic::pod_value_to_int32;
} else if (to_type == type_of<int64_t>()) {
os() << runtime::intrinsic::pod_value_to_int64;
str_ += runtime::intrinsic::pod_value_to_int64;
} else if (to_type == type_of<uint8_t>()) {
os() << runtime::intrinsic::pod_value_to_uint8;
str_ += runtime::intrinsic::pod_value_to_uint8;
} else if (to_type == type_of<uint16_t>()) {
os() << runtime::intrinsic::pod_value_to_uint16;
str_ += runtime::intrinsic::pod_value_to_uint16;
} else if (to_type == type_of<uint32_t>()) {
os() << runtime::intrinsic::pod_value_to_uint32;
str_ += runtime::intrinsic::pod_value_to_uint32;
} else if (to_type == type_of<uint64_t>()) {
os() << runtime::intrinsic::pod_value_to_uint64;
str_ += runtime::intrinsic::pod_value_to_uint64;
} else if (to_type == type_of<void *>()) {
os() << runtime::intrinsic::pod_value_to_void_p;
str_ += runtime::intrinsic::pod_value_to_void_p;
} else if (to_type == type_of<cinn_buffer_t *>()) {
os() << runtime::intrinsic::pod_value_to_buffer_p;
str_ += runtime::intrinsic::pod_value_to_buffer_p;
} else {
LOG(FATAL) << "Not supported type: " << to_type;
}
os() << "(";
Print(op->pod_value_ptr);
os() << ")";
str_ += "(";
IrPrinter::Visit(op->pod_value_ptr);
str_ += ")";
}
void CodeGenC::Visit(const ir::intrinsics::BufferCreate *op) {
const ir::_Buffer_ *buffer_arg = op->buffer.as_buffer();
CHECK(buffer_arg);
os() << runtime::intrinsic::buffer_create;
os() << "(";
str_ += runtime::intrinsic::buffer_create;
str_ += "(";
PrintCastExpr("cinn_device_kind_t", Expr(buffer_arg->target.runtime_arch()));
os() << "/*target*/, ";
str_ += "/*target*/, ";
PrintRuntimeType(runtime::ToRuntimeType(buffer_arg->dtype.ElementOf()));
os() << ", ";
str_ += ", ";
PrintShape(buffer_arg->shape);
if (buffer_arg->data_alignment > 0) {
os() << ", " << buffer_arg->data_alignment << "/*align*/";
str_ += ", ";
str_ += std::to_string(buffer_arg->data_alignment);
str_ += "/*align*/";
}
os() << ")";
str_ += ")";
}
void CodeGenC::Visit(const ir::intrinsics::GetAddr *op) {
if (op->data.as_buffer()) {
os() << "&" << op->data.as_buffer()->name;
str_ += "&";
str_ += op->data.as_buffer()->name;
} else if (op->data.as_var()) {
os() << "&" << op->data.as_var()->name;
str_ += "&";
str_ += op->data.as_var()->name;
} else {
os() << "&(";
Print(op->data);
os() << ")";
str_ += "&(";
IrPrinter::Visit(op->data);
str_ += ")";
}
}
void CodeGenC::Visit(const ir::intrinsics::ArgsConstruct *op) {
os() << runtime::intrinsic::args_construct_repr << "(";
os() << op->var->name << ", ";
os() << op->args.size() << ", ";
str_ += runtime::intrinsic::args_construct_repr;
str_ += "(";
str_ += op->var->name;
str_ += ", ";
str_ += std::to_string(op->args.size());
str_ += ", ";
for (int i = 0; i < op->args.size() - 1; i++) {
Print(op->args[i]);
os() << ", ";
IrPrinter::Visit(op->args[i]);
str_ += ", ";
}
if (!op->args.empty()) {
Print(op->args.back());
IrPrinter::Visit(op->args.back());
}
os() << ")";
str_ += ")";
}
void CodeGenC::Visit(const ir::intrinsics::BuiltinIntrin *op) {
os() << op->name << "(";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
Print(op->args[i]);
os() << ", ";
IrPrinter::Visit(op->args[i]);
str_ += ", ";
}
Print(op->args.back());
IrPrinter::Visit(op->args.back());
}
os() << ")";
str_ += ")";
}
std::string ReadWholeFile(const std::string &path) {
......@@ -874,7 +911,8 @@ void CodeGenC::PrintBuiltinCodes() {
auto source =
ReadWholeFile(FLAGS_cinn_x86_builtin_code_root + "/" + x86_code_file);
os() << source << "\n";
str_ += source;
str_ += "\n";
}
namespace detail {
......
......@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter {
void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; }
protected:
std::string Compile(const ir::LoweredFunc& function);
std::string Compile(const ir::Buffer& buffer);
void Compile(const ir::LoweredFunc& function);
void GenerateHeaderFile(const ir::Module& module);
......@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter {
// @}
void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) {
os() << "void " << op->name << "(";
os() << "void* _args, int32_t num_args";
os() << ")";
str_ += "void ";
str_ += op->name;
str_ += "(";
str_ += "void* _args, int32_t num_args";
str_ += ")";
}
void PrintShape(const std::vector<Expr>& shape,
......
......@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_load(";
str_ += "cinn_avx512_load(";
PrintAbsAddr(op);
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_load(";
str_ += "cinn_avx256_load(";
PrintAbsAddr(op);
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_set1(";
str_ += "cinn_avx512_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_set1(";
str_ += "cinn_avx256_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_store(";
str_ += "cinn_avx512_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
str_ += ", ";
IrPrinter::Visit(op->value);
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_store(";
str_ += "cinn_avx256_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
str_ += ", ";
IrPrinter::Visit(op->value);
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value;
if (SupportsAVX512()) {
os() << "cinn_avx512_set1(";
Print(value);
os() << ")";
str_ += "cinn_avx512_set1(";
IrPrinter::Visit(value);
str_ += ")";
} else if (SupportsAVX256()) {
os() << "cinn_avx256_set1(";
Print(value);
os() << ")";
str_ += "cinn_avx256_set1(";
IrPrinter::Visit(value);
str_ += ")";
} else {
CINN_NOT_IMPLEMENTED
}
} else {
Print(*op);
IrPrinter::Visit(*op);
}
}
......@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) {
}
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op->name << "(";
str_ += "cinn_avx512_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]);
os() << ", ";
str_ += ", ";
}
Print(op->args.back());
IrPrinter::Visit(op->args.back());
}
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op->name << "(";
str_ += "cinn_avx256_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]);
os() << ", ";
str_ += ", ";
}
PrintVecInputArgument(&op->args.back());
}
os() << ")";
str_ += ")";
} else if (bits == 128) {
os() << "cinn_avx128_" << op->name << "(";
str_ += "cinn_avx128_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]);
os() << ", ";
str_ += ", ";
}
PrintVecInputArgument(&op->args.back());
}
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......
......@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC {
template <typename Op>
void PrintAbsAddr(const Op *op) {
os() << op->tensor.template As<ir::_Tensor_>()->name << " + ";
str_ += op->tensor.template As<ir::_Tensor_>()->name;
str_ += " + ";
auto index = op->index();
auto *ramp_n = index.template As<ir::Ramp>();
if (ramp_n) {
CHECK(!ramp_n->base.template As<ir::Ramp>())
<< "base of a Ramp node should not be Ramp type";
Print(ramp_n->base);
IrPrinter::Visit(ramp_n->base);
} else {
Print(op->index());
IrPrinter::Visit(op->index());
}
}
......@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op,
// TODO(Superjomn) Consider support BLAS.
int bits = a.type().bits() * a.type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op_repr << "(";
str_ += "cinn_avx512_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a);
os() << ", ";
str_ += ", ";
PrintVecInputArgument(&b);
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op_repr << "(";
str_ += "cinn_avx256_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a);
os() << ", ";
str_ += ", ";
PrintVecInputArgument(&b);
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......
......@@ -62,6 +62,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
CodeGenC::inline_builtin_codes_ = false;
if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader);
str_ = "";
std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source;
......@@ -71,6 +72,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
if (!outputs.cuda_source_name.empty()) {
auto source = Compile(module, OutputKind::CImpl);
str_ = "";
std::ofstream file(outputs.cuda_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name;
file << source;
......@@ -79,9 +81,8 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
}
}
std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
Print(Expr(func));
return ss_.str();
void CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
IrPrinter::Visit(Expr(func));
}
std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
......@@ -117,10 +118,10 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
// clear names valid within scope when enter a new function
vectorized_tensor_names_.clear();
os() << "__global__\n";
str_ += "__global__\n";
PrintFunctionDeclaration(op);
os() << "\n";
str_ += "\n";
DoIndent();
......@@ -145,15 +146,16 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
if (!func_body.As<ir::Block>()) {
func_body = ir::Block::Make({func_body});
}
Print(func_body);
IrPrinter::Visit(func_body);
}
void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) {
if (utils::Startswith(op->name, "threadIdx") ||
utils::Startswith(op->name, "blockIdx")) {
os() << "(int)" + op->name;
str_ += "(int)";
str_ += op->name;
} else {
os() << op->name;
str_ += op->name;
}
}
......@@ -163,58 +165,63 @@ void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) {
}
void CodeGenCUDA_Dev::Visit(const ir::Min *op) {
os() << "min(";
Print(op->a());
os() << ", ";
Print(op->b());
os() << ")";
str_ += "min(";
IrPrinter::Visit(op->a());
str_ += ", ";
IrPrinter::Visit(op->b());
str_ += ")";
}
void CodeGenCUDA_Dev::Visit(const ir::Max *op) {
os() << "max(";
Print(op->a());
os() << ", ";
Print(op->b());
os() << ")";
str_ += "max(";
IrPrinter::Visit(op->a());
str_ += ", ";
IrPrinter::Visit(op->b());
str_ += ")";
}
void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) {
os() << "void ";
str_ += "void ";
if (op->cuda_axis_info.valid()) {
int thread_num = 1;
for (int i = 0; i < 3; i++) {
thread_num *= op->cuda_axis_info.block_dim(i);
}
os() << "__launch_bounds__(" << thread_num << ") ";
str_ += "__launch_bounds__(";
str_ += std::to_string(thread_num);
str_ += ") ";
}
os() << op->name << "(";
str_ += op->name;
str_ += "(";
for (int i = 0; i < op->args.size() - 1; i++) {
auto &arg = op->args[i];
PrintFuncArg(arg);
os() << ", ";
str_ += ", ";
}
if (!op->args.empty()) {
PrintFuncArg(op->args.back());
}
os() << ")";
str_ += ")";
}
void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) {
// In CUDA kernel, only primitive type is supported, so we replace the
// buffer with T*j
if (arg.is_input()) os() << "const ";
os() << GetTypeRepr(arg.buffer_arg()->dtype);
os() << "* ";
os() << kCKeywordRestrict << " ";
os() << ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>());
if (arg.is_input()) str_ += "const ";
str_ += GetTypeRepr(arg.buffer_arg()->dtype);
str_ += "* ";
str_ += kCKeywordRestrict;
str_ += " ";
str_ += ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>());
} else if (arg.is_var()) {
if (arg.var_arg()->type().is_cpp_handle()) {
os() << kCKeywordRestrict;
str_ += kCKeywordRestrict;
}
os() << GetTypeRepr(arg.type()) << " ";
os() << arg.name();
str_ += GetTypeRepr(arg.type());
str_ += " ";
str_ += arg.name();
} else {
CINN_NOT_IMPLEMENTED
}
......@@ -230,7 +237,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
PrintIncludes();
if (for_nvrtc_) {
os() << "\nextern \"C\" {\n\n";
str_ += "\nextern \"C\" {\n\n";
}
PrintBuiltinCodes();
......@@ -243,27 +250,30 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
}
if (for_nvrtc_) {
os() << "\n\n}";
str_ += "\n\n}";
}
return ss_.str();
return str_;
}
void CodeGenCUDA_Dev::PrintIncludes() { os() << GetSourceHeader(); }
void CodeGenCUDA_Dev::PrintIncludes() { str_ += GetSourceHeader(); }
void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
CHECK_NE(buffer->type(), Void());
auto print_gpu_memory = [&](const std::string &mark) {
os() << mark << GetTypeRepr(buffer->dtype) << " " << buffer->name << " ";
str_ += mark;
str_ += GetTypeRepr(buffer->dtype);
str_ += " ";
str_ += buffer->name;
str_ += " ";
os() << "[ ";
str_ += "[ ";
Expr buffer_size(1);
for (int i = 0; i < buffer->shape.size(); i++) {
buffer_size = buffer_size * buffer->shape[i];
}
optim::Simplify(&buffer_size);
Print(buffer_size);
os() << " ]";
IrPrinter::Visit(buffer_size);
str_ += " ]";
};
switch (buffer->memory_type) {
case ir::MemoryType::GPUShared:
......@@ -281,46 +291,47 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
}
void CodeGenCUDA_Dev::Visit(const ir::Call *op) {
os() << op->name + "(";
str_ += op->name;
str_ += "(";
if (!op->read_args.empty()) {
for (int i = 0; i < op->read_args.size() - 1; i++) {
auto &arg = op->read_args[i];
if (arg.as_tensor()) {
os() << arg.as_tensor()->name;
os() << ", ";
str_ += arg.as_tensor()->name;
str_ += ", ";
} else {
Print(arg);
os() << ", ";
IrPrinter::Visit(arg);
str_ += ", ";
}
}
if (op->read_args.back().as_tensor()) {
os() << op->read_args.back().as_tensor()->name;
str_ += op->read_args.back().as_tensor()->name;
} else {
Print(op->read_args.back());
IrPrinter::Visit(op->read_args.back());
}
}
if (!op->write_args.empty()) {
os() << ", ";
str_ += ", ";
for (int i = 0; i < op->write_args.size() - 1; i++) {
auto &arg = op->write_args[i];
if (arg.as_tensor()) {
os() << arg.as_tensor()->name;
os() << ", ";
str_ += arg.as_tensor()->name;
str_ += ", ";
} else {
Print(arg);
os() << ", ";
IrPrinter::Visit(arg);
str_ += ", ";
}
}
if (op->write_args.back().as_tensor()) {
os() << op->write_args.back().as_tensor()->name;
str_ += op->write_args.back().as_tensor()->name;
} else {
Print(op->write_args.back());
IrPrinter::Visit(op->write_args.back());
}
}
os() << ")";
str_ += ")";
}
void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
......@@ -331,20 +342,21 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
if (op->type().is_customized() &&
utils::Startswith(op->type().customized_type(),
common::customized_type::kcuda_builtin_vector_t)) {
os() << GetTypeRepr(op->type());
str_ += GetTypeRepr(op->type());
if (op->type().is_cpp_handle()) {
os() << " " << kCKeywordRestrict;
str_ += " ";
str_ += kCKeywordRestrict;
}
os() << " ";
Print(op->symbol);
str_ += " ";
IrPrinter::Visit(op->symbol);
vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol));
// skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not
// overloaded.
if (op->body.As<ir::IntImm>() && op->body.As<ir::IntImm>()->value == 0) {
return;
}
os() << " = ";
Print(op->body);
str_ += " = ";
IrPrinter::Visit(op->body);
} else {
CodeGenC::Visit(op);
}
......@@ -374,10 +386,14 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op,
return false;
}
if (is_store && tensor->type().is_cpp_handle()) {
os() << tensor->name << "[" << index << "]";
str_ += tensor->name;
str_ += "[";
str_ += std::to_string(index);
str_ += "]";
} else {
os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".")
<< index2suffix[index];
str_ += tensor->name;
str_ += (tensor->type().is_cpp_handle() ? "->" : ".");
str_ += index2suffix[index];
}
return true;
}
......@@ -396,8 +412,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) {
// accesses element at a cuda built-in vector, others still resolve to
// CodeGenC
if (PrintBuiltinVectorAccess(op, op->index(), true)) {
os() << " = ";
Print(op->value);
str_ += " = ";
IrPrinter::Visit(op->value);
} else {
CodeGenC::Visit(op);
}
......
......@@ -54,7 +54,7 @@ class CodeGenCUDA_Dev : public CodeGenC {
//! Compile on NVRTC.
std::string Compile(const ir::Module& module, bool for_nvrtc = true);
std::string Compile(const ir::LoweredFunc& func);
void Compile(const ir::LoweredFunc& func);
/**
* \brief Print a function argument in CUDA syntax. Currently, just some
......
......@@ -44,14 +44,24 @@ struct Type::Storage {
Type::~Type() {}
std::ostream &operator<<(std::ostream &os, const Type &t) {
if (t.is_cpp_const()) os << "const ";
os << Type2Str(t);
std::string Type::to_string() const {
std::string ret = "";
if (is_cpp_const()) ret += "const ";
ret += Type2Str(*this);
if (lanes() > 1) {
ret += "<";
ret += std::to_string(lanes());
ret += ">";
}
if (is_cpp_handle()) ret += "*";
if (is_cpp_handle2()) ret += "**";
if (t.lanes() > 1) os << "<" << t.lanes() << ">";
if (t.is_cpp_handle()) os << "*";
if (t.is_cpp_handle2()) os << "**";
return ret;
}
std::ostream &operator<<(std::ostream &os, const Type &t) {
os << t.to_string();
return os;
}
......
......@@ -149,6 +149,8 @@ struct Type {
//! Check if a dtype is supported in CINN yet.
bool is_supported() const;
std::string to_string() const;
friend std::ostream& operator<<(std::ostream& os, const Type& t);
~Type();
......
......@@ -32,83 +32,104 @@ namespace ir {
using common::bfloat16;
using common::float16;
void IrPrinter::Print(Expr e) { IRVisitorRequireReImpl::Visit(&e); }
void IrPrinter::Print(const Expr &e) {
IRVisitorRequireReImpl::Visit(&e);
os_ << str_;
str_ = "";
}
void IrPrinter::Print(const std::vector<Expr> &exprs,
const std::string &splitter) {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Print(exprs[i]);
os_ << splitter;
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Print(exprs.back());
if (!exprs.empty()) Visit(exprs.back());
os_ << str_;
str_ = "";
}
void IrPrinter::Visit(const IntImm *x) {
if (x->type().is_int(64)) {
os_ << x->value << "ll";
str_ += std::to_string(x->value);
str_ += "ll";
} else if (x->type().is_int(32)) {
os_ << x->value;
str_ += std::to_string(x->value);
} else if (x->type().is_int(16)) {
os_ << "(int16_t)" << x->value;
str_ += "(int16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_int(8)) {
os_ << "(int8_t)" << x->value;
str_ += "(int8_t)";
str_ += std::to_string(x->value);
} else {
LOG(FATAL) << "Not support int type: " << x->type();
}
}
void IrPrinter::Visit(const UIntImm *x) {
if (x->type().is_uint(64)) {
os_ << x->value << "ull";
str_ += std::to_string(x->value);
str_ += "ull";
} else if (x->type().is_uint(32)) {
os_ << x->value;
str_ += std::to_string(x->value);
} else if (x->type().is_uint(16)) {
os_ << "(uint16_t)" << x->value;
str_ += "(uint16_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(8)) {
os_ << "(uint8_t)" << x->value;
str_ += "(uint8_t)";
str_ += std::to_string(x->value);
} else if (x->type().is_uint(1)) {
if (x->value) {
os_ << "true";
str_ += "true";
} else {
os_ << "false";
str_ += "false";
}
} else {
LOG(FATAL) << "Not support uint type: " << x->type();
}
}
void IrPrinter::Visit(const FloatImm *x) {
std::ostringstream ss;
if (x->type().is_float16()) {
if (std::isinf(x->value)) {
os_ << "cinn::common::raw_uint16_to_float16(0x7c00)";
ss << "cinn::common::raw_uint16_to_float16(0x7c00)";
} else if (std::isnan(x->value)) {
os_ << "cinn::common::raw_uint16_to_float16(0x7e00)";
ss << "cinn::common::raw_uint16_to_float16(0x7e00)";
} else {
os_ << "(float16)"
<< std::setprecision(std::numeric_limits<float16>::max_digits10)
<< static_cast<float16>(x->value) << "f";
ss << "(float16)";
ss << std::setprecision(std::numeric_limits<float16>::max_digits10);
ss << static_cast<float16>(x->value) << "f";
}
} else if (x->type().is_bfloat16()) {
if (std::isinf(x->value)) {
os_ << "cinn::common::raw_uint16_to_bfloat16(0x7F80)";
ss << "cinn::common::raw_uint16_to_bfloat16(0x7F80)";
} else if (std::isnan(x->value)) {
os_ << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)";
ss << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)";
} else {
os_ << "(bfloat16)"
<< std::setprecision(std::numeric_limits<bfloat16>::max_digits10)
<< static_cast<bfloat16>(x->value) << "f";
ss << "(bfloat16)";
ss << std::setprecision(std::numeric_limits<bfloat16>::max_digits10);
ss << static_cast<bfloat16>(x->value) << "f";
}
} else if (x->type().is_float(32)) {
os_ << std::setprecision(std::numeric_limits<float>::max_digits10)
<< std::showpoint << x->value;
ss << std::setprecision(std::numeric_limits<float>::max_digits10);
ss << std::showpoint;
ss << x->value;
if (std::isfinite(x->value)) {
os_ << "f";
ss << "f";
}
} else if (x->type().is_float(64)) {
os_ << std::setprecision(std::numeric_limits<double>::max_digits10)
<< std::showpoint << x->value;
ss << std::setprecision(std::numeric_limits<double>::max_digits10);
ss << std::showpoint;
ss << x->value;
} else {
LOG(FATAL) << "Not support float type: " << x->type();
}
str_ += ss.str();
}
void IrPrinter::Visit(const StringImm *x) {
str_ += "\"";
str_ += x->value;
str_ += "\"";
}
void IrPrinter::Visit(const StringImm *x) { os_ << "\"" << x->value << "\""; }
void IrPrinter::Visit(const Add *x) { PrintBinaryOp("+", x); }
void IrPrinter::Visit(const Sub *x) { PrintBinaryOp("-", x); }
void IrPrinter::Visit(const Mul *x) { PrintBinaryOp("*", x); }
......@@ -123,36 +144,38 @@ void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); }
void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); }
void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); }
void IrPrinter::Visit(const Not *x) {
os_ << "!";
Print(x->v());
str_ += "!";
Visit(x->v());
}
void IrPrinter::Visit(const Min *x) {
os_ << "cinn_min(";
Print(x->a());
os_ << ", ";
Print(x->b());
os_ << ")";
str_ += "cinn_min(";
Visit(x->a());
str_ += ", ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Max *x) {
os_ << "cinn_max(";
Print(x->a());
os_ << ", ";
Print(x->b());
os_ << ")";
str_ += "cinn_max(";
Visit(x->a());
str_ += ", ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Minus *x) {
os_ << "-(";
Print(x->v());
os_ << ")";
str_ += "-(";
Visit(x->v());
str_ += ")";
}
void IrPrinter::Visit(const For *x) {
if (x->is_parallel()) {
os() << "parallel for (";
str_ += "parallel for (";
} else if (x->is_unrolled()) {
os() << "unroll for (";
str_ += "unroll for (";
} else if (x->is_vectorized()) {
int factor = x->vectorize_info().factor;
os() << "vectorize[" << factor << "] for (";
str_ += "vectorize[";
str_ += std::to_string(factor);
str_ += "] for (";
} else if (x->is_binded()) {
auto &bind_info = x->bind_info();
if (bind_info.valid()) {
......@@ -160,181 +183,189 @@ void IrPrinter::Visit(const For *x) {
auto for_type = bind_info.for_type;
std::string prefix =
for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx.";
os() << "thread_bind[" << prefix << axis_name << "] for (";
str_ += "thread_bind[";
str_ += prefix;
str_ += axis_name;
str_ += "] for (";
} else {
os() << "thread_bind[invalid info] for (";
str_ += "thread_bind[invalid info] for (";
}
} else if (x->is_serial()) {
os() << "serial for (";
str_ += "serial for (";
} else if (x->is_default()) {
os() << "default for (";
str_ += "default for (";
} else {
os() << "for (";
str_ += "for (";
}
Print(x->loop_var);
os_ << ", ";
Print(x->min);
os_ << ", ";
Print(x->extent);
os_ << ")\n";
Visit(x->loop_var);
str_ += ", ";
Visit(x->min);
str_ += ", ";
Visit(x->extent);
str_ += ")\n";
DoIndent();
Print(x->body);
Visit(x->body);
}
void IrPrinter::Visit(const PolyFor *x) {
if (x->is_parallel()) {
os() << "parallel poly_for (";
str_ += "parallel poly_for (";
} else {
os() << "poly_for (";
str_ += "poly_for (";
}
Print(x->iterator);
os_ << ", ";
Print(x->init);
os_ << ", ";
Print(x->condition);
os_ << ", ";
Print(x->inc);
os_ << ")\n";
Visit(x->iterator);
str_ += ", ";
Visit(x->init);
str_ += ", ";
Visit(x->condition);
str_ += ", ";
Visit(x->inc);
str_ += ")\n";
DoIndent();
Print(x->body);
Visit(x->body);
}
void IrPrinter::Visit(const IfThenElse *x) {
os_ << "if (";
Print(x->condition);
os_ << ") {\n";
str_ += "if (";
Visit(x->condition);
str_ += ") {\n";
IncIndent();
DoIndent();
Print(x->true_case);
Visit(x->true_case);
DecIndent();
os() << "\n";
str_ += "\n";
DoIndent();
os() << "}";
str_ += "}";
if (x->false_case.defined()) {
os_ << " else {\n";
str_ += " else {\n";
IncIndent();
DoIndent();
Print(x->false_case);
os() << "\n";
Visit(x->false_case);
str_ += "\n";
DecIndent();
DoIndent();
os_ << "}";
str_ += "}";
}
}
void IrPrinter::Visit(const Block *x) {
os_ << "{\n";
str_ += "{\n";
IncIndent();
for (std::size_t i = 0; !x->stmts.empty() && i + 1 < x->stmts.size(); i++) {
DoIndent();
Print(x->stmts[i]);
os_ << "\n";
Visit(x->stmts[i]);
str_ += "\n";
}
if (!x->stmts.empty()) {
DoIndent();
Print(x->stmts.back());
Visit(x->stmts.back());
}
DecIndent();
os_ << "\n";
str_ += "\n";
DoIndent();
os_ << "}";
str_ += "}";
}
void IrPrinter::Visit(const Call *x) {
os_ << x->name << "(";
str_ += x->name;
str_ += "(";
if (!x->read_args.empty()) {
for (std::size_t i = 0; i + 1 < x->read_args.size(); i++) {
Print(x->read_args[i]);
os_ << ", ";
Visit(x->read_args[i]);
str_ += ", ";
}
Print(x->read_args.back());
Visit(x->read_args.back());
}
if (!x->write_args.empty()) {
if (!x->read_args.empty()) os() << ", ";
if (!x->read_args.empty()) str_ += ", ";
for (std::size_t i = 0; i + 1 < x->write_args.size(); i++) {
Print(x->write_args[i]);
os_ << ", ";
Visit(x->write_args[i]);
str_ += ", ";
}
Print(x->write_args.back());
Visit(x->write_args.back());
}
os_ << ")";
str_ += ")";
}
void IrPrinter::Visit(const Cast *x) {
os() << x->type();
os() << "(";
os() << x->v();
os() << ")";
str_ += x->type().to_string();
str_ += "(";
Visit(x->v());
str_ += ")";
}
void IrPrinter::Visit(const _Module_ *x) {}
void IrPrinter::Visit(const _Var_ *x) { os_ << x->name; }
void IrPrinter::Visit(const _Var_ *x) { str_ += x->name; }
void IrPrinter::Visit(const Alloc *x) {
auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer);
os_ << "alloc(" << buffer->name << ", ";
Print(x->extents);
os_ << ")";
str_ += "alloc(";
str_ += buffer->name;
str_ += ", ";
Visit(x->extents);
str_ += ")";
}
void IrPrinter::Visit(const Select *x) {
os_ << "select(";
Print(x->condition);
os_ << ", ";
Print(x->true_value);
os_ << ", ";
Print(x->false_value);
os_ << ")";
str_ += "select(";
Visit(x->condition);
str_ += ", ";
Visit(x->true_value);
str_ += ", ";
Visit(x->false_value);
str_ += ")";
}
void IrPrinter::Visit(const Load *x) {
if (x->is_addr_tensor()) {
auto *tensor = x->tensor.As<ir::_Tensor_>();
CHECK(tensor);
os_ << tensor->name;
str_ += tensor->name;
} else if (x->is_addr_scalar()) {
Print(x->tensor);
Visit(x->tensor);
} else {
CINN_NOT_IMPLEMENTED
}
os_ << "[";
str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Print(x->indices[i]);
os() << ", ";
Visit(x->indices[i]);
str_ += ", ";
}
if (!x->indices.empty()) Print(x->indices.back());
os_ << "]";
if (!x->indices.empty()) Visit(x->indices.back());
str_ += "]";
}
void IrPrinter::Visit(const Store *x) {
if (x->is_addr_tensor()) {
auto *tensor_node = x->tensor.As<ir::_Tensor_>();
CHECK(tensor_node);
os_ << tensor_node->name;
str_ += tensor_node->name;
} else if (x->is_addr_scalar()) {
Print(x->tensor);
Visit(x->tensor);
} else {
CINN_NOT_IMPLEMENTED
}
os_ << "[";
str_ += "[";
for (std::size_t i = 0; i + 1 < x->indices.size(); i++) {
Print(x->indices[i]);
os() << ", ";
Visit(x->indices[i]);
str_ += ", ";
}
if (!x->indices.empty()) Print(x->indices.back());
os_ << "] = ";
Print(x->value);
if (!x->indices.empty()) Visit(x->indices.back());
str_ += "] = ";
Visit(x->value);
}
void IrPrinter::Visit(const Free *x) {
auto *buffer = x->destination.As<ir::_Buffer_>();
CHECK(buffer);
os_ << "free(" << buffer->name << ")";
str_ += "free(";
str_ += buffer->name;
str_ += ")";
}
void IrPrinter::DoIndent() { os_ << std::string(indent_, ' '); }
void IrPrinter::DoIndent() { str_ += std::string(indent_, ' '); }
void IrPrinter::IncIndent() { indent_ += indent_unit; }
void IrPrinter::DecIndent() { indent_ -= indent_unit; }
......@@ -345,126 +376,139 @@ void IrPrinter::Visit(const _Buffer_ *x) {
std::back_inserter(dim_names),
[&](const Expr &x) { return utils::GetStreamCnt(x); });
os_ << "_Buffer_<" << x->type() << ": " << utils::Join(dim_names, ",") << ">("
<< x->name << ")";
str_ += "_Buffer_<";
str_ += x->type().to_string();
str_ += ": ";
str_ += utils::Join(dim_names, ",");
str_ += ">(";
str_ += x->name;
str_ += ")";
}
void IrPrinter::Visit(const _Tensor_ *x) {
os_ << "Tensor(";
os() << x->name << ", ";
os() << "[";
str_ += "Tensor(";
str_ += x->name;
str_ += ", ";
str_ += "[";
if (!x->shape.empty()) {
for (std::size_t i = 0; i + 1 < x->shape.size(); i++) {
Print(x->shape[i]);
os() << ",";
Visit(x->shape[i]);
str_ += ",";
}
Print(x->shape.back());
Visit(x->shape.back());
}
os_ << "])";
str_ += "])";
}
void IrPrinter::Visit(const _LoweredFunc_ *f) {
os_ << "function " << f->name << " ";
str_ += "function ";
str_ += f->name;
str_ += " ";
std::vector<std::string> arg_names;
for (auto &arg : f->args) {
arg_names.push_back(arg.name());
}
os_ << "(" << utils::Join(arg_names, ", ") << ")\n";
str_ += "(";
str_ += utils::Join(arg_names, ", ");
str_ += ")\n";
Print(f->body);
Visit(f->body);
}
void IrPrinter::Visit(const Let *f) {
CHECK(f->type().valid());
os() << f->type() << " ";
Print(f->symbol);
str_ += f->type().to_string();
str_ += " ";
Visit(f->symbol);
if (f->body.defined()) {
os() << " = ";
Print(f->body);
str_ += " = ";
Visit(f->body);
}
}
void IrPrinter::Visit(const Reduce *f) {
os() << "Reduce(";
str_ += "Reduce(";
switch (f->reduce_type) {
case Reduce::ReduceType::kSum:
os() << "sum";
str_ += "sum";
break;
case Reduce::ReduceType::kSub:
os() << "sub";
str_ += "sub";
break;
case Reduce::ReduceType::kDiv:
os() << "Div";
str_ += "Div";
break;
case Reduce::ReduceType::kMul:
os() << "Mul";
str_ += "Mul";
break;
case Reduce::ReduceType::kMax:
os() << "Max";
str_ += "Max";
break;
case Reduce::ReduceType::kMin:
os() << "Min";
str_ += "Min";
break;
case Reduce::ReduceType::kAll:
os() << "&&";
str_ += "&&";
break;
case Reduce::ReduceType::kAny:
os() << "||";
str_ += "||";
break;
}
os() << ", ";
Print(f->body);
os() << ",";
Print(f->init);
os() << ")";
str_ += ", ";
Visit(f->body);
str_ += ",";
Visit(f->init);
str_ += ")";
}
void IrPrinter::Visit(const Ramp *x) {
os() << "Ramp(";
Print(x->base);
os() << ",";
Print(x->stride);
os() << ",";
os() << x->lanes;
os() << ")";
str_ += "Ramp(";
Visit(x->base);
str_ += ",";
Visit(x->stride);
str_ += ",";
str_ += std::to_string(x->lanes);
str_ += ")";
}
void IrPrinter::Visit(const Broadcast *x) {
os() << "Broadcast(";
Print(x->value);
os() << ",";
os() << x->lanes;
os() << ")";
str_ += "Broadcast(";
Visit(x->value);
str_ += ",";
str_ += std::to_string(x->lanes);
str_ += ")";
}
void IrPrinter::Visit(const FracOp *x) {
os() << "(";
Print(x->a());
os() << " / ";
Print(x->b());
os() << ")";
str_ += "(";
Visit(x->a());
str_ += " / ";
Visit(x->b());
str_ += ")";
}
void IrPrinter::Visit(const Product *x) {
os() << "(";
str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Print(x->operand(i));
os() << " * ";
Visit(x->operand(i));
str_ += " * ";
}
if (!x->operands().empty()) Print(x->operands().back());
os() << ")";
if (!x->operands().empty()) Visit(x->operands().back());
str_ += ")";
}
void IrPrinter::Visit(const Sum *x) {
os() << "(";
str_ += "(";
for (std::size_t i = 0; i + 1 < x->operands().size(); i++) {
Print(x->operand(i));
os() << " + ";
Visit(x->operand(i));
str_ += " + ";
}
if (!x->operands().empty()) Print(x->operands().back());
os() << ")";
if (!x->operands().empty()) Visit(x->operands().back());
str_ += ")";
}
void IrPrinter::Visit(const PrimitiveNode *x) {
os() << x->name << "(";
str_ += x->name;
str_ += "(";
std::vector<std::string> args_repr;
for (auto &args : x->arguments) {
std::vector<std::string> arg_repr;
......@@ -474,41 +518,46 @@ void IrPrinter::Visit(const PrimitiveNode *x) {
args_repr.push_back(utils::Join(arg_repr, ","));
}
os() << utils::Join(args_repr, ",");
os() << ")";
str_ += utils::Join(args_repr, ",");
str_ += ")";
}
void IrPrinter::Visit(const _BufferRange_ *x) {
auto *buffer = x->buffer.As<ir::_Buffer_>();
CHECK(buffer);
os() << buffer->name << "[";
str_ += buffer->name;
str_ += "[";
for (std::size_t i = 0; i < x->ranges.size(); i++) {
if (i) os() << ", ";
if (i) str_ += ", ";
auto &range = x->ranges[i];
os() << range->name << "(";
str_ += range->name;
str_ += "(";
if (range->lower_bound.defined()) {
os() << range->lower_bound << ":";
Visit(range->lower_bound);
str_ += ":";
} else {
os() << "undefined:";
str_ += "undefined:";
}
if (range->upper_bound.defined()) {
os() << range->upper_bound;
Visit(range->upper_bound);
} else {
os() << "undefined";
str_ += "undefined";
}
os() << ")";
str_ += ")";
}
os() << "]";
str_ += "]";
}
void IrPrinter::Visit(const ScheduleBlock *x) {}
void IrPrinter::Visit(const ScheduleBlockRealize *x) {
auto *schedule_block = x->schedule_block.As<ScheduleBlock>();
os() << "ScheduleBlock(" << schedule_block->name << ")\n";
str_ += "ScheduleBlock(";
str_ += schedule_block->name;
str_ += ")\n";
DoIndent();
os() << "{\n";
str_ += "{\n";
// print block vars and bindings
auto iter_vars = schedule_block->iter_vars;
auto iter_values = x->iter_values;
......@@ -516,54 +565,61 @@ void IrPrinter::Visit(const ScheduleBlockRealize *x) {
IncIndent();
if (!iter_vars.empty()) DoIndent();
for (std::size_t i = 0; i < iter_vars.size(); i++) {
if (i) os() << ", ";
os() << iter_vars[i]->name;
if (i) str_ += ", ";
str_ += iter_vars[i]->name;
}
if (!iter_vars.empty()) os() << " = axis.bind(";
if (!iter_vars.empty()) str_ += " = axis.bind(";
for (std::size_t i = 0; i < iter_values.size(); i++) {
if (i) os() << ", ";
os() << iter_values[i];
if (i) str_ += ", ";
Visit(iter_values[i]);
}
if (!iter_vars.empty()) os() << ")\n";
if (!iter_vars.empty()) str_ += ")\n";
// print block body
if (!schedule_block->read_buffers.empty()) {
DoIndent();
os() << "read_buffers(";
str_ += "read_buffers(";
auto &read_buffers = schedule_block->read_buffers;
for (std::size_t i = 0; i < read_buffers.size(); i++) {
if (i) os() << ", ";
Print(read_buffers[i]);
if (i) str_ += ", ";
Visit(read_buffers[i]);
}
os() << ")\n";
str_ += ")\n";
}
if (!schedule_block->write_buffers.empty()) {
DoIndent();
os() << "write_buffers(";
str_ += "write_buffers(";
auto &write_buffers = schedule_block->write_buffers;
for (std::size_t i = 0; i < write_buffers.size(); i++) {
if (i) os() << ", ";
Print(write_buffers[i]);
if (i) str_ += ", ";
Visit(write_buffers[i]);
}
os() << ")\n";
str_ += ")\n";
}
if (!schedule_block->attrs.empty()) {
DoIndent();
os() << "attrs(";
str_ += "attrs(";
bool comma = false;
for (auto &&kv : schedule_block->attrs) {
if (comma) os() << ", ";
os() << kv.first << ":";
absl::visit([this](auto &&arg) { this->os() << arg; }, kv.second);
if (comma) str_ += ", ";
str_ += kv.first;
str_ += ":";
absl::visit(
[this](auto &&arg) {
std::ostringstream ss;
ss << arg;
this->str_ += ss.str();
},
kv.second);
comma = true;
}
os() << ")\n";
str_ += ")\n";
}
DoIndent();
Print(schedule_block->body);
os() << "\n";
Visit(schedule_block->body);
str_ += "\n";
DecIndent();
DoIndent();
os() << "}";
str_ += "}";
}
void IrPrinter::Visit(const IntrinsicOp *x) {
......@@ -578,50 +634,52 @@ void IrPrinter::Visit(const IntrinsicOp *x) {
}
}
void IrPrinter::Visit(const intrinsics::BufferGetDataHandle *x) {
os() << runtime::intrinsic::buffer_get_data_handle;
Print(x->buffer);
os() << ")";
str_ += runtime::intrinsic::buffer_get_data_handle;
Visit(x->buffer);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BufferGetDataConstHandle *x) {
os() << runtime::intrinsic::buffer_get_data_const_handle;
Print(x->buffer);
os() << ")";
str_ += runtime::intrinsic::buffer_get_data_const_handle;
Visit(x->buffer);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::PodValueToX *x) {
os() << "pod_value_to_";
os() << x->GetOutputType(0);
os() << "(";
Print(x->pod_value_ptr);
os() << ")";
str_ += "pod_value_to_";
str_ += x->GetOutputType(0).to_string();
str_ += "(";
Visit(x->pod_value_ptr);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BufferCreate *x) {
os() << runtime::intrinsic::buffer_create;
os() << "()";
str_ += runtime::intrinsic::buffer_create;
str_ += "()";
}
void IrPrinter::Visit(const intrinsics::GetAddr *x) {
os() << "get_addr(";
Print(x->data);
os() << ")";
str_ += "get_addr(";
Visit(x->data);
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) {
os() << runtime::intrinsic::args_construct_repr;
os() << "(";
Print(std::vector<Expr>(x->args.begin(), x->args.end()));
os() << ")";
str_ += runtime::intrinsic::args_construct_repr;
str_ += "(";
Visit(std::vector<Expr>(x->args.begin(), x->args.end()));
str_ += ")";
}
void IrPrinter::Visit(const intrinsics::BuiltinIntrin *x) {
os_ << runtime::intrinsic::builtin_intrin_repr << "_";
os_ << x->name << "(";
str_ += runtime::intrinsic::builtin_intrin_repr;
str_ += "_";
str_ += x->name;
str_ += "(";
if (!x->args.empty()) {
for (std::size_t i = 0; i + 1 < x->args.size(); i++) {
Print(x->args[i]);
os_ << ", ";
Visit(x->args[i]);
str_ += ", ";
}
Print(x->args.back());
Visit(x->args.back());
}
os_ << ")";
str_ += ")";
}
std::ostream &operator<<(std::ostream &os, Expr a) {
......
......@@ -30,10 +30,10 @@ namespace ir {
class Module;
struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os) {}
explicit IrPrinter(std::ostream &os) : os_(os), str_("") {}
//! Emit an expression on the output stream.
void Print(Expr e);
void Print(const Expr &e);
//! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs,
const std::string &splitter = ", ");
......@@ -50,6 +50,17 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
std::ostream &os() { return os_; }
void Visit(const Expr &x) { IRVisitorRequireReImpl::Visit(&x); }
void Visit(const std::vector<Expr> &exprs,
const std::string &splitter = ", ") {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
}
#define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL(__)
#undef __
......@@ -58,6 +69,9 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
INTRINSIC_KIND_FOR_EACH(__)
#undef __
protected:
std::string str_;
private:
std::ostream &os_;
uint16_t indent_{};
......@@ -71,11 +85,13 @@ std::ostream &operator<<(std::ostream &os, const Module &m);
template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op,
const BinaryOpNode<IRN> *x) {
os_ << "(";
Print(x->a());
os_ << " " + op + " ";
Print(x->b());
os_ << ")";
str_ += "(";
Visit(x->a());
str_ += " ";
str_ += op;
str_ += " ";
Visit(x->b());
str_ += ")";
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册