提交 fae7628a 编写于 作者: Z zhaiyukun

Add Optimization for three address

  1.Arithmetic priority adjustment
  2.Instruction selection
上级 6335daad
......@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <arithmetic/pattern_match.h>
#include <dmlc/common.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
......@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
// forward declaration
class ThreeAddressExprMutator;
struct ExpressionPattern {
int min_level; // minimal level
std::function<int(Expr)> score_func; // assign score to a subtree
std::function<Expr(Expr, ThreeAddressExprMutator &)> replace_func; // replace a subtree with this instruction
};
class ThreeAddressFilter : public IRVisitor {
public:
bool Find(const Stmt &s) {
......@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator {
// forward declaration
Expr Mutate(Expr expr) override;
// do naive three address translation without instruction selection
Expr MutateWithoutSelection(const Expr expr) {
disable_selection_ = true;
Expr ret = Mutate(expr);
disable_selection_ = false;
return ret;
}
template <typename T>
Expr MutateBinaryOp(const T *op, const Expr &e) {
in_call_++;
......@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator {
Expr Mutate_(const FloatImm *op, const Expr &e) final { return MutateConstOp(op, e); }
Expr Mutate_(const IntImm *op, const Expr &e) final { return MutateConstOp(op, e); }
void AddBroadCastCallIfNeed(const Call *op, const Expr &e) {
if (broadcast_.find(op) == broadcast_.end()) {
return;
}
const Call *new_call = e.as<Call>();
CHECK_NOTNULL(new_call);
broadcast_.insert(new_call);
}
std::vector<Stmt> assign_stmt;
std::vector<Tensor> imm_tensors;
std::unordered_set<FunctionRef, air::NodeHash, air::NodeEqual> imm_ops;
......@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator {
Array<Expr> shape_;
std::unordered_map<size_t, std::pair<Expr, Expr>> common_exprs_; // hash value -> <match expr, replace expr>
std::unordered_map<FunctionRef, size_t, air::NodeHash, air::NodeEqual>
imm2hash_; // imm tensor -> hash value of the expr in the tensor
// imm tensor -> hash value of the expr in the tensor
std::unordered_map<FunctionRef, size_t, air::NodeHash, air::NodeEqual> imm2hash_;
int level_{0};
int in_call_{0};
......@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator {
ExprHasher hasher_;
};
Expr CallPureIntrinsic(const std::string &name, const Array<Expr> &args, const Type type) {
return Call::make(type, name, args, Call::CallType::PureIntrinsic);
Expr ThreeAddressExprMutator::Mutate(Expr expr) {
level_++;
expr_stack.push_back(expr);
Expr ret = IRMutator::Mutate(expr);
expr_stack.pop_back();
level_--;
return ret;
}
// Match instructions by dynamic programming on the tree
class InstructionMatcher {
int ThreeAddressExprMutator::ct_ = 0;
class InstructionSelector {
public:
void Match(const Expr value) {
int max_score = -1;
int max_i = -1;
// try patterns
for (size_t i = 0; i < ins_pattern.size(); ++i) {
int score_ = ins_pattern[i].score_func(value);
if (score_ > max_score) {
max_score = score_;
max_i = static_cast<int>(i);
InstructionSelector(ThreeAddressExprMutator &mutator, std::list<Expr> &exprs,
std::unordered_map<const Object *, std::string> &notation_map,
std::unordered_map<const Object *, bool> &sign_map)
: mutator_(mutator), exprs_(exprs), notation_map_(notation_map), sign_map_(sign_map) {}
~InstructionSelector() = default;
Expr Mutate(Expr expr) {
if (const Mul *op = expr.as<Mul>()) {
return Mutate_(op, expr);
}
if (const Cast *op = expr.as<Cast>()) {
return Mutate_(op, expr);
}
if (const Select *op = expr.as<Select>()) {
return Mutate_(op, expr);
}
return expr;
}
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vaxpy [Xd] = Xm * [Xn] + [Xd]
Expr Mutate_(const Mul *op, const Expr &e) {
std::string root = notation_map_.at(e.get());
if (root != Add::_type_key && root != Sub::_type_key) {
return e;
}
bool is_left_constant = is_constant(op->a);
bool is_right_constant = is_constant(op->b);
if (is_left_constant && is_right_constant) {
return e;
}
Expr expr = GetIndexOfPairExprForMul(e);
if (expr.same_as(e)) {
return e;
}
Array<Expr> args;
if (!is_left_constant) {
args.push_back(op->a);
} else {
args.push_back(op->b);
}
args.push_back(expr);
if (!is_right_constant) {
args.push_back(op->b);
} else {
args.push_back(op->a);
}
return Call::make(op->type, !is_left_constant && !is_right_constant ? "vmadd" : "vaxpy", args,
Call::CallType::PureIntrinsic);
}
// vrelu [Xd] = max([Xn], 0)
// vmaddrelu [Xd] = max(vmadd [Xd], 0)
Expr Mutate_(const Max *op, const Expr &e) {
bool is_left_zero = isZero(op->a);
bool is_right_zero = IsZero(op->b);
if (!is_left_zero && !is_right_zero) {
return e;
}
Expr expr = op->a;
if (is_left_zero) {
expr = op->b;
}
if (const Call *call = expr.as<Call>()) {
if (call->call_type == Call::CallType::PureIntrinsic && call->name == "vmadd") {
return Call::make(op->type, "vmaddrelu", call->args, Call::CallType::PureIntrinsic);
}
}
return Call::make(op->type, "relu", {expr}, Call::CallType::PureIntrinsic);
}
score = max_score;
choice = max_i;
}
int score;
int choice;
const int NORMAL = 20;
const int PRIOR = 50;
const int UNMATCH = -1;
air::arith::PVar<Expr> x, y, z, w;
air::arith::PVar<Type> pt;
air::arith::PVar<Floating> c1, c2;
std::vector<ExpressionPattern> ins_pattern{
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vmla [Xd] = [Xn] * [Xm] + [Xd]
ExpressionPattern{
2,
[&, this](const Expr &expr) -> int {
if (((x * y + z).Match(expr) || (z + x * y).Match(expr)) &&
(!is_constant(x.Eval()) && !is_constant(y.Eval()) && !is_constant(z.Eval()))) {
return PRIOR;
}
return UNMATCH;
},
[&, this](const Expr &expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK(((x * y + z)).Match(expr) || (z + x * y).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
Expr z_eval = mutator.Mutate(z.Eval());
// make sure elemwise inside
if (CountVars(x_eval) != CountVars(y_eval) || CountVars(x_eval) != CountVars(z_eval)) {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
// float(cc1) -> a[i] = cc1; cast(a[i])
Expr Mutate_(const Cast *op, const Expr &e) {
if (op->type.is_int() && op->value->IsInstance<Call>()) {
const Call *call = op->value.as<Call>();
if (call->name != "floor" && call->name != "ceil" && call->name != "round" && call->name != "trunc") {
return e;
}
if (op->type == call->type) {
return op->value;
} else {
return Call::make(op->type, call->name, call->args, call->call_type, call->func, call->value_index);
}
}
if (op->type.is_float() && op->value->IsInstance<Variable>()) {
return Cast::make(op->type, mutator_.AllocateTmp(op->value));
}
return e;
}
if (mutator.IsTmpTensor(x_eval)) {
return mutator.AssignTmp(x_eval, CallPureIntrinsic("vmadd", {y_eval, z_eval, x_eval}, x_eval.type()));
} else if (mutator.IsTmpTensor(y_eval)) {
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vmadd", {x_eval, z_eval, y_eval}, y_eval.type()));
} else if (mutator.IsTmpTensor(z_eval)) {
return mutator.AssignTmp(z_eval, CallPureIntrinsic("vmla", {x_eval, y_eval, z_eval}, z_eval.type()));
} else {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
}},
// vmaddrelu [Xd] = max([Xn] * [Xd] + [Xm], 0)
ExpressionPattern{
2,
[&, this](const Expr expr) -> int {
if (((max(x * y + z, c1)).Match(expr) || (max(z + x * y, c1)).Match(expr) || (max(c1, x * y + z)).Match(expr) ||
(max(c1, z + x * y)).Match(expr)) &&
c1.Eval()->value == 0.0 && (!is_constant(x.Eval()) && !is_constant(y.Eval()) && !is_constant(z.Eval()))) {
return PRIOR;
}
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK((max(x * y + z, c1)).Match(expr) || (max(z + x * y, c1)).Match(expr) ||
(max(c1, x * y + z)).Match(expr) || (max(c1, z + x * y)).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
Expr z_eval = mutator.Mutate(z.Eval());
// check elemwise
if (CountVars(x_eval) != CountVars(y_eval) || CountVars(x_eval) != CountVars(z_eval)) {
return mutator.MutateWithoutSelection(x_eval * y_eval + z_eval);
}
Expr Mutate_(const Select *op, const Expr &e) {
if (const Not *notCond = op->condition.as<Not>()) {
return Select::make(notCond->a, op->false_value, op->true_value);
}
if (const And *andCond = op->condition.as<And>()) {
Expr tmpExpr = Select::make(andCond->a, op->true_value, op->false_value);
return Select::make(andCond->b, tmpExpr, op->false_value);
}
if (const Or *orCond = op->condition.as<Or>()) {
Expr tmpExpr = Select::make(orCond->a, op->true_value, op->false_value);
return Select::make(orCond->b, op->true_value, tmpExpr);
}
return e;
}
if (mutator.IsTmpTensor(x_eval) || x_eval.same_as(x.Eval())) {
return mutator.AssignTmp(x_eval, CallPureIntrinsic("vmaddrelu", {y_eval, z_eval, x_eval}, x_eval.type()));
} else if (mutator.IsTmpTensor(y_eval) || y_eval.same_as(y.Eval())) {
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vmaddrelu", {x_eval, z_eval, y_eval}, y_eval.type()));
} else {
return mutator.MutateWithoutSelection(max(x_eval * y_eval + z_eval, c1.Eval()));
}
}},
// vaxpy [Xd] = Xm * [Xn] + [Xd]
ExpressionPattern{
2,
[&, this](const Expr expr) -> int {
if (((c1 * x + y).Match(expr) || (x * c1 + y).Match(expr) || (y + c1 * x).Match(expr) ||
(y + c1 * x).Match(expr)) &&
(!is_constant(x.Eval()) && !is_constant(y.Eval()))) {
return PRIOR;
}
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK((c1 * x + y).Match(expr) || (x * c1 + y).Match(expr) || (y + c1 * x).Match(expr) ||
(y + c1 * x).Match(expr));
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
// check elemwise
if (CountVars(x_eval) != CountVars(y_eval)) {
return mutator.MutateWithoutSelection(c1.Eval() * x_eval + y_eval);
}
private:
Expr GetIndexOfPairExprForMul(const Expr &expr) {
Expr ret_expr = expr;
bool pos = sign_map_.at(expr.get());
int dim = CountVars(expr);
for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
if ((sign_map_.at((*iter).get()) != pos) || is_constant(*iter) || (iter->same_as(expr))) {
continue;
}
if (CountVars(*iter) > dim) {
continue;
}
ret_expr = *iter;
exprs_.remove_if([&ret_expr](Expr e) { return e.same_as(ret_expr); });
break;
}
return ret_expr;
}
if (mutator.IsTmpTensor(y_eval) || y_eval.same_as(y.Eval())) {
return mutator.AssignTmp(y_eval, CallPureIntrinsic("vaxpy", {x_eval, y_eval, c1.Eval()}, y_eval.type()));
} else {
return mutator.MutateWithoutSelection(c1.Eval() * x_eval + y_eval);
}
}},
// vrelu [Xd] = max([Xn], 0)
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if (((max(x, c1)).Match(expr) || (max(c1, x)).Match(expr)) && c1.Eval()->value == 0.0 &&
!is_constant(x.Eval()) && x.Eval().type() == Float(16, 1)) {
return NORMAL;
}
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK(((max(x, c1)).Match(expr) || (max(c1, x)).Match(expr)));
Expr x_eval = mutator.Mutate(x.Eval());
return mutator.Mutate(CallPureIntrinsic("relu", {x_eval}, x_eval.type()));
}},
// adds [Xd] = ([Xn] + [Yn]) + imm -> [Xn] + ([Yn] + imm)
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if ((((x - y) + c1).Match(expr) || (c1 + (x - y)).Match(expr) || ((x + y) + c1).Match(expr) ||
(c1 + (x + y)).Match(expr)) &&
!is_constant(x.Eval()) && !is_constant(y.Eval())) {
return NORMAL;
}
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
if (((x - y) + c1).Match(expr) || (c1 + (x - y)).Match(expr)) {
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
return mutator.Mutate(x_eval + (c1.Eval() - y_eval));
}
if (((x + y) + c1).Match(expr) || (c1 + (x + y)).Match(expr)) {
Expr x_eval = mutator.Mutate(x.Eval());
Expr y_eval = mutator.Mutate(y.Eval());
return mutator.Mutate(x_eval + (y_eval + c1.Eval()));
}
return expr;
}},
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
ExpressionPattern{
1,
[&, this](const Expr expr) -> int {
if (((cast(pt, call_floor(x))).Match(expr) && pt.Eval().is_int()) ||
((cast(pt, call_ceil(x))).Match(expr) && pt.Eval().is_int()) ||
((cast(pt, call_round(x))).Match(expr) && pt.Eval().is_int()) ||
((cast(pt, call_trunc(x))).Match(expr) && pt.Eval().is_int())) {
return NORMAL;
}
return UNMATCH;
},
ThreeAddressExprMutator &mutator_;
std::list<Expr> &exprs_;
std::unordered_map<const Object *, std::string> &notation_map_;
std::unordered_map<const Object *, bool> &sign_map_;
};
class ExprOptMutator : public IRMutator {
public:
ExprOptMutator(ThreeAddressExprMutator &mutator) : mutator_(mutator) {}
~ExprOptMutator() override = default;
Expr Mutate(Expr expr) {
expr = IRMutator::Mutate(expr);
exprs_.sort([](Expr &e1, Expr &e2) -> bool {
int dim1 = CountVars(e1);
int dim2 = CountVars(e2);
if (dim1 == dim2) {
return !e1->IsInstance<Mul>();
}
return dim1 < dim2;
});
InstructionSelector selector(mutator_, exprs_, notation_map_, sign_map_);
for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
*iter = selector.Mutate(*iter);
}
expr = RebuildExpr();
return expr;
}
Expr Mutate_(const Select *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Select::make(op->condition, ExprOptMutator(mutator_).Mutate(op->true_value),
ExprOptMutator(mutator_).Mutate(op->false_value));
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Add *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
if ((cast(pt, call_floor(x))).Match(expr) && pt.Eval().is_int()) {
Expr x_eval = mutator.Mutate(x.Eval());
return mutator.Mutate(Call::make(expr.type(), "floor", {x_eval}, Call::CallType::PureIntrinsic));
Expr Mutate_(const Sub *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Mul *op, const Expr &e) {
bool is_left_constant = is_constant(op->a);
bool is_right_constant = is_constant(op->b);
if ((is_left_constant && is_left_constant) || (!is_left_constant && !is_right_constant)) {
return AnalyzeBinaryOpExpr(op, e);
}
Expr non_constant_expr = is_left_constant ? op->b : op->a;
Expr constant_expr = is_left_constant ? op->a : op->b;
if (non_constant_expr->IsInstance<Add>()) {
const Add *add = non_constant_expr.as<Add>();
if (is_constant(add->a) || is_constant(add->b)) {
Expr expr = Add::make(Mul::make(constant_expr, add->a), Mul::make(constant_expr, add->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
}
if ((cast(pt, call_ceil(x))).Match(expr) && pt.Eval().is_int()) {
Expr x_eval = mutator.Mutate(x.Eval());
return mutator.Mutate(Call::make(expr.type(), "ceil", {x_eval}, Call::CallType::PureIntrinsic));
if (sign_map_.find(e.get()) != sign_map_.end()) {
sign_map_[expr.get()] = sign_map_[e.get()];
}
if ((cast(pt, call_round(x))).Match(expr) && pt.Eval().is_int()) {
Expr x_eval = mutator.Mutate(x.Eval());
return mutator.Mutate(Call::make(expr.type(), "round", {x_eval}, Call::CallType::PureIntrinsic));
return IRMutator::Mutate(expr);
}
}
if (non_constant_expr->IsInstance<Sub>()) {
const Sub *sub = non_constant_expr.as<Sub>();
if (is_constant(sub->a) || is_constant(sub->b)) {
Expr expr = Sub::make(Mul::make(constant_expr, sub->a), Mul::make(constant_expr, sub->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
}
if ((cast(pt, call_trunc(x))).Match(expr) && pt.Eval().is_int()) {
Expr x_eval = mutator.Mutate(x.Eval());
return mutator.Mutate(Call::make(expr.type(), "trunc", {x_eval}, Call::CallType::PureIntrinsic));
if (sign_map_.find(e.get()) != sign_map_.end()) {
sign_map_[expr.get()] = sign_map_[e.get()];
}
return expr;
}},
// float(cc1) -> a[i] = cc1; cast(a[i])
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if ((cast(pt, x)).Match(expr) && pt.Eval().is_float() && x.Eval().as<Variable>()) {
return NORMAL;
}
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
if ((cast(pt, x)).Match(expr) && pt.Eval().is_float() && x.Eval().as<Variable>()) {
Expr tmp = mutator.AllocateTmp(x.Eval());
return mutator.Mutate(Cast::make(expr.type(), tmp));
}
return expr;
}},
// Imm / x -> y = Imm; y/x
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if (div(c1, y).Match(expr) && is_constant(c1.Eval()) && !is_constant(y.Eval())) {
return NORMAL;
}
return UNMATCH;
},
[&, this](const Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
CHECK(div(c1, y).Match(expr) && is_constant(c1.Eval()) && !is_constant(y.Eval()));
Expr x_eval = mutator.AllocateTmp(c1.Eval());
return mutator.Mutate(Div::make(x_eval, y.Eval()));
}},
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if ((c1 * (c2 + x)).Match(expr) || (c1 * (c2 - x)).Match(expr)) {
return NORMAL;
}
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
if ((c1 * (c2 + x)).Match(expr)) {
return mutator.Mutate(Simplify_cce(x.Eval() * c1.Eval() + c1.Eval() * c2.Eval()));
}
if ((c1 * (c2 - x)).Match(expr)) {
return mutator.Mutate(Simplify_cce(c1.Eval() * c2.Eval() - x.Eval() * c1.Eval()));
}
return expr;
}},
ExpressionPattern{1,
[&, this](const Expr expr) -> int {
if ((select((z || w), x, y)).Match(expr) || (select((z && w), x, y)).Match(expr) ||
(select((!z), x, y)).Match(expr)) {
return NORMAL;
}
return UNMATCH;
},
[&, this](Expr expr, ThreeAddressExprMutator &mutator) -> Expr {
if ((select((z || w), x, y)).Match(expr)) {
Expr temp_eval = mutator.Mutate(Select::make(z.Eval(), x.Eval(), y.Eval()));
return mutator.Mutate(Select::make(w.Eval(), x.Eval(), temp_eval));
}
if ((select((z && w), x, y)).Match(expr)) {
Expr temp_eval = mutator.Mutate(Select::make(z.Eval(), x.Eval(), y.Eval()));
return mutator.Mutate(Select::make(w.Eval(), temp_eval, y.Eval()));
}
if ((select((!z), x, y)).Match(expr)) {
return mutator.Mutate(Select::make(z.Eval(), y.Eval(), x.Eval()));
}
return expr;
}}};
return IRMutator::Mutate(expr);
}
}
return AnalyzeBinaryOpExpr(op, e);
}
// Imm / x -> y = Imm; y/x
Expr Mutate_(const Div *op, const Expr &e) {
if (is_constant(op->a) && !is_constant(op->b)) {
Expr expr = Div::make(mutator_.AllocateTmp(op->a), op->b);
const Div *div = expr.as<Div>();
return AnalyzeBinaryOpExpr(div, expr);
}
return AnalyzeBinaryOpExpr(op, e);
}
Expr Mutate_(const Mod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const FloorDiv *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const FloorMod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Min *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Max *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const EQ *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const NE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const LT *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const LE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const GT *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const GE *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const And *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Or *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Let *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr =
Let::make(op->var, ExprOptMutator(mutator_).Mutate(op->value), ExprOptMutator(mutator_).Mutate(op->body));
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Cast *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Cast::make(op->type, ExprOptMutator(mutator_).Mutate(op->value));
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Not *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Not::make(ExprOptMutator(mutator_).Mutate(op->a));
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Load *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Load::make(op->type, op->buffer_var, ExprOptMutator(mutator_).Mutate(op->index),
ExprOptMutator(mutator_).Mutate(op->predicate));
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Reduce *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> source;
for (Expr src : op->source) {
source.push_back(ExprOptMutator(mutator_).Mutate(src));
}
Expr expr =
Reduce::make(op->combiner, source, op->axis, ExprOptMutator(mutator_).Mutate(op->condition), op->value_index);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Shuffle *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> vectors;
for (Expr v : op->vectors) {
vectors.push_back(ExprOptMutator(mutator_).Mutate(v));
}
Array<Expr> indices;
for (Expr indic : op->indices) {
indices.push_back(ExprOptMutator(mutator_).Mutate(indic));
}
Expr expr = Shuffle::make(vectors, indices);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Call *op, const Expr &e) {
InitExprStatusIfNeed(e);
Array<Expr> args;
for (Expr arg : op->args) {
args.push_back(ExprOptMutator(mutator_).Mutate(arg));
}
Expr expr = Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index);
mutator_.AddBroadCastCallIfNeed(op, expr);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Ramp *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr =
Ramp::make(ExprOptMutator(mutator_).Mutate(op->base), ExprOptMutator(mutator_).Mutate(op->stride), op->lanes);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const Broadcast *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Broadcast::make(ExprOptMutator(mutator_).Mutate(op->value), op->lanes);
exprs_.push_back(expr);
return expr;
}
Expr Mutate_(const IntImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const UIntImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const FloatImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const StringImm *op, const Expr &e) { return SaveAutomicExpr(e); }
Expr Mutate_(const Variable *op, const Expr &e) { return SaveAutomicExpr(e); }
private:
void InitExprStatusIfNeed(const Expr &e) {
const Object *object_e = e.get();
if (notation_map_.find(object_e) == notation_map_.end()) {
notation_map_[object_e] = e->GetTypeKey();
}
if (sign_map_.find(object_e) == sign_map_.end()) {
sign_map_[object_e] = true;
}
}
bool IsNewRoot(const Expr &e) {
CHECK(notation_map_.find(e.get()) != notation_map_.end());
std::string root = notation_map_[e.get()];
std::string type_key = e->GetTypeKey();
return !((root == Add::_type_key || root == Sub::_type_key) &&
(type_key == Add::_type_key || type_key == Sub::_type_key)) ||
!((root == Mul::_type_key || root == Div::_type_key) &&
(type_key == Mul::_type_key || type_key == Div::_type_key));
}
template <typename T>
Expr AnalyzeBinaryOpExpr(const T *op, const Expr &e) {
InitExprStatusIfNeed(e);
const Object *object_e = e.get();
std::string root_of_e = notation_map_[object_e];
bool pos_of_e = sign_map_[object_e];
std::string type_key = e->GetTypeKey();
Expr expr = e;
if (IsNewRoot(e)) {
expr = T::make(ExprOptMutator(mutator_).Mutate(op->a), ExprOptMutator(mutator_).Mutate(op->b));
notation_map_[expr.get()] = root_of_e;
sign_map_[expr.get()] = pos_of_e;
exprs_.push_back(expr);
} else {
notation_map_[op->a.get()] = root_of_e;
notation_map_[op->b.get()] = root_of_e;
sign_map_[op->a.get()] = pos_of_e;
sign_map_[op->b.get()] = (type_key == Sub::_type_key || type_key == Div::_type_key) ? !pos_of_e : pos_of_e;
expr = T::make(IRMutator::Mutate(op->a), IRMutator::Mutate(op->b));
}
return expr;
}
Expr SaveAutomicExpr(const Expr &e) {
InitExprStatusIfNeed(e);
exprs_.push_back(e);
return e;
}
Expr RebuildExpr() {
CHECK(!exprs_.empty());
Expr expr = exprs_.front();
exprs_.pop_front();
while (!exprs_.empty()) {
expr = RebuildExpr(expr, exprs_.front());
exprs_.pop_front();
}
return expr;
}
Expr RebuildExpr(const Expr &expr1, const Expr &expr2) {
Expr expr = expr1;
Expr opnd = expr2;
if (sign_map_[expr2.get()] && !sign_map_[expr1.get()]) {
expr = expr2;
opnd = expr1;
}
if ((sign_map_[expr1.get()] && sign_map_[expr2.get()]) || (!sign_map_[expr1.get()] && !sign_map_[expr2.get()])) {
if (notation_map_[expr1.get()] == Add::_type_key || notation_map_[expr1.get()] == Sub::_type_key) {
expr = Add::make(expr, opnd);
} else {
expr = Mul::make(expr, opnd);
}
} else {
if (notation_map_[expr1.get()] == Add::_type_key || notation_map_[expr1.get()] == Sub::_type_key) {
expr = Sub::make(expr, opnd);
} else {
expr = Div::make(expr, opnd);
}
}
notation_map_[expr.get()] = notation_map_[expr1.get()];
sign_map_[expr.get()] = sign_map_[expr1.get()] || sign_map_[expr2.get()];
return expr;
}
ThreeAddressExprMutator &mutator_;
std::list<Expr> exprs_;
std::unordered_map<const Object *, std::string> notation_map_;
std::unordered_map<const Object *, bool> sign_map_;
};
Expr ThreeAddressExprMutator::Mutate(Expr expr) {
// select instructions
InstructionMatcher matcher;
matcher.Match(expr);
int idx = matcher.choice;
Expr ret;
level_++;
if (idx < 0 || disable_selection_ || level_ < matcher.ins_pattern[idx].min_level) {
expr_stack.push_back(expr);
ret = IRMutator::Mutate(expr);
expr_stack.pop_back();
} else { // match an intrinsic
ret = matcher.ins_pattern[idx].replace_func(expr, *this);
class LoopMutator : public IRMutator {
public:
LoopMutator() : loop_level_(0) {}
~LoopMutator() override = default;
Stmt Mutate_(const For *op, const Stmt &s) final {
loop_level_++;
loop_vars_.push_front(op);
Stmt stmt = IRMutator::Mutate(op->body);
if (provides_.size() == 1 || provides_.front()->args.size() == provides_.front()->args.size()) {
return s;
}
if (!stmt->IsInstance<For>()) {
provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); });
const Provide *provide = provides_.back();
stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args);
provides_.pop_back();
}
stmt = RebuildForStmt(loop_vars_.back(), stmt);
loop_vars_.pop_back();
loop_level_--;
return stmt;
}
level_--;
return ret;
}
int ThreeAddressExprMutator::ct_ = 0;
Stmt Mutate_(const Provide *op, const Stmt &s) final {
if (!loop_vars_.empty()) {
provides_.push_back(op);
}
return IRMutator::Mutate_(op, s);
}
private:
Stmt RebuildForStmt(const For *op, Stmt &body) {
Stmt stmt = body;
while (!provides_.empty() && op->loop_var.same_as(provides_.back()->args[0])) {
const Provide *second = provides_.back();
stmt = Block::make(Provide::make(second->func, second->value_index, second->value, second->args), stmt);
provides_.pop_back();
}
return For::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, stmt);
}
size_t loop_level_;
std::list<const Provide *> provides_;
std::list<const For *> loop_vars_;
};
class InferUpperBound {
private:
......@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator {
// Bring over the common exprs from previous stage
mutator.SetCommonExpr(global_common_expr_);
}
value = ExprOptMutator(mutator).Mutate(value);
value = mutator.Mutate(value);
if (cross_stmt_simplify_) {
// Take back the common exprs for next stages
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册