未验证 提交 68b0cf92 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】refactor codegen for cinn (#55955)

* refactor codegen for cinn

* add to_string to some type which can't be += with string

* fix multi-thread bug caused by static var

* delete dead code and comment
上级 db96ae58
此差异已折叠。
...@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter { ...@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter {
void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; } void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; }
protected: protected:
std::string Compile(const ir::LoweredFunc& function); void Compile(const ir::LoweredFunc& function);
std::string Compile(const ir::Buffer& buffer);
void GenerateHeaderFile(const ir::Module& module); void GenerateHeaderFile(const ir::Module& module);
...@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter { ...@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter {
// @} // @}
void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) { void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) {
os() << "void " << op->name << "("; str_ += "void ";
os() << "void* _args, int32_t num_args"; str_ += op->name;
os() << ")"; str_ += "(";
str_ += "void* _args, int32_t num_args";
str_ += ")";
} }
void PrintShape(const std::vector<Expr>& shape, void PrintShape(const std::vector<Expr>& shape,
......
...@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) { ...@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) {
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_load("; str_ += "cinn_avx512_load(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_load("; str_ += "cinn_avx256_load(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) { ...@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) {
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_set1("; str_ += "cinn_avx512_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value); PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_set1("; str_ += "cinn_avx256_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value); PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) { ...@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) {
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_store("; str_ += "cinn_avx512_store(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ", "; str_ += ", ";
Print(op->value); IrPrinter::Visit(op->value);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_store("; str_ += "cinn_avx256_store(";
PrintAbsAddr(op); PrintAbsAddr(op);
os() << ", "; str_ += ", ";
Print(op->value); IrPrinter::Visit(op->value);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) { ...@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value; Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value;
if (SupportsAVX512()) { if (SupportsAVX512()) {
os() << "cinn_avx512_set1("; str_ += "cinn_avx512_set1(";
Print(value); IrPrinter::Visit(value);
os() << ")"; str_ += ")";
} else if (SupportsAVX256()) { } else if (SupportsAVX256()) {
os() << "cinn_avx256_set1("; str_ += "cinn_avx256_set1(";
Print(value); IrPrinter::Visit(value);
os() << ")"; str_ += ")";
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
} else { } else {
Print(*op); IrPrinter::Visit(*op);
} }
} }
...@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) { ...@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) {
} }
int bits = op->type().bits() * op->type().lanes(); int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op->name << "("; str_ += "cinn_avx512_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]); PrintVecInputArgument(&op->args[i]);
os() << ", "; str_ += ", ";
} }
Print(op->args.back()); IrPrinter::Visit(op->args.back());
} }
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op->name << "("; str_ += "cinn_avx256_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]); PrintVecInputArgument(&op->args[i]);
os() << ", "; str_ += ", ";
} }
PrintVecInputArgument(&op->args.back()); PrintVecInputArgument(&op->args.back());
} }
os() << ")"; str_ += ")";
} else if (bits == 128) { } else if (bits == 128) {
os() << "cinn_avx128_" << op->name << "("; str_ += "cinn_avx128_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) { if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]); PrintVecInputArgument(&op->args[i]);
os() << ", "; str_ += ", ";
} }
PrintVecInputArgument(&op->args.back()); PrintVecInputArgument(&op->args.back());
} }
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
......
...@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC { ...@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC {
template <typename Op> template <typename Op>
void PrintAbsAddr(const Op *op) { void PrintAbsAddr(const Op *op) {
os() << op->tensor.template As<ir::_Tensor_>()->name << " + "; str_ += op->tensor.template As<ir::_Tensor_>()->name;
str_ += " + ";
auto index = op->index(); auto index = op->index();
auto *ramp_n = index.template As<ir::Ramp>(); auto *ramp_n = index.template As<ir::Ramp>();
if (ramp_n) { if (ramp_n) {
CHECK(!ramp_n->base.template As<ir::Ramp>()) CHECK(!ramp_n->base.template As<ir::Ramp>())
<< "base of a Ramp node should not be Ramp type"; << "base of a Ramp node should not be Ramp type";
Print(ramp_n->base); IrPrinter::Visit(ramp_n->base);
} else { } else {
Print(op->index()); IrPrinter::Visit(op->index());
} }
} }
...@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op, ...@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op,
// TODO(Superjomn) Consider support BLAS. // TODO(Superjomn) Consider support BLAS.
int bits = a.type().bits() * a.type().lanes(); int bits = a.type().bits() * a.type().lanes();
if (SupportsAVX512() && bits == 512) { if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op_repr << "("; str_ += "cinn_avx512_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a); PrintVecInputArgument(&a);
os() << ", "; str_ += ", ";
PrintVecInputArgument(&b); PrintVecInputArgument(&b);
os() << ")"; str_ += ")";
} else if (SupportsAVX256() && bits == 256) { } else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op_repr << "("; str_ += "cinn_avx256_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a); PrintVecInputArgument(&a);
os() << ", "; str_ += ", ";
PrintVecInputArgument(&b); PrintVecInputArgument(&b);
os() << ")"; str_ += ")";
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
......
...@@ -62,6 +62,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -62,6 +62,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
CodeGenC::inline_builtin_codes_ = false; CodeGenC::inline_builtin_codes_ = false;
if (!outputs.c_header_name.empty()) { if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader); auto source = Compile(module, OutputKind::CHeader);
str_ = "";
std::ofstream file(outputs.c_header_name); std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source; file << source;
...@@ -71,6 +72,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -71,6 +72,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
if (!outputs.cuda_source_name.empty()) { if (!outputs.cuda_source_name.empty()) {
auto source = Compile(module, OutputKind::CImpl); auto source = Compile(module, OutputKind::CImpl);
str_ = "";
std::ofstream file(outputs.cuda_source_name); std::ofstream file(outputs.cuda_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name; CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name;
file << source; file << source;
...@@ -79,9 +81,8 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -79,9 +81,8 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
} }
} }
std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) { void CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
Print(Expr(func)); IrPrinter::Visit(Expr(func));
return ss_.str();
} }
std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs( std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
...@@ -117,10 +118,10 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs( ...@@ -117,10 +118,10 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
// clear names valid within scope when enter a new function // clear names valid within scope when enter a new function
vectorized_tensor_names_.clear(); vectorized_tensor_names_.clear();
os() << "__global__\n"; str_ += "__global__\n";
PrintFunctionDeclaration(op); PrintFunctionDeclaration(op);
os() << "\n"; str_ += "\n";
DoIndent(); DoIndent();
...@@ -145,15 +146,16 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { ...@@ -145,15 +146,16 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
if (!func_body.As<ir::Block>()) { if (!func_body.As<ir::Block>()) {
func_body = ir::Block::Make({func_body}); func_body = ir::Block::Make({func_body});
} }
Print(func_body); IrPrinter::Visit(func_body);
} }
void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) { void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) {
if (utils::Startswith(op->name, "threadIdx") || if (utils::Startswith(op->name, "threadIdx") ||
utils::Startswith(op->name, "blockIdx")) { utils::Startswith(op->name, "blockIdx")) {
os() << "(int)" + op->name; str_ += "(int)";
str_ += op->name;
} else { } else {
os() << op->name; str_ += op->name;
} }
} }
...@@ -163,58 +165,63 @@ void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) { ...@@ -163,58 +165,63 @@ void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) {
} }
void CodeGenCUDA_Dev::Visit(const ir::Min *op) { void CodeGenCUDA_Dev::Visit(const ir::Min *op) {
os() << "min("; str_ += "min(";
Print(op->a()); IrPrinter::Visit(op->a());
os() << ", "; str_ += ", ";
Print(op->b()); IrPrinter::Visit(op->b());
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::Visit(const ir::Max *op) { void CodeGenCUDA_Dev::Visit(const ir::Max *op) {
os() << "max("; str_ += "max(";
Print(op->a()); IrPrinter::Visit(op->a());
os() << ", "; str_ += ", ";
Print(op->b()); IrPrinter::Visit(op->b());
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) { void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) {
os() << "void "; str_ += "void ";
if (op->cuda_axis_info.valid()) { if (op->cuda_axis_info.valid()) {
int thread_num = 1; int thread_num = 1;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
thread_num *= op->cuda_axis_info.block_dim(i); thread_num *= op->cuda_axis_info.block_dim(i);
} }
os() << "__launch_bounds__(" << thread_num << ") "; str_ += "__launch_bounds__(";
str_ += std::to_string(thread_num);
str_ += ") ";
} }
os() << op->name << "("; str_ += op->name;
str_ += "(";
for (int i = 0; i < op->args.size() - 1; i++) { for (int i = 0; i < op->args.size() - 1; i++) {
auto &arg = op->args[i]; auto &arg = op->args[i];
PrintFuncArg(arg); PrintFuncArg(arg);
os() << ", "; str_ += ", ";
} }
if (!op->args.empty()) { if (!op->args.empty()) {
PrintFuncArg(op->args.back()); PrintFuncArg(op->args.back());
} }
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) { void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) { if (arg.is_buffer()) {
// In CUDA kernel, only primitive type is supported, so we replace the // In CUDA kernel, only primitive type is supported, so we replace the
// buffer with T*j // buffer with T*j
if (arg.is_input()) os() << "const "; if (arg.is_input()) str_ += "const ";
os() << GetTypeRepr(arg.buffer_arg()->dtype); str_ += GetTypeRepr(arg.buffer_arg()->dtype);
os() << "* "; str_ += "* ";
os() << kCKeywordRestrict << " "; str_ += kCKeywordRestrict;
os() << ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>()); str_ += " ";
str_ += ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>());
} else if (arg.is_var()) { } else if (arg.is_var()) {
if (arg.var_arg()->type().is_cpp_handle()) { if (arg.var_arg()->type().is_cpp_handle()) {
os() << kCKeywordRestrict; str_ += kCKeywordRestrict;
} }
os() << GetTypeRepr(arg.type()) << " "; str_ += GetTypeRepr(arg.type());
os() << arg.name(); str_ += " ";
str_ += arg.name();
} else { } else {
CINN_NOT_IMPLEMENTED CINN_NOT_IMPLEMENTED
} }
...@@ -230,7 +237,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -230,7 +237,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
PrintIncludes(); PrintIncludes();
if (for_nvrtc_) { if (for_nvrtc_) {
os() << "\nextern \"C\" {\n\n"; str_ += "\nextern \"C\" {\n\n";
} }
PrintBuiltinCodes(); PrintBuiltinCodes();
...@@ -243,27 +250,30 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, ...@@ -243,27 +250,30 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
} }
if (for_nvrtc_) { if (for_nvrtc_) {
os() << "\n\n}"; str_ += "\n\n}";
} }
return str_;
return ss_.str();
} }
void CodeGenCUDA_Dev::PrintIncludes() { os() << GetSourceHeader(); } void CodeGenCUDA_Dev::PrintIncludes() { str_ += GetSourceHeader(); }
void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
CHECK_NE(buffer->type(), Void()); CHECK_NE(buffer->type(), Void());
auto print_gpu_memory = [&](const std::string &mark) { auto print_gpu_memory = [&](const std::string &mark) {
os() << mark << GetTypeRepr(buffer->dtype) << " " << buffer->name << " "; str_ += mark;
str_ += GetTypeRepr(buffer->dtype);
str_ += " ";
str_ += buffer->name;
str_ += " ";
os() << "[ "; str_ += "[ ";
Expr buffer_size(1); Expr buffer_size(1);
for (int i = 0; i < buffer->shape.size(); i++) { for (int i = 0; i < buffer->shape.size(); i++) {
buffer_size = buffer_size * buffer->shape[i]; buffer_size = buffer_size * buffer->shape[i];
} }
optim::Simplify(&buffer_size); optim::Simplify(&buffer_size);
Print(buffer_size); IrPrinter::Visit(buffer_size);
os() << " ]"; str_ += " ]";
}; };
switch (buffer->memory_type) { switch (buffer->memory_type) {
case ir::MemoryType::GPUShared: case ir::MemoryType::GPUShared:
...@@ -281,46 +291,47 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { ...@@ -281,46 +291,47 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
} }
void CodeGenCUDA_Dev::Visit(const ir::Call *op) { void CodeGenCUDA_Dev::Visit(const ir::Call *op) {
os() << op->name + "("; str_ += op->name;
str_ += "(";
if (!op->read_args.empty()) { if (!op->read_args.empty()) {
for (int i = 0; i < op->read_args.size() - 1; i++) { for (int i = 0; i < op->read_args.size() - 1; i++) {
auto &arg = op->read_args[i]; auto &arg = op->read_args[i];
if (arg.as_tensor()) { if (arg.as_tensor()) {
os() << arg.as_tensor()->name; str_ += arg.as_tensor()->name;
os() << ", "; str_ += ", ";
} else { } else {
Print(arg); IrPrinter::Visit(arg);
os() << ", "; str_ += ", ";
} }
} }
if (op->read_args.back().as_tensor()) { if (op->read_args.back().as_tensor()) {
os() << op->read_args.back().as_tensor()->name; str_ += op->read_args.back().as_tensor()->name;
} else { } else {
Print(op->read_args.back()); IrPrinter::Visit(op->read_args.back());
} }
} }
if (!op->write_args.empty()) { if (!op->write_args.empty()) {
os() << ", "; str_ += ", ";
for (int i = 0; i < op->write_args.size() - 1; i++) { for (int i = 0; i < op->write_args.size() - 1; i++) {
auto &arg = op->write_args[i]; auto &arg = op->write_args[i];
if (arg.as_tensor()) { if (arg.as_tensor()) {
os() << arg.as_tensor()->name; str_ += arg.as_tensor()->name;
os() << ", "; str_ += ", ";
} else { } else {
Print(arg); IrPrinter::Visit(arg);
os() << ", "; str_ += ", ";
} }
} }
if (op->write_args.back().as_tensor()) { if (op->write_args.back().as_tensor()) {
os() << op->write_args.back().as_tensor()->name; str_ += op->write_args.back().as_tensor()->name;
} else { } else {
Print(op->write_args.back()); IrPrinter::Visit(op->write_args.back());
} }
} }
os() << ")"; str_ += ")";
} }
void CodeGenCUDA_Dev::Visit(const ir::Let *op) { void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
...@@ -331,20 +342,21 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { ...@@ -331,20 +342,21 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
if (op->type().is_customized() && if (op->type().is_customized() &&
utils::Startswith(op->type().customized_type(), utils::Startswith(op->type().customized_type(),
common::customized_type::kcuda_builtin_vector_t)) { common::customized_type::kcuda_builtin_vector_t)) {
os() << GetTypeRepr(op->type()); str_ += GetTypeRepr(op->type());
if (op->type().is_cpp_handle()) { if (op->type().is_cpp_handle()) {
os() << " " << kCKeywordRestrict; str_ += " ";
str_ += kCKeywordRestrict;
} }
os() << " "; str_ += " ";
Print(op->symbol); IrPrinter::Visit(op->symbol);
vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol)); vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol));
// skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not // skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not
// overloaded. // overloaded.
if (op->body.As<ir::IntImm>() && op->body.As<ir::IntImm>()->value == 0) { if (op->body.As<ir::IntImm>() && op->body.As<ir::IntImm>()->value == 0) {
return; return;
} }
os() << " = "; str_ += " = ";
Print(op->body); IrPrinter::Visit(op->body);
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
...@@ -374,10 +386,14 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ...@@ -374,10 +386,14 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op,
return false; return false;
} }
if (is_store && tensor->type().is_cpp_handle()) { if (is_store && tensor->type().is_cpp_handle()) {
os() << tensor->name << "[" << index << "]"; str_ += tensor->name;
str_ += "[";
str_ += std::to_string(index);
str_ += "]";
} else { } else {
os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") str_ += tensor->name;
<< index2suffix[index]; str_ += (tensor->type().is_cpp_handle() ? "->" : ".");
str_ += index2suffix[index];
} }
return true; return true;
} }
...@@ -396,8 +412,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) { ...@@ -396,8 +412,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) {
// accesses element at a cuda built-in vector, others still resolve to // accesses element at a cuda built-in vector, others still resolve to
// CodeGenC // CodeGenC
if (PrintBuiltinVectorAccess(op, op->index(), true)) { if (PrintBuiltinVectorAccess(op, op->index(), true)) {
os() << " = "; str_ += " = ";
Print(op->value); IrPrinter::Visit(op->value);
} else { } else {
CodeGenC::Visit(op); CodeGenC::Visit(op);
} }
......
...@@ -54,7 +54,7 @@ class CodeGenCUDA_Dev : public CodeGenC { ...@@ -54,7 +54,7 @@ class CodeGenCUDA_Dev : public CodeGenC {
//! Compile on NVRTC. //! Compile on NVRTC.
std::string Compile(const ir::Module& module, bool for_nvrtc = true); std::string Compile(const ir::Module& module, bool for_nvrtc = true);
std::string Compile(const ir::LoweredFunc& func); void Compile(const ir::LoweredFunc& func);
/** /**
* \brief Print a function argument in CUDA syntax. Currently, just some * \brief Print a function argument in CUDA syntax. Currently, just some
......
...@@ -44,14 +44,24 @@ struct Type::Storage { ...@@ -44,14 +44,24 @@ struct Type::Storage {
Type::~Type() {} Type::~Type() {}
std::ostream &operator<<(std::ostream &os, const Type &t) { std::string Type::to_string() const {
if (t.is_cpp_const()) os << "const "; std::string ret = "";
os << Type2Str(t); if (is_cpp_const()) ret += "const ";
ret += Type2Str(*this);
if (lanes() > 1) {
ret += "<";
ret += std::to_string(lanes());
ret += ">";
}
if (is_cpp_handle()) ret += "*";
if (is_cpp_handle2()) ret += "**";
if (t.lanes() > 1) os << "<" << t.lanes() << ">"; return ret;
if (t.is_cpp_handle()) os << "*"; }
if (t.is_cpp_handle2()) os << "**";
std::ostream &operator<<(std::ostream &os, const Type &t) {
os << t.to_string();
return os; return os;
} }
......
...@@ -149,6 +149,8 @@ struct Type { ...@@ -149,6 +149,8 @@ struct Type {
//! Check if a dtype is supported in CINN yet. //! Check if a dtype is supported in CINN yet.
bool is_supported() const; bool is_supported() const;
std::string to_string() const;
friend std::ostream& operator<<(std::ostream& os, const Type& t); friend std::ostream& operator<<(std::ostream& os, const Type& t);
~Type(); ~Type();
......
此差异已折叠。
...@@ -30,10 +30,10 @@ namespace ir { ...@@ -30,10 +30,10 @@ namespace ir {
class Module; class Module;
struct IrPrinter : public IRVisitorRequireReImpl<void> { struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os) {} explicit IrPrinter(std::ostream &os) : os_(os), str_("") {}
//! Emit an expression on the output stream. //! Emit an expression on the output stream.
void Print(Expr e); void Print(const Expr &e);
//! Emit a expression list with , splitted. //! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs, void Print(const std::vector<Expr> &exprs,
const std::string &splitter = ", "); const std::string &splitter = ", ");
...@@ -50,6 +50,17 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> { ...@@ -50,6 +50,17 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
std::ostream &os() { return os_; } std::ostream &os() { return os_; }
void Visit(const Expr &x) { IRVisitorRequireReImpl::Visit(&x); }
void Visit(const std::vector<Expr> &exprs,
const std::string &splitter = ", ") {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
}
#define __(op__) void Visit(const op__ *x) override; #define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL(__) NODETY_FORALL(__)
#undef __ #undef __
...@@ -58,6 +69,9 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> { ...@@ -58,6 +69,9 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
INTRINSIC_KIND_FOR_EACH(__) INTRINSIC_KIND_FOR_EACH(__)
#undef __ #undef __
protected:
std::string str_;
private: private:
std::ostream &os_; std::ostream &os_;
uint16_t indent_{}; uint16_t indent_{};
...@@ -71,11 +85,13 @@ std::ostream &operator<<(std::ostream &os, const Module &m); ...@@ -71,11 +85,13 @@ std::ostream &operator<<(std::ostream &os, const Module &m);
template <typename IRN> template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op, void IrPrinter::PrintBinaryOp(const std::string &op,
const BinaryOpNode<IRN> *x) { const BinaryOpNode<IRN> *x) {
os_ << "("; str_ += "(";
Print(x->a()); Visit(x->a());
os_ << " " + op + " "; str_ += " ";
Print(x->b()); str_ += op;
os_ << ")"; str_ += " ";
Visit(x->b());
str_ += ")";
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册