diff --git a/src/pass/to_three_address.cc b/src/pass/to_three_address.cc index c63f7a72be5005465c7b2ce498b958ec03ff2c06..670fa6a47ab37c0f360f720ae688f2e60fb3efed 100644 --- a/src/pass/to_three_address.cc +++ b/src/pass/to_three_address.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include #include #include @@ -36,12 +35,6 @@ using VarSet = std::unordered_set; // forward declaration class ThreeAddressExprMutator; -struct ExpressionPattern { - int min_level; // minimal level - std::function score_func; // assign score to a subtree - std::function replace_func; // replace a subtree with this instruction -}; - class ThreeAddressFilter : public IRVisitor { public: bool Find(const Stmt &s) { @@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator { // forward declaration Expr Mutate(Expr expr) override; - // do naive three address translation without instruction selection - Expr MutateWithoutSelection(const Expr expr) { - disable_selection_ = true; - Expr ret = Mutate(expr); - disable_selection_ = false; - return ret; - } - template Expr MutateBinaryOp(const T *op, const Expr &e) { in_call_++; @@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator { Expr Mutate_(const FloatImm *op, const Expr &e) final { return MutateConstOp(op, e); } Expr Mutate_(const IntImm *op, const Expr &e) final { return MutateConstOp(op, e); } + void AddBroadCastCallIfNeed(const Call *op, const Expr &e) { + if (broadcast_.find(op) == broadcast_.end()) { + return; + } + const Call *new_call = e.as(); + CHECK_NOTNULL(new_call); + broadcast_.insert(new_call); + } + std::vector assign_stmt; std::vector imm_tensors; std::unordered_set imm_ops; @@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator { Array shape_; std::unordered_map> common_exprs_; // hash value -> - std::unordered_map - imm2hash_; // imm tensor -> hash value of the expr in the tensor + // imm tensor -> hash value of the expr in the tensor + std::unordered_map imm2hash_; int level_{0}; int in_call_{0}; @@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator { ExprHasher hasher_; }; -Expr CallPureIntrinsic(const std::string &name, const Array &args, const Type type) { - return Call::make(type, name, args, Call::CallType::PureIntrinsic); +Expr ThreeAddressExprMutator::Mutate(Expr expr) { + level_++; + expr_stack.push_back(expr); + Expr ret = IRMutator::Mutate(expr); + expr_stack.pop_back(); + level_--; + return ret; } -// Match instructions by dynamic programming on the tree -class InstructionMatcher { +int ThreeAddressExprMutator::ct_ = 0; + +class InstructionSelector { public: - void Match(const Expr value) { - int max_score = -1; - int max_i = -1; - - // try patterns - for (size_t i = 0; i < ins_pattern.size(); ++i) { - int score_ = ins_pattern[i].score_func(value); - if (score_ > max_score) { - max_score = score_; - max_i = static_cast(i); + InstructionSelector(ThreeAddressExprMutator &mutator, std::list &exprs, + std::unordered_map ¬ation_map, + std::unordered_map &sign_map) + : mutator_(mutator), exprs_(exprs), notation_map_(notation_map), sign_map_(sign_map) {} + ~InstructionSelector() = default; + + Expr Mutate(Expr expr) { + if (const Mul *op = expr.as()) { + return Mutate_(op, expr); + } + if (const Cast *op = expr.as()) { + return Mutate_(op, expr); + } + if (const Select *op = expr.as