diff --git a/src/pass/to_three_address.cc b/src/pass/to_three_address.cc index 670fa6a47ab37c0f360f720ae688f2e60fb3efed..b01fd0696df83c9e764f70771a9094a10d5950e7 100644 --- a/src/pass/to_three_address.cc +++ b/src/pass/to_three_address.cc @@ -35,22 +35,71 @@ using VarSet = std::unordered_set; // forward declaration class ThreeAddressExprMutator; -class ThreeAddressFilter : public IRVisitor { +class ExprArgsFetcher : public IRVisitor { public: - bool Find(const Stmt &s) { - Visit(s); - return need_; + explicit ExprArgsFetcher(Array args) : args_(args), index_(args_.size() - 1) {} + ~ExprArgsFetcher() override = default; + + Array GetArgs(const Expr &e) { + Visit(e); + if (max_dim >= args_.size()) { + return args_; + } + Array args; + while (index_ < args_.size()) { + args.push_back(args_[index_]); + index_++; + } + if (CountVars(args) == CountVars(args_)) { + return args_; + } + return args; + } + + bool MustBroadcast(const Expr &e) { + if (is_constant(e) || CountVars(e) == 0) { + return false; + } + size_t size = GetArgs(e).size(); + return size > max_dim; } void Visit_(const Call *op) override { - if (op->name == "load3d_l1_ub") { - need_ = false; + if (op->call_type == Call::CallType::Halide) { + max_dim = max_dim < op->args.size() ? op->args.size() : max_dim; + for (Expr arg : op->args) { + size_t index = GetIndex(arg); + index_ = index_ > index ? index : index_; + } + } else { + CHECK(op->call_type == Call::CallType::PureIntrinsic); + for (Expr e : op->args) { + Array args = GetArgs(e); + max_dim = max_dim < args.size() ? args.size() : max_dim; + for (Expr arg : args) { + size_t index = GetIndex(arg); + index_ = index_ > index ? index : index_; + } + } } - IRVisitor::Visit_(op); } private: - bool need_{true}; + size_t GetIndex(const Expr &arg) { + if (is_constant(arg)) { + return index_; + } + for (size_t i = 0; i < args_.size(); ++i) { + if (args_[i].same_as(arg)) { + return i; + } + } + return index_; + } + + Array args_; + size_t index_; + size_t max_dim{0}; }; class ScalarOperandFinder : public IRVisitor { @@ -188,50 +237,21 @@ std::unordered_set GetExprTensors(const Expr expr) { return tensors; } -// Replace all instances of a Tensor "from" in an Expr with a new one "to" -class ReplaceProvideTensors : public IRMutator { - public: - ReplaceProvideTensors(const Tensor &from, const Operation &to) : from_(from->op), to_(to) {} - ~ReplaceProvideTensors() override = default; - - Stmt Mutate_(const Provide *op, const Stmt &s) final { - Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - CHECK(op); - if (op->func == from_) { - stmt = Provide::make(to_, op->value_index, op->value, op->args); - } - return stmt; - } - - Expr Mutate_(const Call *op, const Expr &e) override { - Expr expr = IRMutator::Mutate_(op, e); - const Call *n = expr.as(); - CHECK(n); - if (n->func == from_) { - expr = Call::make(n->type, to_->name, n->args, n->call_type, to_, n->value_index); - } - return expr; - } - - private: - const Operation from_; - const Operation to_; -}; - // Mutate expression according to selection choices class ThreeAddressExprMutator : public IRMutator { public: - ThreeAddressExprMutator(const Tensor output, const Array &args, const Array &shape, - const std::unordered_set &broadcast, bool IsReductionOp, - bool cross_stmt_simplify) + ThreeAddressExprMutator(const Tensor output, const Array &args, const Array &out_args, + const Array &shape, const std::unordered_set &broadcast, + bool IsReductionOp, bool cross_stmt_simplify, bool is_simple = false) : output_(output), args_(args), + out_args_(out_args), shape_(shape), broadcast_(broadcast), IsReductionOp_(IsReductionOp), cross_simplify_(cross_stmt_simplify), - hasher_(cross_stmt_simplify) { + hasher_(cross_stmt_simplify), + is_simple_(is_simple) { CHECK_EQ(args_.size(), shape_.size()); if (shape_.empty()) { // scalar values should have at least one dimension and contains one element shape_.push_back(1); @@ -246,7 +266,7 @@ class ThreeAddressExprMutator : public IRMutator { common_exprs_.insert(global_common_expr.begin(), global_common_expr.end()); } - Expr AllocateTmp(Expr value) { + Expr AllocateTmp(Expr value, Array args = {}) { // detect common expression size_t hash_value = hasher_(value); auto x = common_exprs_[hash_value]; @@ -263,13 +283,17 @@ class ThreeAddressExprMutator : public IRMutator { // allocate new immediate tensor Tensor imm; - imm = PlaceholderOpNode::make(output_->op->name + "_" + std::to_string(ct_++), shape_, value.type()).output(0); + if (args.empty()) { + args = args_; + } + std::string name = output_->op->name + "_" + std::to_string(ct_++); + imm = PlaceholderOpNode::make(name, GetShape(args), value.type()).output(0); imm_tensors.push_back(imm); imm_ops.insert(imm->op); // update common expr - assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args_)); - Expr ret = Call::make(value.type(), imm->op->name, args_, Call::CallType::Halide, imm->op, imm->value_index); + assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args)); + Expr ret = Call::make(value.type(), imm->op->name, args, Call::CallType::Halide, imm->op, imm->value_index); common_exprs_[hash_value] = std::make_pair(value, ret); imm2hash_[imm->op] = hash_value; return ret; @@ -283,9 +307,13 @@ class ThreeAddressExprMutator : public IRMutator { common_exprs_.erase(old_hash); // update new common expr - assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args_)); + Array args = args_; + if (is_simple_) { + args = ExprArgsFetcher(args_).GetArgs(value); + } + assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args)); size_t hash_value = hasher_(value); - Expr ret = Call::make(value.type(), imm->op->name, args_, Call::CallType::Halide, imm->op, imm->value_index); + Expr ret = Call::make(value.type(), imm->op->name, args, Call::CallType::Halide, imm->op, imm->value_index); common_exprs_[hash_value] = std::make_pair(value, ret); imm2hash_[imm->op] = hash_value; return ret; @@ -324,20 +352,25 @@ class ThreeAddressExprMutator : public IRMutator { Expr r = Mutate(op->b); in_call_--; - bool broadcast_l = !IsReductionOp_ && !is_constant(l) && CountVars(args_) > CountVars(l); - bool broadcast_r = !IsReductionOp_ && !is_constant(r) && CountVars(args_) > CountVars(r); + Array args = args_; + if (is_simple_) { + args = ExprArgsFetcher(args_).GetArgs(T::make(l, r)); + } + bool broadcast_l = !IsReductionOp_ && !is_constant(l) && CountVars(args) > CountVars(l); + bool broadcast_r = !IsReductionOp_ && !is_constant(r) && CountVars(args) > CountVars(r); if (op->template IsInstance() || op->template IsInstance()) { - if (broadcast_l && broadcast_r) { - l = AllocateTmp(l); - } else if (is_constant(r) && broadcast_l) { - l = AllocateTmp(l); - } else if (is_constant(l) && broadcast_r) { - r = AllocateTmp(r); + if (broadcast_l && (broadcast_r || is_constant(r))) { + l = AllocateTmp(l, args); + } else if ((broadcast_r && is_constant(l))) { + r = AllocateTmp(r, args); + } + if (CountVars(args) > CountVars(r) && ExprArgsFetcher(out_args_).MustBroadcast(r)) { + r = AllocateTmp(r, args); } } - return AllocateTmp(T::make(Mutate(l), Mutate(r))); + return AllocateTmp(T::make(Mutate(l), Mutate(r)), args); } Expr Mutate_(const Add *op, const Expr &e) final { return MutateBinaryOp(op, e); } @@ -443,7 +476,14 @@ class ThreeAddressExprMutator : public IRMutator { // broadcast when need if (broadcast_.count(op) && broadcast) { - return AllocateTmp(e); + if (expr_stack.size() >= 2 && expr_stack[expr_stack.size() - 2]->IsInstance
()) { + Array args = ExprArgsFetcher(args_).GetArgs(expr_stack[expr_stack.size() - 2]); + if (CountVars(e) < CountVars(args)) { + return AllocateTmp(e, args); + } + } else { + return AllocateTmp(e); + } } // this is going to generate a tensor of tensor expr, like A(B(i)) return e; @@ -546,8 +586,25 @@ class ThreeAddressExprMutator : public IRMutator { } } + Array GetShape(const Array &args) { + if (CountVars(args) == CountVars(args_)) { + return shape_; + } + const size_t dim = args.size(); + const size_t maxDim = output_->shape.size(); + CHECK_LE(dim, maxDim); + Array shape; + size_t index = maxDim - dim; + while (index < maxDim) { + shape.push_back(output_->shape[index]); + index++; + } + return shape; + } + Tensor output_; Array args_; + Array out_args_; Array shape_; std::unordered_map> common_exprs_; // hash value -> @@ -566,6 +623,7 @@ class ThreeAddressExprMutator : public IRMutator { bool IsReductionOp_{false}; bool cross_simplify_; ExprHasher hasher_; + bool is_simple_; }; Expr ThreeAddressExprMutator::Mutate(Expr expr) { @@ -579,66 +637,78 @@ Expr ThreeAddressExprMutator::Mutate(Expr expr) { int ThreeAddressExprMutator::ct_ = 0; -class InstructionSelector { +class InstructionMutator : IRMutator { public: - InstructionSelector(ThreeAddressExprMutator &mutator, std::list &exprs, - std::unordered_map ¬ation_map, - std::unordered_map &sign_map) - : mutator_(mutator), exprs_(exprs), notation_map_(notation_map), sign_map_(sign_map) {} - ~InstructionSelector() = default; + explicit InstructionMutator(ThreeAddressExprMutator &mutator, Array &args) : mutator_(mutator), args_(args) {} + ~InstructionMutator() = default; - Expr Mutate(Expr expr) { - if (const Mul *op = expr.as()) { - return Mutate_(op, expr); - } - if (const Cast *op = expr.as()) { - return Mutate_(op, expr); + Expr Mutate(Expr value) { return IRMutator::Mutate(value); } + + // VMADD.type {f16, f32} [Xd], [Xn], [Xm], Xt, MASK + // [Xd] = [Xn] * [Xd] + [Xm] + // VAXPY.type {f16, f32, fmix} [Xd], [Xn], Xm, Xt, MASK + // [Xd] = Xm * [Xn] + [Xd] + Expr Mutate_(const Add *op, const Expr &e) { + Expr l = Mutate(op->a); + Expr r = Mutate(op->b); + if (is_constant(l) && is_constant(r)) { + return ConstantFold(l, r); } - if (const Select *op = expr.as