未验证 提交 ed6b9567 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】refactor ir_visitor (#55171)

This PR delete middle ir_visitor class and thus we can avoid middle virtual function call and codes look more clean
pcard-72718
上级 4543ca91
......@@ -45,7 +45,9 @@ using namespace ::cinn::ir; // NOLINT
FeatureExtractor::FeatureExtractor() {}
void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); }
void FeatureExtractor::Visit(const Expr *x) {
IRVisitorRequireReImpl::Visit(x);
}
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) {
......
......@@ -37,7 +37,7 @@
namespace cinn {
namespace auto_schedule {
class FeatureExtractor : public ir::IRVisitor {
class FeatureExtractor : public ir::IRVisitorRequireReImpl<void> {
public:
FeatureExtractor();
Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target);
......
......@@ -71,7 +71,7 @@ bool operator<(const SearchState& left, const SearchState& right) {
}
// Visit every node by expanding all of their fields in dfs order
class DfsWithExprsFields : public ir::IRVisitor {
class DfsWithExprsFields : public ir::IRVisitorRequireReImpl<void> {
protected:
#define __m(t__) \
void Visit(const ir::t__* x) override { \
......@@ -85,7 +85,7 @@ class DfsWithExprsFields : public ir::IRVisitor {
NODETY_FORALL(__m)
#undef __m
void Visit(const Expr* expr) override { IRVisitor::Visit(expr); }
void Visit(const Expr* expr) override { IRVisitorRequireReImpl::Visit(expr); }
};
// Generate a reduce hash of a AST tree by combining hash of each AST node
......
......@@ -39,11 +39,11 @@
namespace cinn {
namespace backends {
class LLVMIRVisitor : public ir::IRVisitorBase<llvm::Value *> {
class LLVMIRVisitor : public ir::IRVisitorRequireReImpl<llvm::Value *> {
public:
LLVMIRVisitor() = default;
using ir::IRVisitorBase<llvm::Value *>::Visit;
using ir::IRVisitorRequireReImpl<llvm::Value *>::Visit;
#define __m(t__) virtual llvm::Value *Visit(const ir::t__ *x) = 0;
NODETY_FORALL(__m)
#undef __m
......
......@@ -19,13 +19,13 @@
namespace cinn {
namespace backends {
class ModularEvaluator : public ir::IRVisitorBase<ModularEntry> {
class ModularEvaluator : public ir::IRVisitorRequireReImpl<ModularEntry> {
public:
explicit ModularEvaluator(const std::map<Var, ModularEntry>& mod_map)
: mod_map_(mod_map) {}
ModularEntry Eval(const Expr& e) {
return ir::IRVisitorBase<ModularEntry>::Visit(&e);
return ir::IRVisitorRequireReImpl<ModularEntry>::Visit(&e);
}
ModularEntry Visit(const ir::IntImm* op) {
......
......@@ -24,7 +24,7 @@ namespace ir {
namespace {
struct IrNodesCollector : public IRVisitor {
struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
using teller_t = std::function<bool(const Expr*)>;
using handler_t = std::function<void(const Expr*)>;
......
......@@ -32,7 +32,7 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
VLOG(5) << "Not equal on Expr, someone not defined";
}
bool equal = lhs->node_type() == rhs->node_type();
equal = equal && IRVisitorBase<bool, const Expr*>::Visit(&lhs, &rhs);
equal = equal && IRVisitorRequireReImpl<bool, const Expr*>::Visit(&lhs, &rhs);
if (!equal) {
VLOG(5) << "Not equal on Expr, lhs:[type:"
......
......@@ -23,7 +23,7 @@ namespace ir {
// Determine whether two ir AST trees are euqal by comparing their struct and
// fields of each node through dfs visitor
class IrEqualVisitor : public IRVisitorBase<bool, const Expr*> {
class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false)
: allow_name_suffix_diff_(allow_name_suffix_diff) {}
......
......@@ -26,7 +26,7 @@ namespace ir {
//! T might be Expr* or const Expr*
template <typename T = Expr *>
class IRMutator : public IRVisitorBase<void, T> {
class IRMutator : public IRVisitorRequireReImpl<void, T> {
public:
void Visit(const Expr *expr, T op) override;
......@@ -37,22 +37,22 @@ class IRMutator : public IRVisitorBase<void, T> {
template <typename T>
void IRMutator<T>::Visit(const Expr *expr, T op) {
IRVisitorBase<void, T>::Visit(expr, op);
IRVisitorRequireReImpl<void, T>::Visit(expr, op);
}
#define UNARY_OP_IMPL(op__) \
template <typename T> \
void IRMutator<T>::Visit(const op__ *expr, T op) { \
auto *node = op->template As<op__>(); \
IRVisitorBase<void, T>::Visit(&node->v(), &node->v()); \
#define UNARY_OP_IMPL(op__) \
template <typename T> \
void IRMutator<T>::Visit(const op__ *expr, T op) { \
auto *node = op->template As<op__>(); \
IRVisitorRequireReImpl<void, T>::Visit(&node->v(), &node->v()); \
}
#define BINARY_OP_IMPL(op__) \
template <typename T> \
void IRMutator<T>::Visit(const op__ *expr, T op) { \
auto *node = op->template As<op__>(); \
IRVisitorBase<void, T>::Visit(&node->a(), &node->a()); \
IRVisitorBase<void, T>::Visit(&node->b(), &node->b()); \
#define BINARY_OP_IMPL(op__) \
template <typename T> \
void IRMutator<T>::Visit(const op__ *expr, T op) { \
auto *node = op->template As<op__>(); \
IRVisitorRequireReImpl<void, T>::Visit(&node->a(), &node->a()); \
IRVisitorRequireReImpl<void, T>::Visit(&node->b(), &node->b()); \
}
NODETY_UNARY_OP_FOR_EACH(UNARY_OP_IMPL)
......@@ -77,172 +77,181 @@ void IRMutator<T>::Visit(const Cast *expr, T op) {
template <typename T>
void IRMutator<T>::Visit(const For *expr, T op) {
auto *node = op->template As<For>();
IRVisitorBase<void, T>::Visit(&node->min, &node->min);
IRVisitorBase<void, T>::Visit(&node->extent, &node->extent);
IRVisitorBase<void, T>::Visit(&node->body, &node->body);
IRVisitorRequireReImpl<void, T>::Visit(&node->min, &node->min);
IRVisitorRequireReImpl<void, T>::Visit(&node->extent, &node->extent);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
}
template <typename T>
void IRMutator<T>::Visit(const PolyFor *expr, T op) {
auto *node = op->template As<PolyFor>();
// IRVisitorBase<void,T>::Visit(&node->iterator, &node->iterator);
IRVisitorBase<void, T>::Visit(&node->init, &node->init);
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->inc, &node->inc);
IRVisitorBase<void, T>::Visit(&node->body, &node->body);
// IRVisitorRequireReImpl<void,T>::Visit(&node->iterator,
// &node->iterator);
IRVisitorRequireReImpl<void, T>::Visit(&node->init, &node->init);
IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
IRVisitorRequireReImpl<void, T>::Visit(&node->inc, &node->inc);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
}
template <typename T>
void IRMutator<T>::Visit(const Select *expr, T op) {
auto *node = op->template As<Select>();
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->true_value, &node->true_value);
IRVisitorBase<void, T>::Visit(&node->false_value, &node->false_value);
IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
IRVisitorRequireReImpl<void, T>::Visit(&node->true_value, &node->true_value);
IRVisitorRequireReImpl<void, T>::Visit(&node->false_value,
&node->false_value);
}
template <typename T>
void IRMutator<T>::Visit(const IfThenElse *expr, T op) {
auto *node = op->template As<IfThenElse>();
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->true_case, &node->true_case);
IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
IRVisitorRequireReImpl<void, T>::Visit(&node->true_case, &node->true_case);
if (node->false_case.defined())
IRVisitorBase<void, T>::Visit(&node->false_case, &node->false_case);
IRVisitorRequireReImpl<void, T>::Visit(&node->false_case,
&node->false_case);
}
template <typename T>
void IRMutator<T>::Visit(const Block *expr, T op) {
auto *node = op->template As<Block>();
for (auto &expr : node->stmts) {
IRVisitorBase<void, T>::Visit(&expr, &expr);
IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
}
}
template <typename T>
void IRMutator<T>::Visit(const Call *expr, T op) {
auto *node = op->template As<Call>();
for (auto &expr : node->read_args) {
IRVisitorBase<void, T>::Visit(&expr, &expr);
IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
}
for (auto &expr : node->write_args) {
IRVisitorBase<void, T>::Visit(&expr, &expr);
IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
}
}
template <typename T>
void IRMutator<T>::Visit(const _Module_ *expr, T op) {
auto *node = op->template As<_Module_>();
for (auto &func : node->functions) {
IRVisitorBase<void, T>::Visit(&func, &func);
IRVisitorRequireReImpl<void, T>::Visit(&func, &func);
}
for (auto &func : node->buffers) {
IRVisitorBase<void, T>::Visit(&func, &func);
IRVisitorRequireReImpl<void, T>::Visit(&func, &func);
}
for (auto &expr : node->submodules) {
IRVisitorBase<void, T>::Visit(&expr, &expr);
IRVisitorRequireReImpl<void, T>::Visit(&expr, &expr);
}
}
template <typename T>
void IRMutator<T>::Visit(const _Var_ *expr, T op) {
auto *node = op->template As<ir::_Var_>();
if (node->lower_bound.defined()) {
IRVisitorBase<void, T>::Visit(&node->lower_bound, &node->lower_bound);
IRVisitorRequireReImpl<void, T>::Visit(&node->lower_bound,
&node->lower_bound);
}
if (node->upper_bound.defined()) {
IRVisitorBase<void, T>::Visit(&node->upper_bound, &node->upper_bound);
IRVisitorRequireReImpl<void, T>::Visit(&node->upper_bound,
&node->upper_bound);
}
}
template <typename T>
void IRMutator<T>::Visit(const Load *expr, T op) {
auto *node = op->template As<Load>();
for (auto &idx : node->indices) IRVisitorBase<void, T>::Visit(&idx, &idx);
IRVisitorBase<void, T>::Visit(&node->tensor, &node->tensor);
for (auto &idx : node->indices)
IRVisitorRequireReImpl<void, T>::Visit(&idx, &idx);
IRVisitorRequireReImpl<void, T>::Visit(&node->tensor, &node->tensor);
}
template <typename T>
void IRMutator<T>::Visit(const Store *expr, T op) {
auto *node = op->template As<Store>();
IRVisitorBase<void, T>::Visit(&node->value, &node->value);
IRVisitorBase<void, T>::Visit(&node->tensor, &node->tensor);
for (auto &idx : node->indices) IRVisitorBase<void, T>::Visit(&idx, &idx);
IRVisitorRequireReImpl<void, T>::Visit(&node->value, &node->value);
IRVisitorRequireReImpl<void, T>::Visit(&node->tensor, &node->tensor);
for (auto &idx : node->indices)
IRVisitorRequireReImpl<void, T>::Visit(&idx, &idx);
}
template <typename T>
void IRMutator<T>::Visit(const Alloc *expr, T op) {
auto *node = op->template As<Alloc>();
for (auto &e : node->extents) {
IRVisitorBase<void, T>::Visit(&e, &e);
IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
}
if (node->condition.defined())
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition);
IRVisitorRequireReImpl<void, T>::Visit(&node->condition, &node->condition);
if (node->body.defined()) {
Expr body(node->body);
IRVisitorBase<void, T>::Visit(&node->body, &body);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &body);
}
}
template <typename T>
void IRMutator<T>::Visit(const Free *expr, T op) {
auto *node = op->template As<Free>();
IRVisitorBase<void, T>::Visit(&node->destination, &node->destination);
IRVisitorRequireReImpl<void, T>::Visit(&node->destination,
&node->destination);
}
template <typename T>
void IRMutator<T>::Visit(const _Buffer_ *expr, T op) {
auto *node = op->template As<_Buffer_>();
for (auto &e : node->shape) {
IRVisitorBase<void, T>::Visit(&e, &e);
IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
}
for (auto &e : node->strides) {
IRVisitorBase<void, T>::Visit(&e, &e);
IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
}
IRVisitorBase<void, T>::Visit(&node->elem_offset, &node->elem_offset);
IRVisitorRequireReImpl<void, T>::Visit(&node->elem_offset,
&node->elem_offset);
}
template <typename T>
void IRMutator<T>::Visit(const _Tensor_ *expr, T op) {
auto *node = op->template As<_Tensor_>();
for (auto &e : node->shape) {
IRVisitorBase<void, T>::Visit(&e, &e);
IRVisitorRequireReImpl<void, T>::Visit(&e, &e);
}
}
template <typename T>
void IRMutator<T>::Visit(const _LoweredFunc_ *expr, T op) {
auto *node = op->template As<_LoweredFunc_>();
IRVisitorBase<void, T>::Visit(&node->body, &node->body);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
}
template <typename T>
void IRMutator<T>::Visit(const Let *expr, T op) {
auto *node = op->template As<Let>();
IRVisitorBase<void, T>::Visit(&node->symbol, &node->symbol);
IRVisitorRequireReImpl<void, T>::Visit(&node->symbol, &node->symbol);
if (node->body.defined())
IRVisitorBase<void, T>::Visit(&node->body, &node->body);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
}
template <typename T>
void IRMutator<T>::Visit(const Reduce *expr, T op) {
auto *node = op->template As<Reduce>();
if (node->init.defined())
IRVisitorBase<void, T>::Visit(&node->init, &node->init);
IRVisitorRequireReImpl<void, T>::Visit(&node->init, &node->init);
CHECK(node->body.defined());
IRVisitorBase<void, T>::Visit(&node->body, &node->body);
IRVisitorRequireReImpl<void, T>::Visit(&node->body, &node->body);
}
template <typename T>
void IRMutator<T>::Visit(const Ramp *expr, T op) {
auto *node = op->template As<Ramp>();
IRVisitorBase<void, T>::Visit(&node->base, &node->base);
IRVisitorBase<void, T>::Visit(&node->stride, &node->stride);
IRVisitorRequireReImpl<void, T>::Visit(&node->base, &node->base);
IRVisitorRequireReImpl<void, T>::Visit(&node->stride, &node->stride);
}
template <typename T>
void IRMutator<T>::Visit(const Broadcast *expr, T op) {
auto *node = op->template As<Broadcast>();
IRVisitorBase<void, T>::Visit(&node->value, &node->value);
IRVisitorRequireReImpl<void, T>::Visit(&node->value, &node->value);
}
template <typename T>
void IRMutator<T>::Visit(const FracOp *expr, T op) {
auto *node = op->template As<FracOp>();
IRVisitorBase<void, T>::Visit(&node->a(), &node->a());
IRVisitorBase<void, T>::Visit(&node->b(), &node->b());
IRVisitorRequireReImpl<void, T>::Visit(&node->a(), &node->a());
IRVisitorRequireReImpl<void, T>::Visit(&node->b(), &node->b());
}
template <typename T>
void IRMutator<T>::Visit(const Product *expr, T op) {
auto *node = op->template As<Product>();
for (auto &x : node->operands()) {
IRVisitorBase<void, T>::Visit(&x, &x);
IRVisitorRequireReImpl<void, T>::Visit(&x, &x);
}
}
......@@ -250,7 +259,7 @@ template <typename T>
void IRMutator<T>::Visit(const Sum *expr, T op) {
auto *node = op->template As<Sum>();
for (auto &x : node->operands()) {
IRVisitorBase<void, T>::Visit(&x, &x);
IRVisitorRequireReImpl<void, T>::Visit(&x, &x);
}
}
template <typename T>
......@@ -258,7 +267,7 @@ void IRMutator<T>::Visit(const PrimitiveNode *expr, T op) {
auto *node = op->template As<PrimitiveNode>();
for (auto &args : node->arguments) {
for (auto &arg : args) {
IRVisitorBase<void, T>::Visit(&arg, &arg);
IRVisitorRequireReImpl<void, T>::Visit(&arg, &arg);
}
}
}
......@@ -292,13 +301,15 @@ template <typename T>
void IRMutator<T>::Visit(const _BufferRange_ *expr, T op) {
auto *node = op->template As<_BufferRange_>();
CHECK(node);
IRVisitorBase<void, T>::Visit(&node->buffer, &node->buffer);
IRVisitorRequireReImpl<void, T>::Visit(&node->buffer, &node->buffer);
for (auto &var : node->ranges) {
if (var->lower_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->lower_bound, &var->lower_bound);
IRVisitorRequireReImpl<void, T>::Visit(&var->lower_bound,
&var->lower_bound);
}
if (var->upper_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->upper_bound, &var->upper_bound);
IRVisitorRequireReImpl<void, T>::Visit(&var->upper_bound,
&var->upper_bound);
}
}
}
......@@ -309,19 +320,21 @@ void IRMutator<T>::Visit(const ScheduleBlock *expr, T op) {
CHECK(node);
for (auto &var : node->iter_vars) {
if (var->lower_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->lower_bound, &var->lower_bound);
IRVisitorRequireReImpl<void, T>::Visit(&var->lower_bound,
&var->lower_bound);
}
if (var->upper_bound.defined()) {
IRVisitorBase<void, T>::Visit(&var->upper_bound, &var->upper_bound);
IRVisitorRequireReImpl<void, T>::Visit(&var->upper_bound,
&var->upper_bound);
}
}
for (auto &buffer_region : node->read_buffers) {
IRVisitorBase<void, T>::Visit(&buffer_region, &buffer_region);
IRVisitorRequireReImpl<void, T>::Visit(&buffer_region, &buffer_region);
}
for (auto &buffer_region : node->write_buffers) {
IRVisitorBase<void, T>::Visit(&buffer_region, &buffer_region);
IRVisitorRequireReImpl<void, T>::Visit(&buffer_region, &buffer_region);
}
IRVisitorBase<void, T>::Visit(&(node->body), &(node->body));
IRVisitorRequireReImpl<void, T>::Visit(&(node->body), &(node->body));
}
template <typename T>
......@@ -329,9 +342,10 @@ void IRMutator<T>::Visit(const ScheduleBlockRealize *expr, T op) {
auto *node = op->template As<ScheduleBlockRealize>();
CHECK(node);
for (auto &value : node->iter_values) {
IRVisitorBase<void, T>::Visit(&value, &value);
IRVisitorRequireReImpl<void, T>::Visit(&value, &value);
}
IRVisitorBase<void, T>::Visit(&node->schedule_block, &node->schedule_block);
IRVisitorRequireReImpl<void, T>::Visit(&node->schedule_block,
&node->schedule_block);
}
} // namespace ir
......
......@@ -32,7 +32,7 @@ namespace ir {
using common::bfloat16;
using common::float16;
void IrPrinter::Print(Expr e) { IRVisitor::Visit(&e); }
void IrPrinter::Print(Expr e) { IRVisitorRequireReImpl::Visit(&e); }
void IrPrinter::Print(const std::vector<Expr> &exprs,
const std::string &splitter) {
for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) {
......
......@@ -29,7 +29,7 @@ class LoweredFunc;
namespace ir {
class Module;
struct IrPrinter : public IRVisitor {
struct IrPrinter : public IRVisitorRequireReImpl<void> {
explicit IrPrinter(std::ostream &os) : os_(os) {}
//! Emit an expression on the output stream.
......
......@@ -34,7 +34,8 @@ struct _Tensor_;
* @param Args type of the extra arguments passed to the all the methods.
*/
template <typename RetTy = void, typename... Args>
struct IRVisitorBase {
class IRVisitorRequireReImpl {
public:
//! Visit a expression.
// @{
virtual RetTy Visit(const ir::Expr* expr, Args... args) {
......@@ -53,7 +54,6 @@ struct IRVisitorBase {
return RetTy();
}
// @}
protected:
#define __(op__) virtual RetTy Visit(const ir::op__* op, Args... args) = 0;
NODETY_FORALL(__)
......@@ -63,16 +63,19 @@ struct IRVisitorBase {
/**
* Base of all the Ir readonly visitor.
*/
struct IRVisitor : public IRVisitorBase<void> {
struct IRVisitor : public IRVisitorRequireReImpl<void> {
IRVisitor() = default;
void Visit(const Expr* x) { IRVisitorBase::Visit(x); }
void Visit(const Expr* x) { IRVisitorRequireReImpl::Visit(x); }
#define __m(t__) \
virtual void Visit(const t__* x) {}
virtual void Visit(const t__* x) { return VisitDefault(x); }
NODETY_FORALL(__m)
#undef __m
};
virtual void VisitDefault(const Object* obj) {
LOG(FATAL) << "not supported NodeTy";
}
};
// std::set<Expr> CollectIRNodes(Expr expr, std::function<bool(const Expr*)>
// teller);
......
......@@ -32,7 +32,7 @@ namespace {
struct StoreDebugInfoBuilder : public ir::IRVisitor {
std::tuple<std::string, std::vector<Expr>> operator()(const Expr *e) {
ir::IRVisitor::Visit(e);
IRVisitor::Visit(e);
return std::make_tuple(format_.str(), args_);
}
......@@ -40,9 +40,9 @@ struct StoreDebugInfoBuilder : public ir::IRVisitor {
#define _BINARY_OP(Op__, repr__) \
void Visit(const ir::Op__ *x) override { \
format_ << "("; \
ir::IRVisitor::Visit(&x->a()); \
IRVisitor::Visit(&x->a()); \
format_ << " " << #repr__ << " "; \
ir::IRVisitor::Visit(&x->b()); \
IRVisitor::Visit(&x->b()); \
format_ << ")"; \
}
_BINARY_OP(Add, +);
......
......@@ -30,12 +30,14 @@ namespace cinn {
namespace optim {
using namespace ir; // NOLINT
struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
// Use maps to unify all the copied tensors and buffers.
std::map<std::string, ir::_Tensor_*> tensor_map;
std::map<std::string, ir::_Buffer_*> buffer_map;
Expr Visit(const Expr* op) override { return IRVisitorBase::Visit(op); }
Expr Visit(const Expr* op) override {
return IRVisitorRequireReImpl::Visit(op);
}
protected:
// The methods of ir nodes follows the order defined in node.h
......@@ -419,19 +421,19 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
}
}
#define OP_BINARY_HANDLE(op__) \
Expr Visit(const ir::op__* op) override { \
auto a = IRVisitorBase::Visit(&op->a()); \
auto b = IRVisitorBase::Visit(&op->b()); \
return op__::Make(a, b); \
#define OP_BINARY_HANDLE(op__) \
Expr Visit(const ir::op__* op) override { \
auto a = IRVisitorRequireReImpl::Visit(&op->a()); \
auto b = IRVisitorRequireReImpl::Visit(&op->b()); \
return op__::Make(a, b); \
}
NODETY_BINARY_OP_FOR_EACH(OP_BINARY_HANDLE)
#undef OP_BINARY_HANDLE
#define OP_UNARY_HANDLE(op__) \
Expr Visit(const op__* op) override { \
auto v = IRVisitorBase::Visit(&op->v()); \
return op__::Make(v); \
#define OP_UNARY_HANDLE(op__) \
Expr Visit(const op__* op) override { \
auto v = IRVisitorRequireReImpl::Visit(&op->v()); \
return op__::Make(v); \
}
NODETY_UNARY_OP_FOR_EACH(OP_UNARY_HANDLE)
#undef OP_UNARY_HANDLE
......
......@@ -256,6 +256,7 @@ void BindNode(py::module *m) {
});
}
// empty visitor
void BindIrVisitor(py::module *m) {
py::class_<ir::IRVisitor> ir_visitor(*m, "IRVisitor");
ir_visitor.def(py::init<>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册