提交 fae7628a 编写于 作者: Z zhaiyukun

Add Optimization for three address

  1.Arithmetic priority adjustment
  2.Instruction selection
上级 6335daad
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <arithmetic/pattern_match.h>
#include <dmlc/common.h> #include <dmlc/common.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
...@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>; ...@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
// forward declaration // forward declaration
class ThreeAddressExprMutator; class ThreeAddressExprMutator;
struct ExpressionPattern {
int min_level; // minimal level
std::function<int(Expr)> score_func; // assign score to a subtree
std::function<Expr(Expr, ThreeAddressExprMutator &)> replace_func; // replace a subtree with this instruction
};
class ThreeAddressFilter : public IRVisitor { class ThreeAddressFilter : public IRVisitor {
public: public:
bool Find(const Stmt &s) { bool Find(const Stmt &s) {
...@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator {
// forward declaration // forward declaration
Expr Mutate(Expr expr) override; Expr Mutate(Expr expr) override;
// do naive three address translation without instruction selection
Expr MutateWithoutSelection(const Expr expr) {
disable_selection_ = true;
Expr ret = Mutate(expr);
disable_selection_ = false;
return ret;
}
template <typename T> template <typename T>
Expr MutateBinaryOp(const T *op, const Expr &e) { Expr MutateBinaryOp(const T *op, const Expr &e) {
in_call_++; in_call_++;
...@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator {
Expr Mutate_(const FloatImm *op, const Expr &e) final { return MutateConstOp(op, e); } Expr Mutate_(const FloatImm *op, const Expr &e) final { return MutateConstOp(op, e); }
Expr Mutate_(const IntImm *op, const Expr &e) final { return MutateConstOp(op, e); } Expr Mutate_(const IntImm *op, const Expr &e) final { return MutateConstOp(op, e); }
void AddBroadCastCallIfNeed(const Call *op, const Expr &e) {
if (broadcast_.find(op) == broadcast_.end()) {
return;
}
const Call *new_call = e.as<Call>();
CHECK_NOTNULL(new_call);
broadcast_.insert(new_call);
}
std::vector<Stmt> assign_stmt; std::vector<Stmt> assign_stmt;
std::vector<Tensor> imm_tensors; std::vector<Tensor> imm_tensors;
std::unordered_set<FunctionRef, air::NodeHash, air::NodeEqual> imm_ops; std::unordered_set<FunctionRef, air::NodeHash, air::NodeEqual> imm_ops;
...@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator {
Array<Expr> shape_; Array<Expr> shape_;
std::unordered_map<size_t, std::pair<Expr, Expr>> common_exprs_; // hash value -> <match expr, replace expr> std::unordered_map<size_t, std::pair<Expr, Expr>> common_exprs_; // hash value -> <match expr, replace expr>
std::unordered_map<FunctionRef, size_t, air::NodeHash, air::NodeEqual> // imm tensor -> hash value of the expr in the tensor
imm2hash_; // imm tensor -> hash value of the expr in the tensor std::unordered_map<FunctionRef, size_t, air::NodeHash, air::NodeEqual> imm2hash_;
int level_{0}; int level_{0};
int in_call_{0}; int in_call_{0};
...@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator {
ExprHasher hasher_; ExprHasher hasher_;
}; };
Expr CallPureIntrinsic(const std::string &name, const Array<Expr> &args, const Type type) { Expr ThreeAddressExprMutator::Mutate(Expr expr) {
return Call::make(type, name, args, Call::CallType::PureIntrinsic); level_++;
expr_stack.push_back(expr);
Expr ret = IRMutator::Mutate(expr);
expr_stack.pop_back();
level_--;
return ret;
} }
// Match instructions by dynamic programming on the tree int ThreeAddressExprMutator::ct_ = 0;
class InstructionMatcher {
class InstructionSelector {
public: public:
void Match(const Expr value) { InstructionSelector(ThreeAddressExprMutator &mutator, std::list<Expr> &exprs,
int max_score = -1; std::unordered_map<const Object *, std::string> &notation_map,
int max_i = -1; std::unordered_map<const Object *, bool> &sign_map)
: mutator_(mutator), exprs_(exprs), notation_map_(notation_map), sign_map_(sign_map) {}
~InstructionSelector() = default;
// try patterns Expr Mutate(Expr expr) {
for (size_t i = 0; i < ins_pattern.size(); ++i) { if (const Mul *op = expr.as<Mul>()) {
int score_ = ins_pattern[i].score_func(value); return Mutate_(op, expr);
if (score_ > max_score) {
max_score = score_;
max_i = static_cast<int>(i);
} }
if (const Cast *op = expr.as<Cast>()) {
return Mutate_(op, expr);
} }
if (const Select *op = expr.as<Select>()) {
score = max_score; return Mutate_(op, expr);
choice = max_i; }
return expr;
} }
int score;
int choice;
const int NORMAL = 20;
const int PRIOR = 50;
const int UNMATCH = -1;
air::arith::PVar<Expr> x, y, z, w;
air::arith::PVar<Type> pt;
air::arith::PVar<Floating> c1, c2;
std::vector<ExpressionPattern> ins_pattern{
// vmadd [Xd] = [Xn] * [Xd] + [Xm] // vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vmla [Xd] = [Xn] * [Xm] + [Xd] // vaxpy [Xd] = Xm * [Xn] + [Xd]
ExpressionPattern{ Expr Mutate_(const Mul *op, const Expr &e) {
2, std::string root = notation_map_.at(e.get());
[&, this](const Expr &expr) -> int { if (root != Add::_type_key && root != Sub::_type_key) {
if (((x * y + z).Match(expr) || (z + x * y).Match(expr)) && return e;
(!is_constant(x.Eval()) && !is_constant(y.Eval()) && !is_constant(z.Eval()))) { }
return PRIOR; bool is_left_constant = is_constant(op->a);
} bool is_right_constant = is_constant(op->b);
return UNMATCH; if (is_left_constant && is_right_constant) {
}, return e;
[&, this](const Expr &expr, ThreeAddressExprMutator &mutator) -> Expr { }
CHECK(((x * y + z)).Match(expr) || (z + x * y).Match(expr)); Expr expr = GetIndexOfPairExprForMul(e);
if (expr.same_as(e)) {
Expr x_eval = mutator.Mutate(x.Eval()); return e;
Expr y_eval = mutator.Mutate(y.Eval()); }
Expr z_eval = mutator.Mutate(z.Eval()); Array<Expr> args;
// make sure elemwise inside if (!is_left_constant) {
if (CountVars(x_eval) != CountVars(y_eval) || CountVars(x_eval) != CountVars(z_eval)) { args.push_back(op->a);
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
if (mutator.IsTmpTensor(x_eval)) {
return mutator.AssignTmp(x_eval, CallPureIntrinsic("vmadd", {y_eval, z_eval, x_eval}, x_eval.type()));
} else if (mutator.IsTmpTensor(y_eval)) {
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vmadd", {x_eval, z_eval, y_eval}, y_eval.type()));
} else if (mutator.IsTmpTensor(z_eval)) {
return mutator.AssignTmp(z_eval, CallPureIntrinsic("vmla", {x_eval, y_eval, z_eval}, z_eval.type()));
} else { } else {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval); args.push_back(op->b);
} }
}}, args.push_back(expr);
if (!is_right_constant) {
// vmaddrelu [Xd] = max([Xn] * [Xd] + [Xm], 0) args.push_back(op->b);
ExpressionPattern{
2,
[&, this](const Expr expr) -> int {
if (((max(x * y + z, c1)).Match(expr) || (max(z + x * y, c1)).Match(expr) || (max(c1, x * y + z)).Match(expr) ||
(max(c1, z + x * y)).Match(expr)) &&
c1.Eval()->value == 0.0 && (!is_constant(x.Eval()) && !is_constant(y.Eval()) && !is_constant(z.Eval()))) {
return PRIOR;
}
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK((max(x * y + z, c1)).Match(expr) || (max(z + x * y, c1)).Match(expr) ||
(max(c1, x * y + z)).Match(expr) || (max(c1, z + x * y)).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
Expr z_eval = mutator.Mutate(z.Eval());
// check elemwise
if (CountVars(x_eval) != CountVars(y_eval) || CountVars(x_eval) != CountVars(z_eval)) {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
if (mutator.IsTmpTensor(x_eval) || x_eval.same_as(x.Eval())) {
return mutator.AssignTmp(x_eval, CallPureIntrinsic("vmaddrelu", {y_eval, z_eval, x_eval}, x_eval.type()));
} else if (mutator.IsTmpTensor(y_eval) || y_eval.same_as(y.Eval())) {
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vmaddrelu", {x_eval, z_eval, y_eval}, y_eval.type()));
} else { } else {
return mutator.MutateWithoutSelection(max(x_eval * y_eval + z_eval, c1.Eval())); args.push_back(op->a);
}
return Call::make(op->type, !is_left_constant && !is_right_constant ? "vmadd" : "vaxpy", args,
Call::CallType::PureIntrinsic);
} }
}},
// vaxpy [Xd] = Xm * [Xn] + [Xd] // vrelu [Xd] = max([Xn], 0)
ExpressionPattern{ // vmaddrelu [Xd] = max(vmadd [Xd], 0)
2, Expr Mutate_(const Max *op, const Expr &e) {
[&, this](const Expr expr) -> int { bool is_left_zero = isZero(op->a);
if (((c1 * x + y).Match(expr) || (x * c1 + y).Match(expr) || (y + c1 * x).Match(expr) || bool is_right_zero = IsZero(op->b);
(y + c1 * x).Match(expr)) && if (!is_left_zero && !is_right_zero) {
(!is_constant(x.Eval()) && !is_constant(y.Eval()))) { return e;
return PRIOR; }
} Expr expr = op->a;
return UNMATCH; if (is_left_zero) {
}, expr = op->b;
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr { }
CHECK((c1 * x + y).Match(expr) || (x * c1 + y).Match(expr) || (y + c1 * x).Match(expr) ||
(y + c1 * x).Match(expr)); if (const Call *call = expr.as<Call>()) {
Expr x_eval = mutator.Mutate(x.Eval()); if (call->call_type == Call::CallType::PureIntrinsic && call->name == "vmadd") {
Expr y_eval = mutator.Mutate(y.Eval()); return Call::make(op->type, "vmaddrelu", call->args, Call::CallType::PureIntrinsic);
// check elemwise }
if (CountVars(x_eval) != CountVars(y_eval)) { }
return mutator.MutateWithoutSelection(c1.Eval() * x_eval + y_eval); return Call::make(op->type, "relu", {expr}, Call::CallType::PureIntrinsic);
} }
if (mutator.IsTmpTensor(y_eval) || y_eval.same_as(y.Eval())) { // int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vaxpy", {x_eval, y_eval, c1.Eval()}, y_eval.type())); // float(cc1) -> a[i] = cc1; cast(a[i])
Expr Mutate_(const Cast *op, const Expr &e) {
if (op->type.is_int() && op->value->IsInstance<Call>()) {
const Call *call = op->value.as<Call>();
if (call->name != "floor" && call->name != "ceil" && call->name != "round" && call->name != "trunc") {
return e;
}
if (op->type == call->type) {
return op->value;
} else { } else {
return mutator.MutateWithoutSelection(c1.Eval() * x_eval + y_eval); return Call::make(op->type, call->name, call->args, call->call_type, call->func, call->value_index);
}
}
if (op->type.is_float() && op->value->IsInstance<Variable>()) {
return Cast::make(op->type, mutator_.AllocateTmp(op->value));
}
return e;
} }
}},
// vrelu [Xd] = max([Xn], 0) Expr Mutate_(const Select *op, const Expr &e) {
ExpressionPattern{1, if (const Not *notCond = op->condition.as<Not>()) {
[&, this](const Expr expr) -> int { return Select::make(notCond->a, op->false_value, op->true_value);
if (((max(x, c1)).Match(expr) || (max(c1, x)).Match(expr)) && c1.Eval()->value == 0.0 && }
!is_constant(x.Eval()) && x.Eval().type() == Float(16, 1)) { if (const And *andCond = op->condition.as<And>()) {
return NORMAL; Expr tmpExpr = Select::make(andCond->a, op->true_value, op->false_value);
} return Select::make(andCond->b, tmpExpr, op->false_value);
return UNMATCH; }
}, if (const Or *orCond = op->condition.as<Or>()) {
Expr tmpExpr = Select::make(orCond->a, op->true_value, op->false_value);
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr { return Select::make(orCond->b, op->true_value, tmpExpr);
CHECK(((max(x, c1)).Match(expr) || (max(c1, x)).Match(expr))); }
Expr x_eval = mutator.Mutate(x.Eval()); return e;
return mutator.Mutate(CallPureIntrinsic("relu", {x_eval}, x_eval.type())); }
}},
private:
// adds [Xd] = ([Xn] + [Yn]) + imm -> [Xn] + ([Yn] + imm) Expr GetIndexOfPairExprForMul(const Expr &expr) {
ExpressionPattern{1, Expr ret_expr = expr;
[&, this](const Expr expr) -> int { bool pos = sign_map_.at(expr.get());
if ((((x - y) + c1).Match(expr) || (c1 + (x - y)).Match(expr) || ((x + y) + c1).Match(expr) || int dim = CountVars(expr);
(c1 + (x + y)).Match(expr)) && for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
!is_constant(x.Eval()) && !is_constant(y.Eval())) { if ((sign_map_.at((*iter).get()) != pos) || is_constant(*iter) || (iter->same_as(expr))) {
return NORMAL; continue;
} }
return UNMATCH; if (CountVars(*iter) > dim) {
}, continue;
}
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { ret_expr = *iter;
if (((x - y) + c1).Match(expr) || (c1 + (x - y)).Match(expr)) { exprs_.remove_if([&ret_expr](Expr e) { return e.same_as(ret_expr); });
Expr x_eval = mutator.Mutate(x.Eval()); break;
Expr y_eval = mutator.Mutate(y.Eval()); }
return mutator.Mutate(x_eval + (c1.Eval() - y_eval)); return ret_expr;
} }
if (((x + y) + c1).Match(expr) || (c1 + (x + y)).Match(expr)) {
Expr x_eval = mutator.Mutate(x.Eval()); ThreeAddressExprMutator &mutator_;
Expr y_eval = mutator.Mutate(y.Eval()); std::list<Expr> &exprs_;
return mutator.Mutate(x_eval + (y_eval + c1.Eval())); std::unordered_map<const Object *, std::string> &notation_map_;
std::unordered_map<const Object *, bool> &sign_map_;
};
class ExprOptMutator : public IRMutator {
public:
ExprOptMutator(ThreeAddressExprMutator &mutator) : mutator_(mutator) {}
~ExprOptMutator() override = default;
Expr Mutate(Expr expr) {
expr = IRMutator::Mutate(expr);
exprs_.sort([](Expr &e1, Expr &e2) -> bool {
int dim1 = CountVars(e1);
int dim2 = CountVars(e2);
if (dim1 == dim2) {
return !e1->IsInstance<Mul>();
}
return dim1 < dim2;
});
InstructionSelector selector(mutator_, exprs_, notation_map_, sign_map_);
for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
*iter = selector.Mutate(*iter);
} }
expr = RebuildExpr();
return expr; return expr;
}}, }
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc() Expr Mutate_(const Select *op, const Expr &e) {
ExpressionPattern{ InitExprStatusIfNeed(e);
1, Expr expr = Select::make(op->condition, ExprOptMutator(mutator_).Mutate(op->true_value),
[&, this](const Expr expr) -> int { ExprOptMutator(mutator_).Mutate(op->false_value));
if (((cast(pt, call_floor(x))).Match(expr) && pt.Eval().is_int()) || exprs_.push_back(expr);
((cast(pt, call_ceil(x))).Match(expr) && pt.Eval().is_int()) || return expr;
((cast(pt, call_round(x))).Match(expr) && pt.Eval().is_int()) || }
((cast(pt, call_trunc(x))).Match(expr) && pt.Eval().is_int())) {
return NORMAL; Expr Mutate_(const Add *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
}
return UNMATCH; Expr Mutate_(const Sub *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
},
Expr Mutate_(const Mul *op, const Expr &e) {
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { bool is_left_constant = is_constant(op->a);
if ((cast(pt, call_floor(x))).Match(expr) && pt.Eval().is_int()) { bool is_right_constant = is_constant(op->b);
Expr x_eval = mutator.Mutate(x.Eval()); if ((is_left_constant && is_left_constant) || (!is_left_constant && !is_right_constant)) {
return mutator.Mutate(Call::make(expr.type(), "floor", {x_eval}, Call::CallType::PureIntrinsic)); return AnalyzeBinaryOpExpr(op, e);
} }
if ((cast(pt, call_ceil(x))).Match(expr) && pt.Eval().is_int()) { Expr non_constant_expr = is_left_constant ? op->b : op->a;
Expr x_eval = mutator.Mutate(x.Eval()); Expr constant_expr = is_left_constant ? op->a : op->b;
return mutator.Mutate(Call::make(expr.type(), "ceil", {x_eval}, Call::CallType::PureIntrinsic));
} if (non_constant_expr->IsInstance<Add>()) {
if ((cast(pt, call_round(x))).Match(expr) && pt.Eval().is_int()) { const Add *add = non_constant_expr.as<Add>();
Expr x_eval = mutator.Mutate(x.Eval()); if (is_constant(add->a) || is_constant(add->b)) {
return mutator.Mutate(Call::make(expr.type(), "round", {x_eval}, Call::CallType::PureIntrinsic)); Expr expr = Add::make(Mul::make(constant_expr, add->a), Mul::make(constant_expr, add->b));
} if (notation_map_.find(e.get()) == notation_map_.end()) {
if ((cast(pt, call_trunc(x))).Match(expr) && pt.Eval().is_int()) { notation_map_[expr.get()] = notation_map_[e.get()];
Expr x_eval = mutator.Mutate(x.Eval()); }
return mutator.Mutate(Call::make(expr.type(), "trunc", {x_eval}, Call::CallType::PureIntrinsic)); if (sign_map_.find(e.get()) != sign_map_.end()) {
sign_map_[expr.get()] = sign_map_[e.get()];
}
return IRMutator::Mutate(expr);
}
}
if (non_constant_expr->IsInstance<Sub>()) {
const Sub *sub = non_constant_expr.as<Sub>();
if (is_constant(sub->a) || is_constant(sub->b)) {
Expr expr = Sub::make(Mul::make(constant_expr, sub->a), Mul::make(constant_expr, sub->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
}
if (sign_map_.find(e.get()) != sign_map_.end()) {
sign_map_[expr.get()] = sign_map_[e.get()];
} }
return IRMutator::Mutate(expr);
}
}
return AnalyzeBinaryOpExpr(op, e);
}
// Imm / x -> y = Imm; y/x
Expr Mutate_(const Div *op, const Expr &e) {
if (is_constant(op->a) && !is_constant(op->b)) {
Expr expr = Div::make(mutator_.AllocateTmp(op->a), op->b);
const Div *div = expr.as<Div>();
return AnalyzeBinaryOpExpr(div, expr);
}
return AnalyzeBinaryOpExpr(op, e);
}
Expr Mutate_(const Mod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const FloorDiv *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const FloorMod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Min *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Max *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const EQ *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const NE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const LT *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const LE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const GT *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const GE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const And *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Or *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Let *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr =
Let::make(op->var, ExprOptMutator(mutator_).Mutate(op->value), ExprOptMutator(mutator_).Mutate(op->body));
exprs_.push_back(expr);
return expr; return expr;
}}, }
// float(cc1) -> a[i] = cc1; cast(a[i]) Expr Mutate_(const Cast *op, const Expr &e) {
ExpressionPattern{1, InitExprStatusIfNeed(e);
[&, this](const Expr expr) -> int { Expr expr = Cast::make(op->type, ExprOptMutator(mutator_).Mutate(op->value));
if ((cast(pt, x)).Match(expr) && pt.Eval().is_float() && x.Eval().as<Variable>()) { exprs_.push_back(expr);
return NORMAL; return expr;
}
Expr Mutate_(const Not *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Not::make(ExprOptMutator(mutator_).Mutate(op->a));
exprs_.push_back(expr);
return expr;
} }
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { Expr Mutate_(const Load *op, const Expr &e) {
if ((cast(pt, x)).Match(expr) && pt.Eval().is_float() && x.Eval().as<Variable>()) { InitExprStatusIfNeed(e);
Expr tmp = mutator.AllocateTmp(x.Eval()); Expr expr = Load::make(op->type, op->buffer_var, ExprOptMutator(mutator_).Mutate(op->index),
return mutator.Mutate(Cast::make(expr.type(), tmp)); ExprOptMutator(mutator_).Mutate(op->predicate));
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Reduce *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> source;
for (Expr src : op->source) {
source.push_back(ExprOptMutator(mutator_).Mutate(src));
} }
Expr expr =
Reduce::make(op->combiner, source, op->axis, ExprOptMutator(mutator_).Mutate(op->condition), op->value_index);
exprs_.push_back(expr);
return expr; return expr;
}}, }
// Imm / x -> y = Imm; y/x Expr Mutate_(const Shuffle *op, const Expr &e) {
ExpressionPattern{1, InitExprStatusIfNeed(e);
[&, this](const Expr expr) -> int { Array<Expr> vectors;
if (div(c1, y).Match(expr) && is_constant(c1.Eval()) && !is_constant(y.Eval())) { for (Expr v : op->vectors) {
return NORMAL; vectors.push_back(ExprOptMutator(mutator_).Mutate(v));
}
Array<Expr> indices;
for (Expr indic : op->indices) {
indices.push_back(ExprOptMutator(mutator_).Mutate(indic));
}
Expr expr = Shuffle::make(vectors, indices);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Call *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> args;
for (Expr arg : op->args) {
args.push_back(ExprOptMutator(mutator_).Mutate(arg));
}
Expr expr = Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index);
mutator_.AddBroadCastCallIfNeed(op, expr);
exprs_.push_back(expr);
return expr;
} }
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr { Expr Mutate_(const Ramp *op, const Expr &e) {
CHECK(div(c1, y).Match(expr) && is_constant(c1.Eval()) && !is_constant(y.Eval())); InitExprStatusIfNeed(e);
Expr x_eval = mutator.AllocateTmp(c1.Eval()); Expr expr =
return mutator.Mutate(Div::make(x_eval, y.Eval())); Ramp::make(ExprOptMutator(mutator_).Mutate(op->base), ExprOptMutator(mutator_).Mutate(op->stride), op->lanes);
}}, exprs_.push_back(expr);
return expr;
}
ExpressionPattern{1, Expr Mutate_(const Broadcast *op, const Expr &e) {
[&, this](const Expr expr) -> int { InitExprStatusIfNeed(e);
if ((c1 * (c2 + x)).Match(expr) || (c1 * (c2 - x)).Match(expr)) { Expr expr = Broadcast::make(ExprOptMutator(mutator_).Mutate(op->value), op->lanes);
return NORMAL; exprs_.push_back(expr);
return expr;
} }
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { Expr Mutate_(const IntImm *op, const Expr &e) { return SaveAutomicExpr(e); }
if ((c1 * (c2 + x)).Match(expr)) {
return mutator.Mutate(Simplify_cce(x.Eval() * c1.Eval() + c1.Eval() * c2.Eval())); Expr Mutate_(const UIntImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const FloatImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const StringImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const Variable *op, const Expr &e) { return SaveAutomicExpr(e); }
private:
void InitExprStatusIfNeed(const Expr &e) {
const Object *object_e = e.get();
if (notation_map_.find(object_e) == notation_map_.end()) {
notation_map_[object_e] = e->GetTypeKey();
} }
if ((c1 * (c2 - x)).Match(expr)) { if (sign_map_.find(object_e) == sign_map_.end()) {
return mutator.Mutate(Simplify_cce(c1.Eval() * c2.Eval() - x.Eval() * c1.Eval())); sign_map_[object_e] = true;
}
}
bool IsNewRoot(const Expr &e) {
CHECK(notation_map_.find(e.get()) != notation_map_.end());
std::string root = notation_map_[e.get()];
std::string type_key = e->GetTypeKey();
return !((root == Add::_type_key || root == Sub::_type_key) &&
(type_key == Add::_type_key || type_key == Sub::_type_key)) ||
!((root == Mul::_type_key || root == Div::_type_key) &&
(type_key == Mul::_type_key || type_key == Div::_type_key));
}
template <typename T>
Expr AnalyzeBinaryOpExpr(const T *op, const Expr &e) {
InitExprStatusIfNeed(e);
const Object *object_e = e.get();
std::string root_of_e = notation_map_[object_e];
bool pos_of_e = sign_map_[object_e];
std::string type_key = e->GetTypeKey();
Expr expr = e;
if (IsNewRoot(e)) {
expr = T::make(ExprOptMutator(mutator_).Mutate(op->a), ExprOptMutator(mutator_).Mutate(op->b));
notation_map_[expr.get()] = root_of_e;
sign_map_[expr.get()] = pos_of_e;
exprs_.push_back(expr);
} else {
notation_map_[op->a.get()] = root_of_e;
notation_map_[op->b.get()] = root_of_e;
sign_map_[op->a.get()] = pos_of_e;
sign_map_[op->b.get()] = (type_key == Sub::_type_key || type_key == Div::_type_key) ? !pos_of_e : pos_of_e;
expr = T::make(IRMutator::Mutate(op->a), IRMutator::Mutate(op->b));
}
return expr;
}
Expr SaveAutomicExpr(const Expr &e) {
InitExprStatusIfNeed(e);
exprs_.push_back(e);
return e;
}
Expr RebuildExpr() {
CHECK(!exprs_.empty());
Expr expr = exprs_.front();
exprs_.pop_front();
while (!exprs_.empty()) {
expr = RebuildExpr(expr, exprs_.front());
exprs_.pop_front();
} }
return expr; return expr;
}},
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if ((select((z || w), x, y)).Match(expr) || (select((z && w), x, y)).Match(expr) ||
(select((!z), x, y)).Match(expr)) {
return NORMAL;
} }
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { Expr RebuildExpr(const Expr &expr1, const Expr &expr2) {
if ((select((z || w), x, y)).Match(expr)) { Expr expr = expr1;
Expr temp_eval = mutator.Mutate(Select::make(z.Eval(), x.Eval(), y.Eval())); Expr opnd = expr2;
return mutator.Mutate(Select::make(w.Eval(), x.Eval(), temp_eval)); if (sign_map_[expr2.get()] && !sign_map_[expr1.get()]) {
expr = expr2;
opnd = expr1;
} }
if ((select((z && w), x, y)).Match(expr)) {
Expr temp_eval = mutator.Mutate(Select::make(z.Eval(), x.Eval(), y.Eval())); if ((sign_map_[expr1.get()] && sign_map_[expr2.get()]) || (!sign_map_[expr1.get()] && !sign_map_[expr2.get()])) {
return mutator.Mutate(Select::make(w.Eval(), temp_eval, y.Eval())); if (notation_map_[expr1.get()] == Add::_type_key || notation_map_[expr1.get()] == Sub::_type_key) {
expr = Add::make(expr, opnd);
} else {
expr = Mul::make(expr, opnd);
}
} else {
if (notation_map_[expr1.get()] == Add::_type_key || notation_map_[expr1.get()] == Sub::_type_key) {
expr = Sub::make(expr, opnd);
} else {
expr = Div::make(expr, opnd);
} }
if ((select((!z), x, y)).Match(expr)) {
return mutator.Mutate(Select::make(z.Eval(), y.Eval(), x.Eval()));
} }
notation_map_[expr.get()] = notation_map_[expr1.get()];
sign_map_[expr.get()] = sign_map_[expr1.get()] || sign_map_[expr2.get()];
return expr; return expr;
}}}; }
ThreeAddressExprMutator &mutator_;
std::list<Expr> exprs_;
std::unordered_map<const Object *, std::string> notation_map_;
std::unordered_map<const Object *, bool> sign_map_;
}; };
Expr ThreeAddressExprMutator::Mutate(Expr expr) { class LoopMutator : public IRMutator {
// select instructions public:
InstructionMatcher matcher; LoopMutator() : loop_level_(0) {}
matcher.Match(expr); ~LoopMutator() override = default;
int idx = matcher.choice;
Expr ret; Stmt Mutate_(const For *op, const Stmt &s) final {
level_++; loop_level_++;
if (idx < 0 || disable_selection_ || level_ < matcher.ins_pattern[idx].min_level) { loop_vars_.push_front(op);
expr_stack.push_back(expr); Stmt stmt = IRMutator::Mutate(op->body);
ret = IRMutator::Mutate(expr); if (provides_.size() == 1 || provides_.front()->args.size() == provides_.front()->args.size()) {
expr_stack.pop_back(); return s;
} else { // match an intrinsic }
ret = matcher.ins_pattern[idx].replace_func(expr, *this); if (!stmt->IsInstance<For>()) {
provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); });
const Provide *provide = provides_.back();
stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args);
provides_.pop_back();
}
stmt = RebuildForStmt(loop_vars_.back(), stmt);
loop_vars_.pop_back();
loop_level_--;
return stmt;
} }
level_--;
return ret;
}
int ThreeAddressExprMutator::ct_ = 0; Stmt Mutate_(const Provide *op, const Stmt &s) final {
if (!loop_vars_.empty()) {
provides_.push_back(op);
}
return IRMutator::Mutate_(op, s);
}
private:
Stmt RebuildForStmt(const For *op, Stmt &body) {
Stmt stmt = body;
while (!provides_.empty() && op->loop_var.same_as(provides_.back()->args[0])) {
const Provide *second = provides_.back();
stmt = Block::make(Provide::make(second->func, second->value_index, second->value, second->args), stmt);
provides_.pop_back();
}
return For::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, stmt);
}
size_t loop_level_;
std::list<const Provide *> provides_;
std::list<const For *> loop_vars_;
};
class InferUpperBound { class InferUpperBound {
private: private:
...@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator { ...@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator {
// Bring over the common exprs from previous stage // Bring over the common exprs from previous stage
mutator.SetCommonExpr(global_common_expr_); mutator.SetCommonExpr(global_common_expr_);
} }
value = ExprOptMutator(mutator).Mutate(value);
value = mutator.Mutate(value); value = mutator.Mutate(value);
if (cross_stmt_simplify_) { if (cross_stmt_simplify_) {
// Take back the common exprs for next stages // Take back the common exprs for next stages
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册