提交 fae7628a 编写于 作者: Z zhaiyukun

Add Optimization for three address

  1.Arithmetic priority adjustment
  2.Instruction selection
上级 6335daad
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <arithmetic/pattern_match.h>
#include <dmlc/common.h> #include <dmlc/common.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
...@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>; ...@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
// forward declaration // forward declaration
class ThreeAddressExprMutator; class ThreeAddressExprMutator;
struct ExpressionPattern {
int min_level; // minimal level
std::function<int(Expr)> score_func; // assign score to a subtree
std::function<Expr(Expr, ThreeAddressExprMutator &)> replace_func; // replace a subtree with this instruction
};
class ThreeAddressFilter : public IRVisitor { class ThreeAddressFilter : public IRVisitor {
public: public:
bool Find(const Stmt &s) { bool Find(const Stmt &s) {
...@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator {
// forward declaration // forward declaration
Expr Mutate(Expr expr) override; Expr Mutate(Expr expr) override;
// do naive three address translation without instruction selection
Expr MutateWithoutSelection(const Expr expr) {
disable_selection_ = true;
Expr ret = Mutate(expr);
disable_selection_ = false;
return ret;
}
template <typename T> template <typename T>
Expr MutateBinaryOp(const T *op, const Expr &e) { Expr MutateBinaryOp(const T *op, const Expr &e) {
in_call_++; in_call_++;
...@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator {
Expr Mutate_(const FloatImm *op, const Expr &e) final { return MutateConstOp(op, e); } Expr Mutate_(const FloatImm *op, const Expr &e) final { return MutateConstOp(op, e); }
Expr Mutate_(const IntImm *op, const Expr &e) final { return MutateConstOp(op, e); } Expr Mutate_(const IntImm *op, const Expr &e) final { return MutateConstOp(op, e); }
void AddBroadCastCallIfNeed(const Call *op, const Expr &e) {
if (broadcast_.find(op) == broadcast_.end()) {
return;
}
const Call *new_call = e.as<Call>();
CHECK_NOTNULL(new_call);
broadcast_.insert(new_call);
}
std::vector<Stmt> assign_stmt; std::vector<Stmt> assign_stmt;
std::vector<Tensor> imm_tensors; std::vector<Tensor> imm_tensors;
std::unordered_set<FunctionRef, air::NodeHash, air::NodeEqual> imm_ops; std::unordered_set<FunctionRef, air::NodeHash, air::NodeEqual> imm_ops;
...@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator {
Array<Expr> shape_; Array<Expr> shape_;
std::unordered_map<size_t, std::pair<Expr, Expr>> common_exprs_; // hash value -> <match expr, replace expr> std::unordered_map<size_t, std::pair<Expr, Expr>> common_exprs_; // hash value -> <match expr, replace expr>
std::unordered_map<FunctionRef, size_t, air::NodeHash, air::NodeEqual> // imm tensor -> hash value of the expr in the tensor
imm2hash_; // imm tensor -> hash value of the expr in the tensor std::unordered_map<FunctionRef, size_t, air::NodeHash, air::NodeEqual> imm2hash_;
int level_{0}; int level_{0};
int in_call_{0}; int in_call_{0};
...@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator { ...@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator {
ExprHasher hasher_; ExprHasher hasher_;
}; };
Expr CallPureIntrinsic(const std::string &name, const Array<Expr> &args, const Type type) { Expr ThreeAddressExprMutator::Mutate(Expr expr) {
return Call::make(type, name, args, Call::CallType::PureIntrinsic); level_++;
expr_stack.push_back(expr);
Expr ret = IRMutator::Mutate(expr);
expr_stack.pop_back();
level_--;
return ret;
} }
// Match instructions by dynamic programming on the tree int ThreeAddressExprMutator::ct_ = 0;
class InstructionMatcher {
class InstructionSelector {
public: public:
void Match(const Expr value) { InstructionSelector(ThreeAddressExprMutator &mutator, std::list<Expr> &exprs,
int max_score = -1; std::unordered_map<const Object *, std::string> &notation_map,
int max_i = -1; std::unordered_map<const Object *, bool> &sign_map)
: mutator_(mutator), exprs_(exprs), notation_map_(notation_map), sign_map_(sign_map) {}
// try patterns ~InstructionSelector() = default;
for (size_t i = 0; i < ins_pattern.size(); ++i) {
int score_ = ins_pattern[i].score_func(value); Expr Mutate(Expr expr) {
if (score_ > max_score) { if (const Mul *op = expr.as<Mul>()) {
max_score = score_; return Mutate_(op, expr);
max_i = static_cast<int>(i); }
if (const Cast *op = expr.as<Cast>()) {
return Mutate_(op, expr);
}
if (const Select *op = expr.as<Select>()) {
return Mutate_(op, expr);
}
return expr;
}
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vaxpy [Xd] = Xm * [Xn] + [Xd]
Expr Mutate_(const Mul *op, const Expr &e) {
std::string root = notation_map_.at(e.get());
if (root != Add::_type_key && root != Sub::_type_key) {
return e;
}
bool is_left_constant = is_constant(op->a);
bool is_right_constant = is_constant(op->b);
if (is_left_constant && is_right_constant) {
return e;
}
Expr expr = GetIndexOfPairExprForMul(e);
if (expr.same_as(e)) {
return e;
}
Array<Expr> args;
if (!is_left_constant) {
args.push_back(op->a);
} else {
args.push_back(op->b);
}
args.push_back(expr);
if (!is_right_constant) {
args.push_back(op->b);
} else {
args.push_back(op->a);
}
return Call::make(op->type, !is_left_constant && !is_right_constant ? "vmadd" : "vaxpy", args,
Call::CallType::PureIntrinsic);
}
// vrelu [Xd] = max([Xn], 0)
// vmaddrelu [Xd] = max(vmadd [Xd], 0)
Expr Mutate_(const Max *op, const Expr &e) {
bool is_left_zero = isZero(op->a);
bool is_right_zero = IsZero(op->b);
if (!is_left_zero && !is_right_zero) {
return e;
}
Expr expr = op->a;
if (is_left_zero) {
expr = op->b;
}
if (const Call *call = expr.as<Call>()) {
if (call->call_type == Call::CallType::PureIntrinsic && call->name == "vmadd") {
return Call::make(op->type, "vmaddrelu", call->args, Call::CallType::PureIntrinsic);
} }
} }
return Call::make(op->type, "relu", {expr}, Call::CallType::PureIntrinsic);
}
score = max_score; // int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
choice = max_i; // float(cc1) -> a[i] = cc1; cast(a[i])
} Expr Mutate_(const Cast *op, const Expr &e) {
if (op->type.is_int() && op->value->IsInstance<Call>()) {
int score; const Call *call = op->value.as<Call>();
int choice; if (call->name != "floor" && call->name != "ceil" && call->name != "round" && call->name != "trunc") {
const int NORMAL = 20; return e;
const int PRIOR = 50; }
const int UNMATCH = -1; if (op->type == call->type) {
air::arith::PVar<Expr> x, y, z, w; return op->value;
air::arith::PVar<Type> pt; } else {
air::arith::PVar<Floating> c1, c2; return Call::make(op->type, call->name, call->args, call->call_type, call->func, call->value_index);
}
std::vector<ExpressionPattern> ins_pattern{ }
// vmadd [Xd] = [Xn] * [Xd] + [Xm] if (op->type.is_float() && op->value->IsInstance<Variable>()) {
// vmla [Xd] = [Xn] * [Xm] + [Xd] return Cast::make(op->type, mutator_.AllocateTmp(op->value));
ExpressionPattern{ }
2, return e;
[&, this](const Expr &expr) -> int { }
if (((x * y + z).Match(expr) || (z + x * y).Match(expr)) &&
(!is_constant(x.Eval()) && !is_constant(y.Eval()) && !is_constant(z.Eval()))) {
return PRIOR;
}
return UNMATCH;
},
[&, this](const Expr &expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK(((x * y + z)).Match(expr) || (z + x * y).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
Expr z_eval = mutator.Mutate(z.Eval());
// make sure elemwise inside
if (CountVars(x_eval) != CountVars(y_eval) || CountVars(x_eval) != CountVars(z_eval)) {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
if (mutator.IsTmpTensor(x_eval)) { Expr Mutate_(const Select *op, const Expr &e) {
return mutator.AssignTmp(x_eval, CallPureIntrinsic("vmadd", {y_eval, z_eval, x_eval}, x_eval.type())); if (const Not *notCond = op->condition.as<Not>()) {
} else if (mutator.IsTmpTensor(y_eval)) { return Select::make(notCond->a, op->false_value, op->true_value);
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vmadd", {x_eval, z_eval, y_eval}, y_eval.type())); }
} else if (mutator.IsTmpTensor(z_eval)) { if (const And *andCond = op->condition.as<And>()) {
return mutator.AssignTmp(z_eval, CallPureIntrinsic("vmla", {x_eval, y_eval, z_eval}, z_eval.type())); Expr tmpExpr = Select::make(andCond->a, op->true_value, op->false_value);
} else { return Select::make(andCond->b, tmpExpr, op->false_value);
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval); }
} if (const Or *orCond = op->condition.as<Or>()) {
}}, Expr tmpExpr = Select::make(orCond->a, op->true_value, op->false_value);
return Select::make(orCond->b, op->true_value, tmpExpr);
// vmaddrelu [Xd] = max([Xn] * [Xd] + [Xm], 0) }
ExpressionPattern{ return e;
2, }
[&, this](const Expr expr) -> int {
if (((max(x * y + z, c1)).Match(expr) || (max(z + x * y, c1)).Match(expr) || (max(c1, x * y + z)).Match(expr) ||
(max(c1, z + x * y)).Match(expr)) &&
c1.Eval()->value == 0.0 && (!is_constant(x.Eval()) && !is_constant(y.Eval()) && !is_constant(z.Eval()))) {
return PRIOR;
}
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK((max(x * y + z, c1)).Match(expr) || (max(z + x * y, c1)).Match(expr) ||
(max(c1, x * y + z)).Match(expr) || (max(c1, z + x * y)).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
Expr z_eval = mutator.Mutate(z.Eval());
// check elemwise
if (CountVars(x_eval) != CountVars(y_eval) || CountVars(x_eval) != CountVars(z_eval)) {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
if (mutator.IsTmpTensor(x_eval) || x_eval.same_as(x.Eval())) { private:
return mutator.AssignTmp(x_eval, CallPureIntrinsic("vmaddrelu", {y_eval, z_eval, x_eval}, x_eval.type())); Expr GetIndexOfPairExprForMul(const Expr &expr) {
} else if (mutator.IsTmpTensor(y_eval) || y_eval.same_as(y.Eval())) { Expr ret_expr = expr;
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vmaddrelu", {x_eval, z_eval, y_eval}, y_eval.type())); bool pos = sign_map_.at(expr.get());
} else { int dim = CountVars(expr);
return mutator.MutateWithoutSelection(max(x_eval * y_eval + z_eval, c1.Eval())); for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
} if ((sign_map_.at((*iter).get()) != pos) || is_constant(*iter) || (iter->same_as(expr))) {
}}, continue;
}
// vaxpy [Xd] = Xm * [Xn] + [Xd] if (CountVars(*iter) > dim) {
ExpressionPattern{ continue;
2, }
[&, this](const Expr expr) -> int { ret_expr = *iter;
if (((c1 * x + y).Match(expr) || (x * c1 + y).Match(expr) || (y + c1 * x).Match(expr) || exprs_.remove_if([&ret_expr](Expr e) { return e.same_as(ret_expr); });
(y + c1 * x).Match(expr)) && break;
(!is_constant(x.Eval()) && !is_constant(y.Eval()))) { }
return PRIOR; return ret_expr;
} }
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK((c1 * x + y).Match(expr) || (x * c1 + y).Match(expr) || (y + c1 * x).Match(expr) ||
(y + c1 * x).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
// check elemwise
if (CountVars(x_eval) != CountVars(y_eval)) {
return mutator.MutateWithoutSelection(c1.Eval() * x_eval + y_eval);
}
if (mutator.IsTmpTensor(y_eval) || y_eval.same_as(y.Eval())) { ThreeAddressExprMutator &mutator_;
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vaxpy", {x_eval, y_eval, c1.Eval()}, y_eval.type())); std::list<Expr> &exprs_;
} else { std::unordered_map<const Object *, std::string> &notation_map_;
return mutator.MutateWithoutSelection(c1.Eval() * x_eval + y_eval); std::unordered_map<const Object *, bool> &sign_map_;
} };
}},
class ExprOptMutator : public IRMutator {
// vrelu [Xd] = max([Xn], 0) public:
ExpressionPattern{1, ExprOptMutator(ThreeAddressExprMutator &mutator) : mutator_(mutator) {}
[&, this](const Expr expr) -> int { ~ExprOptMutator() override = default;
if (((max(x, c1)).Match(expr) || (max(c1, x)).Match(expr)) && c1.Eval()->value == 0.0 &&
!is_constant(x.Eval()) && x.Eval().type() == Float(16, 1)) { Expr Mutate(Expr expr) {
return NORMAL; expr = IRMutator::Mutate(expr);
} exprs_.sort([](Expr &e1, Expr &e2) -> bool {
return UNMATCH; int dim1 = CountVars(e1);
}, int dim2 = CountVars(e2);
if (dim1 == dim2) {
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr { return !e1->IsInstance<Mul>();
CHECK(((max(x, c1)).Match(expr) || (max(c1, x)).Match(expr))); }
Expr x_eval = mutator.Mutate(x.Eval()); return dim1 < dim2;
return mutator.Mutate(CallPureIntrinsic("relu", {x_eval}, x_eval.type())); });
}}, InstructionSelector selector(mutator_, exprs_, notation_map_, sign_map_);
for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
// adds [Xd] = ([Xn] + [Yn]) + imm -> [Xn] + ([Yn] + imm) *iter = selector.Mutate(*iter);
ExpressionPattern{1, }
[&, this](const Expr expr) -> int { expr = RebuildExpr();
if ((((x - y) + c1).Match(expr) || (c1 + (x - y)).Match(expr) || ((x + y) + c1).Match(expr) || return expr;
(c1 + (x + y)).Match(expr)) && }
!is_constant(x.Eval()) && !is_constant(y.Eval())) {
return NORMAL; Expr Mutate_(const Select *op, const Expr &e) {
} InitExprStatusIfNeed(e);
return UNMATCH; Expr expr = Select::make(op->condition, ExprOptMutator(mutator_).Mutate(op->true_value),
}, ExprOptMutator(mutator_).Mutate(op->false_value));
exprs_.push_back(expr);
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { return expr;
if (((x - y) + c1).Match(expr) || (c1 + (x - y)).Match(expr)) { }
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval()); Expr Mutate_(const Add *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
return mutator.Mutate(x_eval + (c1.Eval() - y_eval));
}
if (((x + y) + c1).Match(expr) || (c1 + (x + y)).Match(expr)) {
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
return mutator.Mutate(x_eval + (y_eval + c1.Eval()));
}
return expr;
}},
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
ExpressionPattern{
1,
[&, this](const Expr expr) -> int {
if (((cast(pt, call_floor(x))).Match(expr) && pt.Eval().is_int()) ||
((cast(pt, call_ceil(x))).Match(expr) && pt.Eval().is_int()) ||
((cast(pt, call_round(x))).Match(expr) && pt.Eval().is_int()) ||
((cast(pt, call_trunc(x))).Match(expr) && pt.Eval().is_int())) {
return NORMAL;
}
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { Expr Mutate_(const Sub *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
if ((cast(pt, call_floor(x))).Match(expr) && pt.Eval().is_int()) {
Expr x_eval = mutator.Mutate(x.Eval()); Expr Mutate_(const Mul *op, const Expr &e) {
return mutator.Mutate(Call::make(expr.type(), "floor", {x_eval}, Call::CallType::PureIntrinsic)); bool is_left_constant = is_constant(op->a);
bool is_right_constant = is_constant(op->b);
if ((is_left_constant && is_left_constant) || (!is_left_constant && !is_right_constant)) {
return AnalyzeBinaryOpExpr(op, e);
}
Expr non_constant_expr = is_left_constant ? op->b : op->a;
Expr constant_expr = is_left_constant ? op->a : op->b;
if (non_constant_expr->IsInstance<Add>()) {
const Add *add = non_constant_expr.as<Add>();
if (is_constant(add->a) || is_constant(add->b)) {
Expr expr = Add::make(Mul::make(constant_expr, add->a), Mul::make(constant_expr, add->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
} }
if ((cast(pt, call_ceil(x))).Match(expr) && pt.Eval().is_int()) { if (sign_map_.find(e.get()) != sign_map_.end()) {
Expr x_eval = mutator.Mutate(x.Eval()); sign_map_[expr.get()] = sign_map_[e.get()];
return mutator.Mutate(Call::make(expr.type(), "ceil", {x_eval}, Call::CallType::PureIntrinsic));
} }
if ((cast(pt, call_round(x))).Match(expr) && pt.Eval().is_int()) { return IRMutator::Mutate(expr);
Expr x_eval = mutator.Mutate(x.Eval()); }
return mutator.Mutate(Call::make(expr.type(), "round", {x_eval}, Call::CallType::PureIntrinsic)); }
if (non_constant_expr->IsInstance<Sub>()) {
const Sub *sub = non_constant_expr.as<Sub>();
if (is_constant(sub->a) || is_constant(sub->b)) {
Expr expr = Sub::make(Mul::make(constant_expr, sub->a), Mul::make(constant_expr, sub->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
} }
if ((cast(pt, call_trunc(x))).Match(expr) && pt.Eval().is_int()) { if (sign_map_.find(e.get()) != sign_map_.end()) {
Expr x_eval = mutator.Mutate(x.Eval()); sign_map_[expr.get()] = sign_map_[e.get()];
return mutator.Mutate(Call::make(expr.type(), "trunc", {x_eval}, Call::CallType::PureIntrinsic));
} }
return expr; return IRMutator::Mutate(expr);
}}, }
}
// float(cc1) -> a[i] = cc1; cast(a[i]) return AnalyzeBinaryOpExpr(op, e);
ExpressionPattern{1, }
[&, this](const Expr expr) -> int {
if ((cast(pt, x)).Match(expr) && pt.Eval().is_float() && x.Eval().as<Variable>()) { // Imm / x -> y = Imm; y/x
return NORMAL; Expr Mutate_(const Div *op, const Expr &e) {
} if (is_constant(op->a) && !is_constant(op->b)) {
return UNMATCH; Expr expr = Div::make(mutator_.AllocateTmp(op->a), op->b);
}, const Div *div = expr.as<Div>();
return AnalyzeBinaryOpExpr(div, expr);
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { }
if ((cast(pt, x)).Match(expr) && pt.Eval().is_float() && x.Eval().as<Variable>()) { return AnalyzeBinaryOpExpr(op, e);
Expr tmp = mutator.AllocateTmp(x.Eval()); }
return mutator.Mutate(Cast::make(expr.type(), tmp));
} Expr Mutate_(const Mod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
return expr;
}}, Expr Mutate_(const FloorDiv *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
// Imm / x -> y = Imm; y/x Expr Mutate_(const FloorMod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
ExpressionPattern{1,
[&, this](const Expr expr) -> int { Expr Mutate_(const Min *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
if (div(c1, y).Match(expr) && is_constant(c1.Eval()) && !is_constant(y.Eval())) {
return NORMAL; Expr Mutate_(const Max *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
}
return UNMATCH; Expr Mutate_(const EQ *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
},
Expr Mutate_(const NE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK(div(c1, y).Match(expr) && is_constant(c1.Eval()) && !is_constant(y.Eval())); Expr Mutate_(const LT *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr x_eval = mutator.AllocateTmp(c1.Eval());
return mutator.Mutate(Div::make(x_eval, y.Eval())); Expr Mutate_(const LE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
}},
Expr Mutate_(const GT *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
ExpressionPattern{1,
[&, this](const Expr expr) -> int { Expr Mutate_(const GE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
if ((c1 * (c2 + x)).Match(expr) || (c1 * (c2 - x)).Match(expr)) {
return NORMAL; Expr Mutate_(const And *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
}
return UNMATCH; Expr Mutate_(const Or *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
},
Expr Mutate_(const Let *op, const Expr &e) {
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { InitExprStatusIfNeed(e);
if ((c1 * (c2 + x)).Match(expr)) { Expr expr =
return mutator.Mutate(Simplify_cce(x.Eval() * c1.Eval() + c1.Eval() * c2.Eval())); Let::make(op->var, ExprOptMutator(mutator_).Mutate(op->value), ExprOptMutator(mutator_).Mutate(op->body));
} exprs_.push_back(expr);
if ((c1 * (c2 - x)).Match(expr)) { return expr;
return mutator.Mutate(Simplify_cce(c1.Eval() * c2.Eval() - x.Eval() * c1.Eval())); }
}
return expr; Expr Mutate_(const Cast *op, const Expr &e) {
}}, InitExprStatusIfNeed(e);
ExpressionPattern{1, Expr expr = Cast::make(op->type, ExprOptMutator(mutator_).Mutate(op->value));
[&, this](const Expr expr) -> int { exprs_.push_back(expr);
if ((select((z || w), x, y)).Match(expr) || (select((z && w), x, y)).Match(expr) || return expr;
(select((!z), x, y)).Match(expr)) { }
return NORMAL;
} Expr Mutate_(const Not *op, const Expr &e) {
return UNMATCH; InitExprStatusIfNeed(e);
}, Expr expr = Not::make(ExprOptMutator(mutator_).Mutate(op->a));
exprs_.push_back(expr);
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr { return expr;
if ((select((z || w), x, y)).Match(expr)) { }
Expr temp_eval = mutator.Mutate(Select::make(z.Eval(), x.Eval(), y.Eval()));
return mutator.Mutate(Select::make(w.Eval(), x.Eval(), temp_eval)); Expr Mutate_(const Load *op, const Expr &e) {
} InitExprStatusIfNeed(e);
if ((select((z && w), x, y)).Match(expr)) { Expr expr = Load::make(op->type, op->buffer_var, ExprOptMutator(mutator_).Mutate(op->index),
Expr temp_eval = mutator.Mutate(Select::make(z.Eval(), x.Eval(), y.Eval())); ExprOptMutator(mutator_).Mutate(op->predicate));
return mutator.Mutate(Select::make(w.Eval(), temp_eval, y.Eval())); exprs_.push_back(expr);
} return expr;
if ((select((!z), x, y)).Match(expr)) { }
return mutator.Mutate(Select::make(z.Eval(), y.Eval(), x.Eval()));
} Expr Mutate_(const Reduce *op, const Expr &e) {
return expr; InitExprStatusIfNeed(e);
}}}; Array<Expr> source;
for (Expr src : op->source) {
source.push_back(ExprOptMutator(mutator_).Mutate(src));
}
Expr expr =
Reduce::make(op->combiner, source, op->axis, ExprOptMutator(mutator_).Mutate(op->condition), op->value_index);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Shuffle *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> vectors;
for (Expr v : op->vectors) {
vectors.push_back(ExprOptMutator(mutator_).Mutate(v));
}
Array<Expr> indices;
for (Expr indic : op->indices) {
indices.push_back(ExprOptMutator(mutator_).Mutate(indic));
}
Expr expr = Shuffle::make(vectors, indices);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Call *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> args;
for (Expr arg : op->args) {
args.push_back(ExprOptMutator(mutator_).Mutate(arg));
}
Expr expr = Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index);
mutator_.AddBroadCastCallIfNeed(op, expr);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Ramp *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr =
Ramp::make(ExprOptMutator(mutator_).Mutate(op->base), ExprOptMutator(mutator_).Mutate(op->stride), op->lanes);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Broadcast *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Broadcast::make(ExprOptMutator(mutator_).Mutate(op->value), op->lanes);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const IntImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const UIntImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const FloatImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const StringImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const Variable *op, const Expr &e) { return SaveAutomicExpr(e); }
private:
void InitExprStatusIfNeed(const Expr &e) {
const Object *object_e = e.get();
if (notation_map_.find(object_e) == notation_map_.end()) {
notation_map_[object_e] = e->GetTypeKey();
}
if (sign_map_.find(object_e) == sign_map_.end()) {
sign_map_[object_e] = true;
}
}
bool IsNewRoot(const Expr &e) {
CHECK(notation_map_.find(e.get()) != notation_map_.end());
std::string root = notation_map_[e.get()];
std::string type_key = e->GetTypeKey();
return !((root == Add::_type_key || root == Sub::_type_key) &&
(type_key == Add::_type_key || type_key == Sub::_type_key)) ||
!((root == Mul::_type_key || root == Div::_type_key) &&
(type_key == Mul::_type_key || type_key == Div::_type_key));
}
template <typename T>
Expr AnalyzeBinaryOpExpr(const T *op, const Expr &e) {
InitExprStatusIfNeed(e);
const Object *object_e = e.get();
std::string root_of_e = notation_map_[object_e];
bool pos_of_e = sign_map_[object_e];
std::string type_key = e->GetTypeKey();
Expr expr = e;
if (IsNewRoot(e)) {
expr = T::make(ExprOptMutator(mutator_).Mutate(op->a), ExprOptMutator(mutator_).Mutate(op->b));
notation_map_[expr.get()] = root_of_e;
sign_map_[expr.get()] = pos_of_e;
exprs_.push_back(expr);
} else {
notation_map_[op->a.get()] = root_of_e;
notation_map_[op->b.get()] = root_of_e;
sign_map_[op->a.get()] = pos_of_e;
sign_map_[op->b.get()] = (type_key == Sub::_type_key || type_key == Div::_type_key) ? !pos_of_e : pos_of_e;
expr = T::make(IRMutator::Mutate(op->a), IRMutator::Mutate(op->b));
}
return expr;
}
Expr SaveAutomicExpr(const Expr &e) {
InitExprStatusIfNeed(e);
exprs_.push_back(e);
return e;
}
Expr RebuildExpr() {
CHECK(!exprs_.empty());
Expr expr = exprs_.front();
exprs_.pop_front();
while (!exprs_.empty()) {
expr = RebuildExpr(expr, exprs_.front());
exprs_.pop_front();
}
return expr;
}
Expr RebuildExpr(const Expr &expr1, const Expr &expr2) {
Expr expr = expr1;
Expr opnd = expr2;
if (sign_map_[expr2.get()] && !sign_map_[expr1.get()]) {
expr = expr2;
opnd = expr1;
}
if ((sign_map_[expr1.get()] && sign_map_[expr2.get()]) || (!sign_map_[expr1.get()] && !sign_map_[expr2.get()])) {
if (notation_map_[expr1.get()] == Add::_type_key || notation_map_[expr1.get()] == Sub::_type_key) {
expr = Add::make(expr, opnd);
} else {
expr = Mul::make(expr, opnd);
}
} else {
if (notation_map_[expr1.get()] == Add::_type_key || notation_map_[expr1.get()] == Sub::_type_key) {
expr = Sub::make(expr, opnd);
} else {
expr = Div::make(expr, opnd);
}
}
notation_map_[expr.get()] = notation_map_[expr1.get()];
sign_map_[expr.get()] = sign_map_[expr1.get()] || sign_map_[expr2.get()];
return expr;
}
ThreeAddressExprMutator &mutator_;
std::list<Expr> exprs_;
std::unordered_map<const Object *, std::string> notation_map_;
std::unordered_map<const Object *, bool> sign_map_;
}; };
Expr ThreeAddressExprMutator::Mutate(Expr expr) { class LoopMutator : public IRMutator {
// select instructions public:
InstructionMatcher matcher; LoopMutator() : loop_level_(0) {}
matcher.Match(expr); ~LoopMutator() override = default;
int idx = matcher.choice;
Expr ret; Stmt Mutate_(const For *op, const Stmt &s) final {
level_++; loop_level_++;
if (idx < 0 || disable_selection_ || level_ < matcher.ins_pattern[idx].min_level) { loop_vars_.push_front(op);
expr_stack.push_back(expr); Stmt stmt = IRMutator::Mutate(op->body);
ret = IRMutator::Mutate(expr); if (provides_.size() == 1 || provides_.front()->args.size() == provides_.front()->args.size()) {
expr_stack.pop_back(); return s;
} else { // match an intrinsic }
ret = matcher.ins_pattern[idx].replace_func(expr, *this); if (!stmt->IsInstance<For>()) {
provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); });
const Provide *provide = provides_.back();
stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args);
provides_.pop_back();
}
stmt = RebuildForStmt(loop_vars_.back(), stmt);
loop_vars_.pop_back();
loop_level_--;
return stmt;
} }
level_--;
return ret;
}
int ThreeAddressExprMutator::ct_ = 0; Stmt Mutate_(const Provide *op, const Stmt &s) final {
if (!loop_vars_.empty()) {
provides_.push_back(op);
}
return IRMutator::Mutate_(op, s);
}
private:
Stmt RebuildForStmt(const For *op, Stmt &body) {
Stmt stmt = body;
while (!provides_.empty() && op->loop_var.same_as(provides_.back()->args[0])) {
const Provide *second = provides_.back();
stmt = Block::make(Provide::make(second->func, second->value_index, second->value, second->args), stmt);
provides_.pop_back();
}
return For::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, stmt);
}
size_t loop_level_;
std::list<const Provide *> provides_;
std::list<const For *> loop_vars_;
};
class InferUpperBound { class InferUpperBound {
private: private:
...@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator { ...@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator {
// Bring over the common exprs from previous stage // Bring over the common exprs from previous stage
mutator.SetCommonExpr(global_common_expr_); mutator.SetCommonExpr(global_common_expr_);
} }
value = ExprOptMutator(mutator).Mutate(value);
value = mutator.Mutate(value); value = mutator.Mutate(value);
if (cross_stmt_simplify_) { if (cross_stmt_simplify_) {
// Take back the common exprs for next stages // Take back the common exprs for next stages
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册