未验证 提交 ed6b9567 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】refactor ir_visitor (#55171)

This PR delete middle ir_visitor class and thus we can avoid middle virtual function call and codes look more clean
pcard-72718
上级 4543ca91
...@@ -45,7 +45,9 @@ using namespace ::cinn::ir; // NOLINT ...@@ -45,7 +45,9 @@ using namespace ::cinn::ir; // NOLINT
FeatureExtractor::FeatureExtractor() {} FeatureExtractor::FeatureExtractor() {}
void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } void FeatureExtractor::Visit(const Expr *x) {
IRVisitorRequireReImpl::Visit(x);
}
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) { const common::Target &target) {
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
class FeatureExtractor : public ir::IRVisitor { class FeatureExtractor : public ir::IRVisitorRequireReImpl<void> {
public: public:
FeatureExtractor(); FeatureExtractor();
Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target); Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target);
......
...@@ -71,7 +71,7 @@ bool operator<(const SearchState& left, const SearchState& right) { ...@@ -71,7 +71,7 @@ bool operator<(const SearchState& left, const SearchState& right) {
} }
// Visit every node by expanding all of their fields in dfs order // Visit every node by expanding all of their fields in dfs order
class DfsWithExprsFields : public ir::IRVisitor { class DfsWithExprsFields : public ir::IRVisitorRequireReImpl<void> {
protected: protected:
#define __m(t__) \ #define __m(t__) \
void Visit(const ir::t__* x) override { \ void Visit(const ir::t__* x) override { \
...@@ -85,7 +85,7 @@ class DfsWithExprsFields : public ir::IRVisitor { ...@@ -85,7 +85,7 @@ class DfsWithExprsFields : public ir::IRVisitor {
NODETY_FORALL(__m) NODETY_FORALL(__m)
#undef __m #undef __m
void Visit(const Expr* expr) override { IRVisitor::Visit(expr); } void Visit(const Expr* expr) override { IRVisitorRequireReImpl::Visit(expr); }
}; };
// Generate a reduce hash of a AST tree by combining hash of each AST node // Generate a reduce hash of a AST tree by combining hash of each AST node
......
...@@ -39,11 +39,11 @@ ...@@ -39,11 +39,11 @@
namespace cinn { namespace cinn {
namespace backends { namespace backends {
class LLVMIRVisitor : public ir::IRVisitorBase<llvm::Value *> { class LLVMIRVisitor : public ir::IRVisitorRequireReImpl<llvm::Value *> {
public: public:
LLVMIRVisitor() = default; LLVMIRVisitor() = default;
using ir::IRVisitorBase<llvm::Value *>::Visit; using ir::IRVisitorRequireReImpl<llvm::Value *>::Visit;
#define __m(t__) virtual llvm::Value *Visit(const ir::t__ *x) = 0; #define __m(t__) virtual llvm::Value *Visit(const ir::t__ *x) = 0;
NODETY_FORALL(__m) NODETY_FORALL(__m)
#undef __m #undef __m
......
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
namespace cinn { namespace cinn {
namespace backends { namespace backends {
class ModularEvaluator : public ir::IRVisitorBase<ModularEntry> { class ModularEvaluator : public ir::IRVisitorRequireReImpl<ModularEntry> {
public: public:
explicit ModularEvaluator(const std::map<Var, ModularEntry>& mod_map) explicit ModularEvaluator(const std::map<Var, ModularEntry>& mod_map)
: mod_map_(mod_map) {} : mod_map_(mod_map) {}
ModularEntry Eval(const Expr& e) { ModularEntry Eval(const Expr& e) {
return ir::IRVisitorBase<ModularEntry>::Visit(&e); return ir::IRVisitorRequireReImpl<ModularEntry>::Visit(&e);
} }
ModularEntry Visit(const ir::IntImm* op) { ModularEntry Visit(const ir::IntImm* op) {
......
...@@ -24,7 +24,7 @@ namespace ir { ...@@ -24,7 +24,7 @@ namespace ir {
namespace { namespace {
struct IrNodesCollector : public IRVisitor { struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
using teller_t = std::function<bool(const Expr*)>; using teller_t = std::function<bool(const Expr*)>;
using handler_t = std::function<void(const Expr*)>; using handler_t = std::function<void(const Expr*)>;
......
...@@ -32,7 +32,7 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { ...@@ -32,7 +32,7 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
VLOG(5) << "Not equal on Expr, someone not defined"; VLOG(5) << "Not equal on Expr, someone not defined";
} }
bool equal = lhs->node_type() == rhs->node_type(); bool equal = lhs->node_type() == rhs->node_type();
equal = equal && IRVisitorBase<bool, const Expr*>::Visit(&lhs, &rhs); equal = equal && IRVisitorRequireReImpl<bool, const Expr*>::Visit(&lhs, &rhs);
if (!equal) { if (!equal) {
VLOG(5) << "Not equal on Expr, lhs:[type:" VLOG(5) << "Not equal on Expr, lhs:[type:"
......
...@@ -23,7 +23,7 @@ namespace ir { ...@@ -23,7 +23,7 @@ namespace ir {
// Determine whether two ir AST trees are euqal by comparing their struct and // Determine whether two ir AST trees are euqal by comparing their struct and
// fields of each node through dfs visitor // fields of each node through dfs visitor
class IrEqualVisitor : public IRVisitorBase<bool, const Expr*> { class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public: public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false) explicit IrEqualVisitor(bool allow_name_suffix_diff = false)
: allow_name_suffix_diff_(allow_name_suffix_diff) {} : allow_name_suffix_diff_(allow_name_suffix_diff) {}
......
...@@ -26,7 +26,7 @@ namespace ir { ...@@ -26,7 +26,7 @@ namespace ir {
//! T might be Expr* or const Expr* //! T might be Expr* or const Expr*
template <typename T = Expr *> template <typename T = Expr *>
class IRMutator : public IRVisitorBase<void, T> { class IRMutator : public IRVisitorRequireReImpl<void, T> {
public: public:
void Visit(const Expr *expr, T op) override; void Visit(const Expr *expr, T op) override;
...@@ -37,22 +37,22 @@ class IRMutator : public IRVisitorBase<void, T> { ...@@ -37,22 +37,22 @@ class IRMutator : public IRVisitorBase<void, T> {
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Expr *expr, T op) { void IRMutator<T>::Visit(const Expr *expr, T op) {
IRVisitorBase<void, T>::Visit(expr, op); IRVisitorRequireReImpl<void, T>::Visit(expr, op);
} }
#define UNARY_OP_IMPL(op__) \ #define UNARY_OP_IMPL(op__) \
template <typename T> \ template <typename T> \
void IRMutator<T>::Visit(const op__ *expr, T op) { \ void IRMutator<T>::Visit(const op__ *expr, T op) { \
auto *node = op->template As<op__>(); \ auto *node = op->template As<op__>(); \
IRVisitorBase<void, T>::Visit(&node->v(), &node->v()); \ IRVisitorRequireReImpl<void, T>::Visit(&node->v(), &node->v()); \
} }
#define BINARY_OP_IMPL(op__) \ #define BINARY_OP_IMPL(op__) \
template <typename T> \ template <typename T> \
void IRMutator<T>::Visit(const op__ *expr, T op) { \ void IRMutator<T>::Visit(const op__ *expr, T op) { \
auto *node = op->template As<op__>(); \ auto *node = op->template As<op__>(); \
IRVisitorBase<void, T>::Visit(&node->a(), &node->a()); \ IRVisitorRequireReImpl<void, T>::Visit(&node->a(), &node->a()); \
IRVisitorBase<void, T>::Visit(&node->b(), &node->b()); \ IRVisitorRequireReImpl<void, T>::Visit(&node->b(), &node->b()); \
} }
NODETY_UNARY_OP_FOR_EACH(UNARY_OP_IMPL) NODETY_UNARY_OP_FOR_EACH(UNARY_OP_IMPL)
...@@ -77,172 +77,181 @@ void IRMutator<T>::Visit(const Cast *expr, T op) { ...@@ -77,172 +77,181 @@ void IRMutator<T>::Visit(const Cast *expr, T op) {
template <typename T> template <typename T>
void IRMutator<T>::Visit(const For *expr, T op) { void IRMutator<T>::Visit(const For *expr, T op) {
auto *node = op->template As<For>(); auto *node = op->template As<For>();
IRVisitorBase<void, T>::Visit(&node->min, &node->min); IRVisitorRequireReImpl<void, T>::Visit(&node->min, &node->min);
IRVisitorBase<void, T>::Visit(&node->extent, &node->extent); IRVisitorRequireReImpl<void, T>::Visit(&node->extent, &node->extent);
IRVisitorBase<void, T>::Visit(&node->body, &node->body); IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const PolyFor *expr, T op) { void IRMutator<T>::Visit(const PolyFor *expr, T op) {
auto *node = op->template As<PolyFor>(); auto *node = op->template As<PolyFor>();
// IRVisitorBase<void,T>::Visit(&node->iterator, &node->iterator); // IRVisitorRequireReImpl<void,T>::Visit(&node->iterator,
IRVisitorBase<void, T>::Visit(&node->init, &node->init); // &node->iterator);
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition); IRVisitorRequireReImpl<void, T>::Visit(&node->init, &node->init);
IRVisitorBase<void, T>::Visit(&node->inc, &node->inc); IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->body, &node->body); IRVisitorRequireReImpl<void, T>::Visit(&node->inc, &node->inc);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Select *expr, T op) { void IRMutator<T>::Visit(const Select *expr, T op) {
auto *node = op->template As<Select>(); auto *node = op->template As<Select>();
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition); IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->true_value, &node->true_value); IRVisitorRequireReImpl<void, T>::Visit(&node->true_value, &node->true_value);
IRVisitorBase<void, T>::Visit(&node->false_value, &node->false_value); IRVisitorRequireReImpl<void, T>::Visit(&node->false_value,
&node->false_value);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const IfThenElse *expr, T op) { void IRMutator<T>::Visit(const IfThenElse *expr, T op) {
auto *node = op->template As<IfThenElse>(); auto *node = op->template As<IfThenElse>();
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition); IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->true_case, &node->true_case); IRVisitorRequireReImpl<void, T>::Visit(&node->true_case, &node->true_case);
if (node->false_case.defined()) if (node->false_case.defined())
IRVisitorBase<void, T>::Visit(&node->false_case, &node->false_case); IRVisitorRequireReImpl<void, T>::Visit(&node->false_case,
&node->false_case);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Block *expr, T op) { void IRMutator<T>::Visit(const Block *expr, T op) {
auto *node = op->template As<Block>(); auto *node = op->template As<Block>();
for (auto &expr : node->stmts) { for (auto &expr : node->stmts) {
IRVisitorBase<void, T>::Visit(&expr, &expr); IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
} }
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Call *expr, T op) { void IRMutator<T>::Visit(const Call *expr, T op) {
auto *node = op->template As<Call>(); auto *node = op->template As<Call>();
for (auto &expr : node->read_args) { for (auto &expr : node->read_args) {
IRVisitorBase<void, T>::Visit(&expr, &expr); IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
} }
for (auto &expr : node->write_args) { for (auto &expr : node->write_args) {
IRVisitorBase<void, T>::Visit(&expr, &expr); IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
} }
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const _Module_ *expr, T op) { void IRMutator<T>::Visit(const _Module_ *expr, T op) {
auto *node = op->template As<_Module_>(); auto *node = op->template As<_Module_>();
for (auto &func : node->functions) { for (auto &func : node->functions) {
IRVisitorBase<void, T>::Visit(&func, &func); IRVisitorRequireReImpl<void, T>::Visit(&func, &func);
} }
for (auto &func : node->buffers) { for (auto &func : node->buffers) {
IRVisitorBase<void, T>::Visit(&func, &func); IRVisitorRequireReImpl<void, T>::Visit(&func, &func);
} }
for (auto &expr : node->submodules) { for (auto &expr : node->submodules) {
IRVisitorBase<void, T>::Visit(&expr, &expr); IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
} }
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const _Var_ *expr, T op) { void IRMutator<T>::Visit(const _Var_ *expr, T op) {
auto *node = op->template As<ir::_Var_>(); auto *node = op->template As<ir::_Var_>();
if (node->lower_bound.defined()) { if (node->lower_bound.defined()) {
IRVisitorBase<void, T>::Visit(&node->lower_bound, &node->lower_bound); IRVisitorRequireReImpl<void, T>::Visit(&node->lower_bound,
&node->lower_bound);
} }
if (node->upper_bound.defined()) { if (node->upper_bound.defined()) {
IRVisitorBase<void, T>::Visit(&node->upper_bound, &node->upper_bound); IRVisitorRequireReImpl<void, T>::Visit(&node->upper_bound,
&node->upper_bound);
} }
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Load *expr, T op) { void IRMutator<T>::Visit(const Load *expr, T op) {
auto *node = op->template As<Load>(); auto *node = op->template As<Load>();
for (auto &idx : node->indices) IRVisitorBase<void, T>::Visit(&idx, &idx); for (auto &idx : node->indices)
IRVisitorBase<void, T>::Visit(&node->tensor, &node->tensor); IRVisitorRequireReImpl<void, T>::Visit(&idx, &idx);
IRVisitorRequireReImpl<void, T>::Visit(&node->tensor, &node->tensor);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Store *expr, T op) { void IRMutator<T>::Visit(const Store *expr, T op) {
auto *node = op->template As<Store>(); auto *node = op->template As<Store>();
IRVisitorBase<void, T>::Visit(&node->value, &node->value); IRVisitorRequireReImpl<void, T>::Visit(&node->value, &node->value);
IRVisitorBase<void, T>::Visit(&node->tensor, &node->tensor); IRVisitorRequireReImpl<void, T>::Visit(&node->tensor, &node->tensor);
for (auto &idx : node->indices) IRVisitorBase<void, T>::Visit(&idx, &idx); for (auto &idx : node->indices)
IRVisitorRequireReImpl<void, T>::Visit(&idx, &idx);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Alloc *expr, T op) { void IRMutator<T>::Visit(const Alloc *expr, T op) {
auto *node = op->template As<Alloc>(); auto *node = op->template As<Alloc>();
for (auto &e : node->extents) { for (auto &e : node->extents) {
IRVisitorBase<void, T>::Visit(&e, &e); IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
} }
if (node->condition.defined()) if (node->condition.defined())
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition); IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
if (node->body.defined()) { if (node->body.defined()) {
Expr body(node->body); Expr body(node->body);
IRVisitorBase<void, T>::Visit(&node->body, &body); IRVisitorRequireReImpl<void, T>::Visit(&node->body, &body);
} }
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Free *expr, T op) { void IRMutator<T>::Visit(const Free *expr, T op) {
auto *node = op->template As<Free>(); auto *node = op->template As<Free>();
IRVisitorBase<void, T>::Visit(&node->destination, &node->destination); IRVisitorRequireReImpl<void, T>::Visit(&node->destination,
&node->destination);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const _Buffer_ *expr, T op) { void IRMutator<T>::Visit(const _Buffer_ *expr, T op) {
auto *node = op->template As<_Buffer_>(); auto *node = op->template As<_Buffer_>();
for (auto &e : node->shape) { for (auto &e : node->shape) {
IRVisitorBase<void, T>::Visit(&e, &e); IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
} }
for (auto &e : node->strides) { for (auto &e : node->strides) {
IRVisitorBase<void, T>::Visit(&e, &e); IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
} }
IRVisitorBase<void, T>::Visit(&node->elem_offset, &node->elem_offset); IRVisitorRequireReImpl<void, T>::Visit(&node->elem_offset,
&node->elem_offset);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const _Tensor_ *expr, T op) { void IRMutator<T>::Visit(const _Tensor_ *expr, T op) {
auto *node = op->template As<_Tensor_>(); auto *node = op->template As<_Tensor_>();
for (auto &e : node->shape) { for (auto &e : node->shape) {
IRVisitorBase<void, T>::Visit(&e, &e); IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
} }
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const _LoweredFunc_ *expr, T op) { void IRMutator<T>::Visit(const _LoweredFunc_ *expr, T op) {
auto *node = op->template As<_LoweredFunc_>(); auto *node = op->template As<_LoweredFunc_>();
IRVisitorBase<void, T>::Visit(&node->body, &node->body); IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Let *expr, T op) { void IRMutator<T>::Visit(const Let *expr, T op) {
auto *node = op->template As<Let>(); auto *node = op->template As<Let>();
IRVisitorBase<void, T>::Visit(&node->symbol, &node->symbol); IRVisitorRequireReImpl<void, T>::Visit(&node->symbol, &node->symbol);
if (node->body.defined()) if (node->body.defined())
IRVisitorBase<void, T>::Visit(&node->body, &node->body); IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Reduce *expr, T op) { void IRMutator<T>::Visit(const Reduce *expr, T op) {
auto *node = op->template As<Reduce>(); auto *node = op->template As<Reduce>();
if (node->init.defined()) if (node->init.defined())
IRVisitorBase<void, T>::Visit(&node->init, &node->init); IRVisitorRequireReImpl<void, T>::Visit(&node->init, &node->init);
CHECK(node->body.defined()); CHECK(node->body.defined());
IRVisitorBase<void, T>::Visit(&node->body, &node->body); IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Ramp *expr, T op) { void IRMutator<T>::Visit(const Ramp *expr, T op) {
auto *node = op->template As<Ramp>(); auto *node = op->template As<Ramp>();
IRVisitorBase<void, T>::Visit(&node->base, &node->base); IRVisitorRequireReImpl<void, T>::Visit(&node->base, &node->base);
IRVisitorBase<void, T>::Visit(&node->stride, &node->stride); IRVisitorRequireReImpl<void, T>::Visit(&node->stride, &node->stride);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Broadcast *expr, T op) { void IRMutator<T>::Visit(const Broadcast *expr, T op) {
auto *node = op->template As<Broadcast>(); auto *node = op->template As<Broadcast>();
IRVisitorBase<void, T>::Visit(&node->value, &node->value); IRVisitorRequireReImpl<void, T>::Visit(&node->value, &node->value);
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const FracOp *expr, T op) { void IRMutator<T>::Visit(const FracOp *expr, T op) {
auto *node = op->template As<FracOp>(); auto *node = op->template As<FracOp>();
IRVisitorBase<void, T>::Visit(&node->a(), &node->a()); IRVisitorRequireReImpl<void, T>::Visit(&node->a(), &node->a());
IRVisitorBase<void, T>::Visit(&node->b(), &node->b()); IRVisitorRequireReImpl<void, T>::Visit(&node->b(), &node->b());
} }
template <typename T> template <typename T>
void IRMutator<T>::Visit(const Product *expr, T op) { void IRMutator<T>::Visit(const Product *expr, T op) {
auto *node = op->template As<Product>(); auto *node = op->template As<Product>();
for (auto &x : node->operands()) { for (auto &x : node->operands()) {
IRVisitorBase<void, T>::Visit(&x, &x); IRVisitorRequireReImpl<void, T>::Visit(&x, &x);
} }
} }
...@@ -250,7 +259,7 @@ template <typename T> ...@@ -250,7 +259,7 @@ template <typename T>
void IRMutator<T>::Visit(const Sum *expr, T op) { void IRMutator<T>::Visit(const Sum *expr, T op) {
auto *node = op->template As<Sum>(); auto *node = op->template As<Sum>();
for (auto &x : node->operands()) { for (auto &x : node->operands()) {
IRVisitorBase<void, T>::Visit(&x, &x); IRVisitorRequireReImpl<void, T>::Visit(&x, &x);
} }
} }
template <typename T> template <typename T>
...@@ -258,7 +267,7 @@ void IRMutator<T>::Visit(const PrimitiveNode *expr, T op) { ...@@ -258,7 +267,7 @@ void IRMutator<T>::Visit(const PrimitiveNode *expr, T op) {
auto *node = op->template As<PrimitiveNode>(); auto *node = op->template As<PrimitiveNode>();
for (auto &args : node->arguments) { for (auto &args : node->arguments) {
for (auto &arg : args) { for (auto &arg : args) {
IRVisitorBase<void, T>::Visit(&arg, &arg); IRVisitorRequireReImpl<void, T>::Visit(&arg, &arg);
} }
} }
} }
...@@ -292,13 +301,15 @@ template <typename T> ...@@ -292,13 +301,15 @@ template <typename T>
void IRMutator<T>::Visit(const _BufferRange_ *expr, T op) { void IRMutator<T>::Visit(const _BufferRange_ *expr, T op) {
auto *node = op->template As<_BufferRange_>(); auto *node = op->template As<_BufferRange_>();
CHECK(node); CHECK(node);
IRVisitorBase<void, T>::Visit(&node->buffer, &node->buffer); IRVisitorRequireReImpl<void, T>::Visit(&node->buffer, &node->buffer);
for (auto &var : node->ranges) { for (auto &var : node->ranges) {
if (var->lower_bound.defined()) { if (var->lower_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->lower_bound, &var->lower_bound); IRVisitorRequireReImpl<void, T>::Visit(&var->lower_bound,
&var->lower_bound);
} }
if (var->upper_bound.defined()) { if (var->upper_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->upper_bound, &var->upper_bound); IRVisitorRequireReImpl<void, T>::Visit(&var->upper_bound,
&var->upper_bound);
} }
} }
} }
...@@ -309,19 +320,21 @@ void IRMutator<T>::Visit(const ScheduleBlock *expr, T op) { ...@@ -309,19 +320,21 @@ void IRMutator<T>::Visit(const ScheduleBlock *expr, T op) {
CHECK(node); CHECK(node);
for (auto &var : node->iter_vars) { for (auto &var : node->iter_vars) {
if (var->lower_bound.defined()) { if (var->lower_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->lower_bound, &var->lower_bound); IRVisitorRequireReImpl<void, T>::Visit(&var->lower_bound,
&var->lower_bound);
} }
if (var->upper_bound.defined()) { if (var->upper_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->upper_bound, &var->upper_bound); IRVisitorRequireReImpl<void, T>::Visit(&var->upper_bound,
&var->upper_bound);
} }
} }
for (auto &buffer_region : node->read_buffers) { for (auto &buffer_region : node->read_buffers) {
IRVisitorBase<void, T>::Visit(&buffer_region, &buffer_region); IRVisitorRequireReImpl<void, T>::Visit(&buffer_region, &buffer_region);
} }
for (auto &buffer_region : node->write_buffers) { for (auto &buffer_region : node->write_buffers) {
IRVisitorBase<void, T>::Visit(&buffer_region, &buffer_region); IRVisitorRequireReImpl<void, T>::Visit(&buffer_region, &buffer_region);
} }
IRVisitorBase<void, T>::Visit(&(node->body), &(node->body)); IRVisitorRequireReImpl<void, T>::Visit(&(node->body), &(node->body));
} }
template <typename T> template <typename T>
...@@ -329,9 +342,10 @@ void IRMutator<T>::Visit(const ScheduleBlockRealize *expr, T op) { ...@@ -329,9 +342,10 @@ void IRMutator<T>::Visit(const ScheduleBlockRealize *expr, T op) {
auto *node = op->template As<ScheduleBlockRealize>(); auto *node = op->template As<ScheduleBlockRealize>();
CHECK(node); CHECK(node);
for (auto &value : node->iter_values) { for (auto &value : node->iter_values) {
IRVisitorBase<void, T>::Visit(&value, &value); IRVisitorRequireReImpl<void, T>::Visit(&value, &value);
} }
IRVisitorBase<void, T>::Visit(&node->schedule_block, &node->schedule_block); IRVisitorRequireReImpl<void, T>::Visit(&node->schedule_block,
&node->schedule_block);
} }
} // namespace ir } // namespace ir
......
...@@ -32,7 +32,7 @@ namespace ir { ...@@ -32,7 +32,7 @@ namespace ir {
using common::bfloat16; using common::bfloat16;
using common::float16; using common::float16;
void IrPrinter::Print(Expr e) { IRVisitor::Visit(&e); } void IrPrinter::Print(Expr e) { IRVisitorRequireReImpl::Visit(&e); }
void IrPrinter::Print(const std::vector<Expr> &exprs, void IrPrinter::Print(const std::vector<Expr> &exprs,
const std::string &splitter) { const std::string &splitter) {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) { for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
......
...@@ -29,7 +29,7 @@ class LoweredFunc; ...@@ -29,7 +29,7 @@ class LoweredFunc;
namespace ir { namespace ir {
class Module; class Module;
struct IrPrinter : public IRVisitor { struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os) {} explicit IrPrinter(std::ostream &os) : os_(os) {}
//! Emit an expression on the output stream. //! Emit an expression on the output stream.
......
...@@ -34,7 +34,8 @@ struct _Tensor_; ...@@ -34,7 +34,8 @@ struct _Tensor_;
* @param Args type of the extra arguments passed to the all the methods. * @param Args type of the extra arguments passed to the all the methods.
*/ */
template <typename RetTy = void, typename... Args> template <typename RetTy = void, typename... Args>
struct IRVisitorBase { class IRVisitorRequireReImpl {
public:
//! Visit a expression. //! Visit a expression.
// @{ // @{
virtual RetTy Visit(const ir::Expr* expr, Args... args) { virtual RetTy Visit(const ir::Expr* expr, Args... args) {
...@@ -53,7 +54,6 @@ struct IRVisitorBase { ...@@ -53,7 +54,6 @@ struct IRVisitorBase {
return RetTy(); return RetTy();
} }
// @} // @}
protected: protected:
#define __(op__) virtual RetTy Visit(const ir::op__* op, Args... args) = 0; #define __(op__) virtual RetTy Visit(const ir::op__* op, Args... args) = 0;
NODETY_FORALL(__) NODETY_FORALL(__)
...@@ -63,16 +63,19 @@ struct IRVisitorBase { ...@@ -63,16 +63,19 @@ struct IRVisitorBase {
/** /**
* Base of all the Ir readonly visitor. * Base of all the Ir readonly visitor.
*/ */
struct IRVisitor : public IRVisitorBase<void> { struct IRVisitor : public IRVisitorRequireReImpl<void> {
IRVisitor() = default; IRVisitor() = default;
void Visit(const Expr* x) { IRVisitorBase::Visit(x); } void Visit(const Expr* x) { IRVisitorRequireReImpl::Visit(x); }
#define __m(t__) \ #define __m(t__) \
virtual void Visit(const t__* x) {} virtual void Visit(const t__* x) { return VisitDefault(x); }
NODETY_FORALL(__m) NODETY_FORALL(__m)
#undef __m #undef __m
};
virtual void VisitDefault(const Object* obj) {
LOG(FATAL) << "not supported NodeTy";
}
};
// std::set<Expr> CollectIRNodes(Expr expr, std::function<bool(const Expr*)> // std::set<Expr> CollectIRNodes(Expr expr, std::function<bool(const Expr*)>
// teller); // teller);
......
...@@ -32,7 +32,7 @@ namespace { ...@@ -32,7 +32,7 @@ namespace {
struct StoreDebugInfoBuilder : public ir::IRVisitor { struct StoreDebugInfoBuilder : public ir::IRVisitor {
std::tuple<std::string, std::vector<Expr>> operator()(const Expr *e) { std::tuple<std::string, std::vector<Expr>> operator()(const Expr *e) {
ir::IRVisitor::Visit(e); IRVisitor::Visit(e);
return std::make_tuple(format_.str(), args_); return std::make_tuple(format_.str(), args_);
} }
...@@ -40,9 +40,9 @@ struct StoreDebugInfoBuilder : public ir::IRVisitor { ...@@ -40,9 +40,9 @@ struct StoreDebugInfoBuilder : public ir::IRVisitor {
#define _BINARY_OP(Op__, repr__) \ #define _BINARY_OP(Op__, repr__) \
void Visit(const ir::Op__ *x) override { \ void Visit(const ir::Op__ *x) override { \
format_ << "("; \ format_ << "("; \
ir::IRVisitor::Visit(&x->a()); \ IRVisitor::Visit(&x->a()); \
format_ << " " << #repr__ << " "; \ format_ << " " << #repr__ << " "; \
ir::IRVisitor::Visit(&x->b()); \ IRVisitor::Visit(&x->b()); \
format_ << ")"; \ format_ << ")"; \
} }
_BINARY_OP(Add, +); _BINARY_OP(Add, +);
......
...@@ -30,12 +30,14 @@ namespace cinn { ...@@ -30,12 +30,14 @@ namespace cinn {
namespace optim { namespace optim {
using namespace ir; // NOLINT using namespace ir; // NOLINT
struct IRCopyVisitor : public ir::IRVisitorBase<Expr> { struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
// Use maps to unify all the copied tensors and buffers. // Use maps to unify all the copied tensors and buffers.
std::map<std::string, ir::_Tensor_*> tensor_map; std::map<std::string, ir::_Tensor_*> tensor_map;
std::map<std::string, ir::_Buffer_*> buffer_map; std::map<std::string, ir::_Buffer_*> buffer_map;
Expr Visit(const Expr* op) override { return IRVisitorBase::Visit(op); } Expr Visit(const Expr* op) override {
return IRVisitorRequireReImpl::Visit(op);
}
protected: protected:
// The methods of ir nodes follows the order defined in node.h // The methods of ir nodes follows the order defined in node.h
...@@ -421,8 +423,8 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> { ...@@ -421,8 +423,8 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
#define OP_BINARY_HANDLE(op__) \ #define OP_BINARY_HANDLE(op__) \
Expr Visit(const ir::op__* op) override { \ Expr Visit(const ir::op__* op) override { \
auto a = IRVisitorBase::Visit(&op->a()); \ auto a = IRVisitorRequireReImpl::Visit(&op->a()); \
auto b = IRVisitorBase::Visit(&op->b()); \ auto b = IRVisitorRequireReImpl::Visit(&op->b()); \
return op__::Make(a, b); \ return op__::Make(a, b); \
} }
NODETY_BINARY_OP_FOR_EACH(OP_BINARY_HANDLE) NODETY_BINARY_OP_FOR_EACH(OP_BINARY_HANDLE)
...@@ -430,7 +432,7 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> { ...@@ -430,7 +432,7 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
#define OP_UNARY_HANDLE(op__) \ #define OP_UNARY_HANDLE(op__) \
Expr Visit(const op__* op) override { \ Expr Visit(const op__* op) override { \
auto v = IRVisitorBase::Visit(&op->v()); \ auto v = IRVisitorRequireReImpl::Visit(&op->v()); \
return op__::Make(v); \ return op__::Make(v); \
} }
NODETY_UNARY_OP_FOR_EACH(OP_UNARY_HANDLE) NODETY_UNARY_OP_FOR_EACH(OP_UNARY_HANDLE)
......
...@@ -256,6 +256,7 @@ void BindNode(py::module *m) { ...@@ -256,6 +256,7 @@ void BindNode(py::module *m) {
}); });
} }
// empty visitor
void BindIrVisitor(py::module *m) { void BindIrVisitor(py::module *m) {
py::class_<ir::IRVisitor> ir_visitor(*m, "IRVisitor"); py::class_<ir::IRVisitor> ir_visitor(*m, "IRVisitor");
ir_visitor.def(py::init<>()) ir_visitor.def(py::init<>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册