未验证 提交 68b0cf92 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】refactor codegen for cinn (#55955)

* refactor codegen for cinn

* add to_string to some type which can't be += with string

* fix multi-thread bug caused by static var

* delete dead code and comment
上级 db96ae58
此差异已折叠。
......@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter {
void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; }
protected:
std::string Compile(const ir::LoweredFunc& function);
std::string Compile(const ir::Buffer& buffer);
void Compile(const ir::LoweredFunc& function);
void GenerateHeaderFile(const ir::Module& module);
......@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter {
// @}
void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) {
os() << "void " << op->name << "(";
os() << "void* _args, int32_t num_args";
os() << ")";
str_ += "void ";
str_ += op->name;
str_ += "(";
str_ += "void* _args, int32_t num_args";
str_ += ")";
}
void PrintShape(const std::vector<Expr>& shape,
......
......@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_load(";
str_ += "cinn_avx512_load(";
PrintAbsAddr(op);
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_load(";
str_ += "cinn_avx256_load(";
PrintAbsAddr(op);
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_set1(";
str_ += "cinn_avx512_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_set1(";
str_ += "cinn_avx256_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_store(";
str_ += "cinn_avx512_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
str_ += ", ";
IrPrinter::Visit(op->value);
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_store(";
str_ += "cinn_avx256_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
str_ += ", ";
IrPrinter::Visit(op->value);
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value;
if (SupportsAVX512()) {
os() << "cinn_avx512_set1(";
Print(value);
os() << ")";
str_ += "cinn_avx512_set1(";
IrPrinter::Visit(value);
str_ += ")";
} else if (SupportsAVX256()) {
os() << "cinn_avx256_set1(";
Print(value);
os() << ")";
str_ += "cinn_avx256_set1(";
IrPrinter::Visit(value);
str_ += ")";
} else {
CINN_NOT_IMPLEMENTED
}
} else {
Print(*op);
IrPrinter::Visit(*op);
}
}
......@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) {
}
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op->name << "(";
str_ += "cinn_avx512_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]);
os() << ", ";
str_ += ", ";
}
Print(op->args.back());
IrPrinter::Visit(op->args.back());
}
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op->name << "(";
str_ += "cinn_avx256_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]);
os() << ", ";
str_ += ", ";
}
PrintVecInputArgument(&op->args.back());
}
os() << ")";
str_ += ")";
} else if (bits == 128) {
os() << "cinn_avx128_" << op->name << "(";
str_ += "cinn_avx128_";
str_ += op->name;
str_ += "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
PrintVecInputArgument(&op->args[i]);
os() << ", ";
str_ += ", ";
}
PrintVecInputArgument(&op->args.back());
}
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......
......@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC {
template <typename Op>
void PrintAbsAddr(const Op *op) {
os() << op->tensor.template As<ir::_Tensor_>()->name << " + ";
str_ += op->tensor.template As<ir::_Tensor_>()->name;
str_ += " + ";
auto index = op->index();
auto *ramp_n = index.template As<ir::Ramp>();
if (ramp_n) {
CHECK(!ramp_n->base.template As<ir::Ramp>())
<< "base of a Ramp node should not be Ramp type";
Print(ramp_n->base);
IrPrinter::Visit(ramp_n->base);
} else {
Print(op->index());
IrPrinter::Visit(op->index());
}
}
......@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op,
// TODO(Superjomn) Consider support BLAS.
int bits = a.type().bits() * a.type().lanes();
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op_repr << "(";
str_ += "cinn_avx512_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a);
os() << ", ";
str_ += ", ";
PrintVecInputArgument(&b);
os() << ")";
str_ += ")";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op_repr << "(";
str_ += "cinn_avx256_";
str_ += op_repr;
str_ += "(";
PrintVecInputArgument(&a);
os() << ", ";
str_ += ", ";
PrintVecInputArgument(&b);
os() << ")";
str_ += ")";
} else {
CodeGenC::Visit(op);
}
......
......@@ -62,6 +62,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
CodeGenC::inline_builtin_codes_ = false;
if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader);
str_ = "";
std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source;
......@@ -71,6 +72,7 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
if (!outputs.cuda_source_name.empty()) {
auto source = Compile(module, OutputKind::CImpl);
str_ = "";
std::ofstream file(outputs.cuda_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name;
file << source;
......@@ -79,9 +81,8 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module,
}
}
std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
Print(Expr(func));
return ss_.str();
void CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
IrPrinter::Visit(Expr(func));
}
std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
......@@ -117,10 +118,10 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
// clear names valid within scope when enter a new function
vectorized_tensor_names_.clear();
os() << "__global__\n";
str_ += "__global__\n";
PrintFunctionDeclaration(op);
os() << "\n";
str_ += "\n";
DoIndent();
......@@ -145,15 +146,16 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
if (!func_body.As<ir::Block>()) {
func_body = ir::Block::Make({func_body});
}
Print(func_body);
IrPrinter::Visit(func_body);
}
void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) {
if (utils::Startswith(op->name, "threadIdx") ||
utils::Startswith(op->name, "blockIdx")) {
os() << "(int)" + op->name;
str_ += "(int)";
str_ += op->name;
} else {
os() << op->name;
str_ += op->name;
}
}
......@@ -163,58 +165,63 @@ void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) {
}
void CodeGenCUDA_Dev::Visit(const ir::Min *op) {
os() << "min(";
Print(op->a());
os() << ", ";
Print(op->b());
os() << ")";
str_ += "min(";
IrPrinter::Visit(op->a());
str_ += ", ";
IrPrinter::Visit(op->b());
str_ += ")";
}
void CodeGenCUDA_Dev::Visit(const ir::Max *op) {
os() << "max(";
Print(op->a());
os() << ", ";
Print(op->b());
os() << ")";
str_ += "max(";
IrPrinter::Visit(op->a());
str_ += ", ";
IrPrinter::Visit(op->b());
str_ += ")";
}
void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) {
os() << "void ";
str_ += "void ";
if (op->cuda_axis_info.valid()) {
int thread_num = 1;
for (int i = 0; i < 3; i++) {
thread_num *= op->cuda_axis_info.block_dim(i);
}
os() << "__launch_bounds__(" << thread_num << ") ";
str_ += "__launch_bounds__(";
str_ += std::to_string(thread_num);
str_ += ") ";
}
os() << op->name << "(";
str_ += op->name;
str_ += "(";
for (int i = 0; i < op->args.size() - 1; i++) {
auto &arg = op->args[i];
PrintFuncArg(arg);
os() << ", ";
str_ += ", ";
}
if (!op->args.empty()) {
PrintFuncArg(op->args.back());
}
os() << ")";
str_ += ")";
}
void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) {
// In CUDA kernel, only primitive type is supported, so we replace the
// buffer with T*j
if (arg.is_input()) os() << "const ";
os() << GetTypeRepr(arg.buffer_arg()->dtype);
os() << "* ";
os() << kCKeywordRestrict << " ";
os() << ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>());
if (arg.is_input()) str_ += "const ";
str_ += GetTypeRepr(arg.buffer_arg()->dtype);
str_ += "* ";
str_ += kCKeywordRestrict;
str_ += " ";
str_ += ir::BufferGetTensorName(arg.buffer_arg().As<ir::_Buffer_>());
} else if (arg.is_var()) {
if (arg.var_arg()->type().is_cpp_handle()) {
os() << kCKeywordRestrict;
str_ += kCKeywordRestrict;
}
os() << GetTypeRepr(arg.type()) << " ";
os() << arg.name();
str_ += GetTypeRepr(arg.type());
str_ += " ";
str_ += arg.name();
} else {
CINN_NOT_IMPLEMENTED
}
......@@ -230,7 +237,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
PrintIncludes();
if (for_nvrtc_) {
os() << "\nextern \"C\" {\n\n";
str_ += "\nextern \"C\" {\n\n";
}
PrintBuiltinCodes();
......@@ -243,27 +250,30 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
}
if (for_nvrtc_) {
os() << "\n\n}";
str_ += "\n\n}";
}
return ss_.str();
return str_;
}
void CodeGenCUDA_Dev::PrintIncludes() { os() << GetSourceHeader(); }
void CodeGenCUDA_Dev::PrintIncludes() { str_ += GetSourceHeader(); }
void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
CHECK_NE(buffer->type(), Void());
auto print_gpu_memory = [&](const std::string &mark) {
os() << mark << GetTypeRepr(buffer->dtype) << " " << buffer->name << " ";
str_ += mark;
str_ += GetTypeRepr(buffer->dtype);
str_ += " ";
str_ += buffer->name;
str_ += " ";
os() << "[ ";
str_ += "[ ";
Expr buffer_size(1);
for (int i = 0; i < buffer->shape.size(); i++) {
buffer_size = buffer_size * buffer->shape[i];
}
optim::Simplify(&buffer_size);
Print(buffer_size);
os() << " ]";
IrPrinter::Visit(buffer_size);
str_ += " ]";
};
switch (buffer->memory_type) {
case ir::MemoryType::GPUShared:
......@@ -281,46 +291,47 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) {
}
void CodeGenCUDA_Dev::Visit(const ir::Call *op) {
os() << op->name + "(";
str_ += op->name;
str_ += "(";
if (!op->read_args.empty()) {
for (int i = 0; i < op->read_args.size() - 1; i++) {
auto &arg = op->read_args[i];
if (arg.as_tensor()) {
os() << arg.as_tensor()->name;
os() << ", ";
str_ += arg.as_tensor()->name;
str_ += ", ";
} else {
Print(arg);
os() << ", ";
IrPrinter::Visit(arg);
str_ += ", ";
}
}
if (op->read_args.back().as_tensor()) {
os() << op->read_args.back().as_tensor()->name;
str_ += op->read_args.back().as_tensor()->name;
} else {
Print(op->read_args.back());
IrPrinter::Visit(op->read_args.back());
}
}
if (!op->write_args.empty()) {
os() << ", ";
str_ += ", ";
for (int i = 0; i < op->write_args.size() - 1; i++) {
auto &arg = op->write_args[i];
if (arg.as_tensor()) {
os() << arg.as_tensor()->name;
os() << ", ";
str_ += arg.as_tensor()->name;
str_ += ", ";
} else {
Print(arg);
os() << ", ";
IrPrinter::Visit(arg);
str_ += ", ";
}
}
if (op->write_args.back().as_tensor()) {
os() << op->write_args.back().as_tensor()->name;
str_ += op->write_args.back().as_tensor()->name;
} else {
Print(op->write_args.back());
IrPrinter::Visit(op->write_args.back());
}
}
os() << ")";
str_ += ")";
}
void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
......@@ -331,20 +342,21 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
if (op->type().is_customized() &&
utils::Startswith(op->type().customized_type(),
common::customized_type::kcuda_builtin_vector_t)) {
os() << GetTypeRepr(op->type());
str_ += GetTypeRepr(op->type());
if (op->type().is_cpp_handle()) {
os() << " " << kCKeywordRestrict;
str_ += " ";
str_ += kCKeywordRestrict;
}
os() << " ";
Print(op->symbol);
str_ += " ";
IrPrinter::Visit(op->symbol);
vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol));
// skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not
// overloaded.
if (op->body.As<ir::IntImm>() && op->body.As<ir::IntImm>()->value == 0) {
return;
}
os() << " = ";
Print(op->body);
str_ += " = ";
IrPrinter::Visit(op->body);
} else {
CodeGenC::Visit(op);
}
......@@ -374,10 +386,14 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op,
return false;
}
if (is_store && tensor->type().is_cpp_handle()) {
os() << tensor->name << "[" << index << "]";
str_ += tensor->name;
str_ += "[";
str_ += std::to_string(index);
str_ += "]";
} else {
os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".")
<< index2suffix[index];
str_ += tensor->name;
str_ += (tensor->type().is_cpp_handle() ? "->" : ".");
str_ += index2suffix[index];
}
return true;
}
......@@ -396,8 +412,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) {
// accesses element at a cuda built-in vector, others still resolve to
// CodeGenC
if (PrintBuiltinVectorAccess(op, op->index(), true)) {
os() << " = ";
Print(op->value);
str_ += " = ";
IrPrinter::Visit(op->value);
} else {
CodeGenC::Visit(op);
}
......
......@@ -54,7 +54,7 @@ class CodeGenCUDA_Dev : public CodeGenC {
//! Compile on NVRTC.
std::string Compile(const ir::Module& module, bool for_nvrtc = true);
std::string Compile(const ir::LoweredFunc& func);
void Compile(const ir::LoweredFunc& func);
/**
* \brief Print a function argument in CUDA syntax. Currently, just some
......
......@@ -44,14 +44,24 @@ struct Type::Storage {
Type::~Type() {}
std::ostream &operator<<(std::ostream &os, const Type &t) {
if (t.is_cpp_const()) os << "const ";
os << Type2Str(t);
std::string Type::to_string() const {
std::string ret = "";
if (is_cpp_const()) ret += "const ";
ret += Type2Str(*this);
if (lanes() > 1) {
ret += "<";
ret += std::to_string(lanes());
ret += ">";
}
if (is_cpp_handle()) ret += "*";
if (is_cpp_handle2()) ret += "**";
if (t.lanes() > 1) os << "<" << t.lanes() << ">";
if (t.is_cpp_handle()) os << "*";
if (t.is_cpp_handle2()) os << "**";
return ret;
}
std::ostream &operator<<(std::ostream &os, const Type &t) {
os << t.to_string();
return os;
}
......
......@@ -149,6 +149,8 @@ struct Type {
//! Check if a dtype is supported in CINN yet.
bool is_supported() const;
std::string to_string() const;
friend std::ostream& operator<<(std::ostream& os, const Type& t);
~Type();
......
此差异已折叠。
......@@ -30,10 +30,10 @@ namespace ir {
class Module;
struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os) {}
explicit IrPrinter(std::ostream &os) : os_(os), str_("") {}
//! Emit an expression on the output stream.
void Print(Expr e);
void Print(const Expr &e);
//! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs,
const std::string &splitter = ", ");
......@@ -50,6 +50,17 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
std::ostream &os() { return os_; }
void Visit(const Expr &x) { IRVisitorRequireReImpl::Visit(&x); }
void Visit(const std::vector<Expr> &exprs,
const std::string &splitter = ", ") {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
Visit(exprs[i]);
str_ += splitter;
}
if (!exprs.empty()) Visit(exprs.back());
}
#define __(op__) void Visit(const op__ *x) override;
NODETY_FORALL(__)
#undef __
......@@ -58,6 +69,9 @@ struct IrPrinter : public IRVisitorRequireReImpl<void> {
INTRINSIC_KIND_FOR_EACH(__)
#undef __
protected:
std::string str_;
private:
std::ostream &os_;
uint16_t indent_{};
......@@ -71,11 +85,13 @@ std::ostream &operator<<(std::ostream &os, const Module &m);
template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op,
const BinaryOpNode<IRN> *x) {
os_ << "(";
Print(x->a());
os_ << " " + op + " ";
Print(x->b());
os_ << ")";
str_ += "(";
Visit(x->a());
str_ += " ";
str_ += op;
str_ += " ";
Visit(x->b());
str_ += ")";
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册