提交 05313741 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!102 Add Loop Mutator & Relu to pass three address

Merge pull request !102 from ConnZhai/conn
......@@ -35,22 +35,71 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
// forward declaration
class ThreeAddressExprMutator;
class ThreeAddressFilter : public IRVisitor {
class ExprArgsFetcher : public IRVisitor {
public:
bool Find(const Stmt &s) {
Visit(s);
return need_;
explicit ExprArgsFetcher(Array<Expr> args) : args_(args), index_(args_.size() - 1) {}
~ExprArgsFetcher() override = default;
Array<Expr> GetArgs(const Expr &e) {
Visit(e);
if (max_dim >= args_.size()) {
return args_;
}
Array<Expr> args;
while (index_ < args_.size()) {
args.push_back(args_[index_]);
index_++;
}
if (CountVars(args) == CountVars(args_)) {
return args_;
}
return args;
}
bool MustBroadcast(const Expr &e) {
if (is_constant(e) || CountVars(e) == 0) {
return false;
}
size_t size = GetArgs(e).size();
return size > max_dim;
}
void Visit_(const Call *op) override {
if (op->name == "load3d_l1_ub") {
need_ = false;
if (op->call_type == Call::CallType::Halide) {
max_dim = max_dim < op->args.size() ? op->args.size() : max_dim;
for (Expr arg : op->args) {
size_t index = GetIndex(arg);
index_ = index_ > index ? index : index_;
}
} else {
CHECK(op->call_type == Call::CallType::PureIntrinsic);
for (Expr e : op->args) {
Array<Expr> args = GetArgs(e);
max_dim = max_dim < args.size() ? args.size() : max_dim;
for (Expr arg : args) {
size_t index = GetIndex(arg);
index_ = index_ > index ? index : index_;
}
}
}
IRVisitor::Visit_(op);
}
private:
bool need_{true};
size_t GetIndex(const Expr &arg) {
if (is_constant(arg)) {
return index_;
}
for (size_t i = 0; i < args_.size(); ++i) {
if (args_[i].same_as(arg)) {
return i;
}
}
return index_;
}
Array<Expr> args_;
size_t index_;
size_t max_dim{0};
};
class ScalarOperandFinder : public IRVisitor {
......@@ -188,50 +237,21 @@ std::unordered_set<Tensor> GetExprTensors(const Expr expr) {
return tensors;
}
// Replace all instances of a Tensor "from" in an Expr with a new one "to"
class ReplaceProvideTensors : public IRMutator {
public:
ReplaceProvideTensors(const Tensor &from, const Operation &to) : from_(from->op), to_(to) {}
~ReplaceProvideTensors() override = default;
Stmt Mutate_(const Provide *op, const Stmt &s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Provide>();
CHECK(op);
if (op->func == from_) {
stmt = Provide::make(to_, op->value_index, op->value, op->args);
}
return stmt;
}
Expr Mutate_(const Call *op, const Expr &e) override {
Expr expr = IRMutator::Mutate_(op, e);
const Call *n = expr.as<Call>();
CHECK(n);
if (n->func == from_) {
expr = Call::make(n->type, to_->name, n->args, n->call_type, to_, n->value_index);
}
return expr;
}
private:
const Operation from_;
const Operation to_;
};
// Mutate expression according to selection choices
class ThreeAddressExprMutator : public IRMutator {
public:
ThreeAddressExprMutator(const Tensor output, const Array<Expr> &args, const Array<Expr> &shape,
const std::unordered_set<const Call *> &broadcast, bool IsReductionOp,
bool cross_stmt_simplify)
ThreeAddressExprMutator(const Tensor output, const Array<Expr> &args, const Array<Expr> &out_args,
const Array<Expr> &shape, const std::unordered_set<const Call *> &broadcast,
bool IsReductionOp, bool cross_stmt_simplify, bool is_simple = false)
: output_(output),
args_(args),
out_args_(out_args),
shape_(shape),
broadcast_(broadcast),
IsReductionOp_(IsReductionOp),
cross_simplify_(cross_stmt_simplify),
hasher_(cross_stmt_simplify) {
hasher_(cross_stmt_simplify),
is_simple_(is_simple) {
CHECK_EQ(args_.size(), shape_.size());
if (shape_.empty()) { // scalar values should have at least one dimension and contains one element
shape_.push_back(1);
......@@ -246,7 +266,7 @@ class ThreeAddressExprMutator : public IRMutator {
common_exprs_.insert(global_common_expr.begin(), global_common_expr.end());
}
Expr AllocateTmp(Expr value) {
Expr AllocateTmp(Expr value, Array<Expr> args = {}) {
// detect common expression
size_t hash_value = hasher_(value);
auto x = common_exprs_[hash_value];
......@@ -263,13 +283,17 @@ class ThreeAddressExprMutator : public IRMutator {
// allocate new immediate tensor
Tensor imm;
imm = PlaceholderOpNode::make(output_->op->name + "_" + std::to_string(ct_++), shape_, value.type()).output(0);
if (args.empty()) {
args = args_;
}
std::string name = output_->op->name + "_" + std::to_string(ct_++);
imm = PlaceholderOpNode::make(name, GetShape(args), value.type()).output(0);
imm_tensors.push_back(imm);
imm_ops.insert(imm->op);
// update common expr
assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args_));
Expr ret = Call::make(value.type(), imm->op->name, args_, Call::CallType::Halide, imm->op, imm->value_index);
assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args));
Expr ret = Call::make(value.type(), imm->op->name, args, Call::CallType::Halide, imm->op, imm->value_index);
common_exprs_[hash_value] = std::make_pair(value, ret);
imm2hash_[imm->op] = hash_value;
return ret;
......@@ -283,9 +307,13 @@ class ThreeAddressExprMutator : public IRMutator {
common_exprs_.erase(old_hash);
// update new common expr
assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args_));
Array<Expr> args = args_;
if (is_simple_) {
args = ExprArgsFetcher(args_).GetArgs(value);
}
assign_stmt.push_back(Provide::make(imm->op, imm->value_index, value, args));
size_t hash_value = hasher_(value);
Expr ret = Call::make(value.type(), imm->op->name, args_, Call::CallType::Halide, imm->op, imm->value_index);
Expr ret = Call::make(value.type(), imm->op->name, args, Call::CallType::Halide, imm->op, imm->value_index);
common_exprs_[hash_value] = std::make_pair(value, ret);
imm2hash_[imm->op] = hash_value;
return ret;
......@@ -324,20 +352,25 @@ class ThreeAddressExprMutator : public IRMutator {
Expr r = Mutate(op->b);
in_call_--;
bool broadcast_l = !IsReductionOp_ && !is_constant(l) && CountVars(args_) > CountVars(l);
bool broadcast_r = !IsReductionOp_ && !is_constant(r) && CountVars(args_) > CountVars(r);
Array<Expr> args = args_;
if (is_simple_) {
args = ExprArgsFetcher(args_).GetArgs(T::make(l, r));
}
bool broadcast_l = !IsReductionOp_ && !is_constant(l) && CountVars(args) > CountVars(l);
bool broadcast_r = !IsReductionOp_ && !is_constant(r) && CountVars(args) > CountVars(r);
if (op->template IsInstance<Add>() || op->template IsInstance<Mul>()) {
if (broadcast_l && broadcast_r) {
l = AllocateTmp(l);
} else if (is_constant(r) && broadcast_l) {
l = AllocateTmp(l);
} else if (is_constant(l) && broadcast_r) {
r = AllocateTmp(r);
if (broadcast_l && (broadcast_r || is_constant(r))) {
l = AllocateTmp(l, args);
} else if ((broadcast_r && is_constant(l))) {
r = AllocateTmp(r, args);
}
if (CountVars(args) > CountVars(r) && ExprArgsFetcher(out_args_).MustBroadcast(r)) {
r = AllocateTmp(r, args);
}
}
return AllocateTmp(T::make(Mutate(l), Mutate(r)));
return AllocateTmp(T::make(Mutate(l), Mutate(r)), args);
}
Expr Mutate_(const Add *op, const Expr &e) final { return MutateBinaryOp<Add>(op, e); }
......@@ -443,8 +476,15 @@ class ThreeAddressExprMutator : public IRMutator {
// broadcast when need
if (broadcast_.count(op) && broadcast) {
if (expr_stack.size() >= 2 && expr_stack[expr_stack.size() - 2]->IsInstance<Div>()) {
Array<Expr> args = ExprArgsFetcher(args_).GetArgs(expr_stack[expr_stack.size() - 2]);
if (CountVars(e) < CountVars(args)) {
return AllocateTmp(e, args);
}
} else {
return AllocateTmp(e);
}
}
// this is going to generate a tensor of tensor expr, like A(B(i))
return e;
} else if (op->call_type == Call::CallType::PureIntrinsic && op->name == air::ir::intrinsic::tvm_if_then_else) {
......@@ -546,8 +586,25 @@ class ThreeAddressExprMutator : public IRMutator {
}
}
Array<Expr> GetShape(const Array<Expr> &args) {
if (CountVars(args) == CountVars(args_)) {
return shape_;
}
const size_t dim = args.size();
const size_t maxDim = output_->shape.size();
CHECK_LE(dim, maxDim);
Array<Expr> shape;
size_t index = maxDim - dim;
while (index < maxDim) {
shape.push_back(output_->shape[index]);
index++;
}
return shape;
}
Tensor output_;
Array<Expr> args_;
Array<Expr> out_args_;
Array<Expr> shape_;
std::unordered_map<size_t, std::pair<Expr, Expr>> common_exprs_; // hash value -> <match expr, replace expr>
......@@ -566,6 +623,7 @@ class ThreeAddressExprMutator : public IRMutator {
bool IsReductionOp_{false};
bool cross_simplify_;
ExprHasher hasher_;
bool is_simple_;
};
Expr ThreeAddressExprMutator::Mutate(Expr expr) {
......@@ -579,66 +637,78 @@ Expr ThreeAddressExprMutator::Mutate(Expr expr) {
int ThreeAddressExprMutator::ct_ = 0;
class InstructionSelector {
class InstructionMutator : IRMutator {
public:
InstructionSelector(ThreeAddressExprMutator &mutator, std::list<Expr> &exprs,
std::unordered_map<const Object *, std::string> &notation_map,
std::unordered_map<const Object *, bool> &sign_map)
: mutator_(mutator), exprs_(exprs), notation_map_(notation_map), sign_map_(sign_map) {}
~InstructionSelector() = default;
explicit InstructionMutator(ThreeAddressExprMutator &mutator, Array<Expr> &args) : mutator_(mutator), args_(args) {}
~InstructionMutator() = default;
Expr Mutate(Expr expr) {
if (const Mul *op = expr.as<Mul>()) {
return Mutate_(op, expr);
Expr Mutate(Expr value) { return IRMutator::Mutate(value); }
// VMADD.type {f16, f32} [Xd], [Xn], [Xm], Xt, MASK
// [Xd] = [Xn] * [Xd] + [Xm]
// VAXPY.type {f16, f32, fmix} [Xd], [Xn], Xm, Xt, MASK
// [Xd] = Xm * [Xn] + [Xd]
Expr Mutate_(const Add *op, const Expr &e) {
Expr l = Mutate(op->a);
Expr r = Mutate(op->b);
if (is_constant(l) && is_constant(r)) {
return ConstantFold<Add>(l, r);
}
if (const Cast *op = expr.as<Cast>()) {
return Mutate_(op, expr);
return Add::make(l, r);
}
if (const Select *op = expr.as<Select>()) {
return Mutate_(op, expr);
Expr Mutate_(const Sub *op, const Expr &e) {
Expr l = Mutate(op->a);
Expr r = Mutate(op->b);
if (is_constant(l) && is_constant(r)) {
return ConstantFold<Sub>(l, r);
}
return expr;
return Sub::make(l, r);
}
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vaxpy [Xd] = Xm * [Xn] + [Xd]
Expr Mutate_(const Mul *op, const Expr &e) {
std::string root = notation_map_.at(e.get());
if (root != Add::_type_key && root != Sub::_type_key) {
return e;
Expr l = Mutate(op->a);
Expr r = Mutate(op->b);
bool is_left_constant = is_constant(l);
bool is_right_constant = is_constant(r);
if (!is_left_constant && !is_right_constant) {
return Mul::make(l, r);
}
bool is_left_constant = is_constant(op->a);
bool is_right_constant = is_constant(op->b);
if (is_left_constant && is_right_constant) {
return e;
return ConstantFold<Mul>(l, r);
}
Expr expr = GetIndexOfPairExprForMul(e);
if (expr.same_as(e)) {
return e;
Expr constant = is_left_constant ? l : r;
Expr nonconstant = is_left_constant ? r : l;
if (const Add *add = nonconstant.as<Add>()) {
return MulExprMutator<Add>(constant, add);
} else if (const Sub *sub = nonconstant.as<Sub>()) {
return MulExprMutator<Sub>(constant, sub);
}
Array<Expr> args;
if (!is_left_constant) {
args.push_back(op->a);
} else {
args.push_back(op->b);
return Mul::make(l, r);
}
args.push_back(expr);
if (!is_right_constant) {
args.push_back(op->b);
} else {
args.push_back(op->a);
Expr Mutate_(const Div *op, const Expr &e) {
Expr l = Mutate(op->a);
Expr r = Mutate(op->b);
if (is_constant(l) && is_constant(r)) {
return ConstantFold<Div>(l, r);
} else if (is_constant(l)) {
l = mutator_.AllocateTmp(l, ExprArgsFetcher(args_).GetArgs(Div::make(l, r)));
}
return Call::make(op->type, !is_left_constant && !is_right_constant ? "vmadd" : "vaxpy", args,
Call::CallType::PureIntrinsic);
return Div::make(l, r);
}
// vrelu [Xd] = max([Xn], 0)
// vmaddrelu [Xd] = max(vmadd [Xd], 0)
Expr Mutate_(const Max *op, const Expr &e) {
// relu only support fp16
if (!op->type.is_float() || op->type.bits() != 16) {
return Max::make(Mutate(op->a), Mutate(op->b));
}
bool is_left_zero = isZero(op->a);
bool is_right_zero = IsZero(op->b);
if (!is_left_zero && !is_right_zero) {
return e;
return Max::make(Mutate(op->a), Mutate(op->b));
}
Expr expr = op->a;
if (is_left_zero) {
......@@ -656,91 +726,154 @@ class InstructionSelector {
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
// float(cc1) -> a[i] = cc1; cast(a[i])
Expr Mutate_(const Cast *op, const Expr &e) {
if (op->type.is_int() && op->value->IsInstance<Call>()) {
const Call *call = op->value.as<Call>();
Expr value = Mutate(op->value);
if (op->type.is_int() && value->IsInstance<Call>()) {
const Call *call = value.as<Call>();
if (call->name != "floor" && call->name != "ceil" && call->name != "round" && call->name != "trunc") {
return e;
return Cast::make(op->type, value);
}
if (op->type == call->type) {
return op->value;
return value;
} else {
return Call::make(op->type, call->name, call->args, call->call_type, call->func, call->value_index);
}
}
if (op->type.is_float() && op->value->IsInstance<Variable>()) {
return Cast::make(op->type, mutator_.AllocateTmp(op->value));
if (op->type.is_float() && value->IsInstance<Variable>()) {
return Cast::make(op->type, mutator_.AllocateTmp(value));
}
return e;
return Cast::make(op->type, value);
}
Expr Mutate_(const Select *op, const Expr &e) {
if (const Not *notCond = op->condition.as<Not>()) {
return Select::make(notCond->a, op->false_value, op->true_value);
Expr condition = Mutate(op->condition);
Expr true_value = Mutate(op->true_value);
Expr false_value = Mutate(op->false_value);
if (const Not *notCond = condition.as<Not>()) {
return Select::make(notCond->a, false_value, true_value);
}
if (const And *andCond = op->condition.as<And>()) {
Expr tmpExpr = Select::make(andCond->a, op->true_value, op->false_value);
return Select::make(andCond->b, tmpExpr, op->false_value);
if (const And *andCond = condition.as<And>()) {
Expr tmpExpr = Select::make(andCond->a, true_value, false_value);
return Select::make(andCond->b, tmpExpr, false_value);
}
if (const Or *orCond = op->condition.as<Or>()) {
Expr tmpExpr = Select::make(orCond->a, op->true_value, op->false_value);
return Select::make(orCond->b, op->true_value, tmpExpr);
if (const Or *orCond = condition.as<Or>()) {
Expr tmpExpr = Select::make(orCond->a, true_value, false_value);
return Select::make(orCond->b, true_value, tmpExpr);
}
return e;
return Select::make(condition, true_value, false_value);
}
private:
Expr GetIndexOfPairExprForMul(const Expr &expr) {
Expr ret_expr = expr;
bool pos = sign_map_.at(expr.get());
int dim = CountVars(expr);
for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
if ((sign_map_.at((*iter).get()) != pos) || is_constant(*iter) || (iter->same_as(expr))) {
continue;
template <typename T>
Expr MulExprMutator(Expr &imm, const T *op) {
Expr l = Mutate(op->a);
Expr r = Mutate(op->b);
if (is_constant(l)) {
return Mutate(T::make(ConstantFold<Mul>(imm, l), Mul::make(r, imm)));
} else if (is_constant(r)) {
return Mutate(T::make(ConstantFold<Mul>(imm, r), Mul::make(l, imm)));
}
if (CountVars(*iter) > dim) {
continue;
return Mul::make(T::make(l, r), imm);
}
ret_expr = *iter;
exprs_.remove_if([&ret_expr](Expr e) { return e.same_as(ret_expr); });
break;
template <typename T>
Expr ConstantFold(const Expr &a, const Expr &b) {
CHECK(a.type().is_int() || a.type().is_uint() || a.type().is_float());
if (a.type() != b.type()) {
CHECK(a.type() == b.type());
}
CHECK(a.type() == b.type());
if (const IntImm *int_a = a.as<IntImm>()) {
const IntImm *int_b = b.as<IntImm>();
return IntImm::make(a.type(), ComputeConstant<int64_t, T>(int_a->value, int_b->value));
}
if (const UIntImm *uint_a = a.as<UIntImm>()) {
const UIntImm *uint_b = b.as<UIntImm>();
return UIntImm::make(a.type(), ComputeConstant<uint64_t, T>(uint_a->value, uint_b->value));
}
const FloatImm *float_a = a.as<FloatImm>();
const FloatImm *float_b = b.as<FloatImm>();
return FloatImm::make(a.type(), ComputeConstant<double, T>(float_a->value, float_b->value));
}
return ret_expr;
template <typename Data, typename Op>
Data ComputeConstant(Data d1, Data d2) {
if (Op::_type_key == Mul::_type_key) {
return d1 * d2;
}
if (Op::_type_key == Div::_type_key) {
return d1 / d2;
}
if (Op::_type_key == Add::_type_key) {
return d1 + d2;
}
CHECK(Op::_type_key == Sub::_type_key);
return d1 - d2;
}
bool IsCandidate(const Expr &e) {
if (!e->IsInstance<Mul>()) {
return false;
}
const Mul *mul = e.as<Mul>();
bool is_left_constant = is_constant(mul->a);
bool is_right_constant = is_constant(mul->b);
if (is_left_constant && is_right_constant) {
return false;
}
return mul->a.type().is_float() && mul->a.type() == mul->b.type();
}
ThreeAddressExprMutator &mutator_;
std::list<Expr> &exprs_;
std::unordered_map<const Object *, std::string> &notation_map_;
std::unordered_map<const Object *, bool> &sign_map_;
};
Array<Expr> args_;
}; // namespace ir
class ExprOptMutator : public IRMutator {
public:
ExprOptMutator(ThreeAddressExprMutator &mutator) : mutator_(mutator) {}
explicit ExprOptMutator(ThreeAddressExprMutator &mutator, const Array<Expr> &args) : mutator_(mutator), args_(args) {}
~ExprOptMutator() override = default;
Expr Mutate(Expr expr) {
expr = IRMutator::Mutate(expr);
exprs_.sort([](Expr &e1, Expr &e2) -> bool {
int dim1 = CountVars(e1);
int dim2 = CountVars(e2);
if (dim1 == dim2) {
return !e1->IsInstance<Mul>();
}
return dim1 < dim2;
IRMutator::Mutate(expr);
std::sort(exprs_.begin(), exprs_.end(), [this](Expr &e1, Expr &e2) -> bool {
bool is_const = is_constant(e1);
if (is_const || is_constant(e2)) {
return !is_const;
}
Array<Expr> args1 = ExprArgsFetcher(args_).GetArgs(e1);
Array<Expr> args2 = ExprArgsFetcher(args_).GetArgs(e2);
if (args1.size() != args2.size()) {
return args1.size() > args2.size();
}
if (sign_map_[e1.get()] != sign_map_[e2.get()]) {
return !sign_map_[e1.get()];
}
return e1->IsInstance<Mul>();
});
InstructionSelector selector(mutator_, exprs_, notation_map_, sign_map_);
for (auto iter = exprs_.rbegin(); iter != exprs_.rend(); ++iter) {
*iter = selector.Mutate(*iter);
if (exprs_.size() < 3) {
return expr;
}
if (is_constant(exprs_[exprs_.size() - 2])) {
return RebuildExpr();
}
Expr e = exprs_.front();
Array<Expr> args = ExprArgsFetcher(args_).GetArgs(e);
e = exprs_[exprs_.size() - 3];
CHECK(sign_map_.find(e.get()) != sign_map_.end());
if (sign_map_[e.get()]) {
e = exprs_[exprs_.size() - 2];
}
if (args.size() > ExprArgsFetcher(args_).GetArgs(e).size()) {
expr = RebuildExpr();
}
return expr;
}
Expr Mutate_(const Select *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Select::make(op->condition, ExprOptMutator(mutator_).Mutate(op->true_value),
ExprOptMutator(mutator_).Mutate(op->false_value));
Expr expr = Select::make(op->condition, ExprOptMutator(mutator_, args_).Mutate(op->true_value),
ExprOptMutator(mutator_, args_).Mutate(op->false_value));
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
......@@ -748,53 +881,9 @@ class ExprOptMutator : public IRMutator {
Expr Mutate_(const Sub *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Mul *op, const Expr &e) {
bool is_left_constant = is_constant(op->a);
bool is_right_constant = is_constant(op->b);
if ((is_left_constant && is_left_constant) || (!is_left_constant && !is_right_constant)) {
return AnalyzeBinaryOpExpr(op, e);
}
Expr non_constant_expr = is_left_constant ? op->b : op->a;
Expr constant_expr = is_left_constant ? op->a : op->b;
Expr Mutate_(const Mul *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
if (non_constant_expr->IsInstance<Add>()) {
const Add *add = non_constant_expr.as<Add>();
if (is_constant(add->a) || is_constant(add->b)) {
Expr expr = Add::make(Mul::make(constant_expr, add->a), Mul::make(constant_expr, add->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
}
if (sign_map_.find(e.get()) != sign_map_.end()) {
sign_map_[expr.get()] = sign_map_[e.get()];
}
return IRMutator::Mutate(expr);
}
}
if (non_constant_expr->IsInstance<Sub>()) {
const Sub *sub = non_constant_expr.as<Sub>();
if (is_constant(sub->a) || is_constant(sub->b)) {
Expr expr = Sub::make(Mul::make(constant_expr, sub->a), Mul::make(constant_expr, sub->b));
if (notation_map_.find(e.get()) == notation_map_.end()) {
notation_map_[expr.get()] = notation_map_[e.get()];
}
if (sign_map_.find(e.get()) != sign_map_.end()) {
sign_map_[expr.get()] = sign_map_[e.get()];
}
return IRMutator::Mutate(expr);
}
}
return AnalyzeBinaryOpExpr(op, e);
}
// Imm / x -> y = Imm; y/x
Expr Mutate_(const Div *op, const Expr &e) {
if (is_constant(op->a) && !is_constant(op->b)) {
Expr expr = Div::make(mutator_.AllocateTmp(op->a), op->b);
const Div *div = expr.as<Div>();
return AnalyzeBinaryOpExpr(div, expr);
}
return AnalyzeBinaryOpExpr(op, e);
}
Expr Mutate_(const Div *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
Expr Mutate_(const Mod *op, const Expr &e) { return AnalyzeBinaryOpExpr(op, e); }
......@@ -824,31 +913,35 @@ class ExprOptMutator : public IRMutator {
Expr Mutate_(const Let *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr =
Let::make(op->var, ExprOptMutator(mutator_).Mutate(op->value), ExprOptMutator(mutator_).Mutate(op->body));
Expr expr = Let::make(op->var, ExprOptMutator(mutator_, args_).Mutate(op->value),
ExprOptMutator(mutator_, args_).Mutate(op->body));
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
Expr Mutate_(const Cast *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Cast::make(op->type, ExprOptMutator(mutator_).Mutate(op->value));
Expr expr = Cast::make(op->type, ExprOptMutator(mutator_, args_).Mutate(op->value));
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
Expr Mutate_(const Not *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Not::make(ExprOptMutator(mutator_).Mutate(op->a));
Expr expr = Not::make(ExprOptMutator(mutator_, args_).Mutate(op->a));
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
Expr Mutate_(const Load *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Load::make(op->type, op->buffer_var, ExprOptMutator(mutator_).Mutate(op->index),
ExprOptMutator(mutator_).Mutate(op->predicate));
Expr expr = Load::make(op->type, op->buffer_var, ExprOptMutator(mutator_, args_).Mutate(op->index),
ExprOptMutator(mutator_, args_).Mutate(op->predicate));
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
......@@ -856,11 +949,12 @@ class ExprOptMutator : public IRMutator {
InitExprStatusIfNeed(e);
Array<Expr> source;
for (Expr src : op->source) {
source.push_back(ExprOptMutator(mutator_).Mutate(src));
source.push_back(ExprOptMutator(mutator_, args_).Mutate(src));
}
Expr expr =
Reduce::make(op->combiner, source, op->axis, ExprOptMutator(mutator_).Mutate(op->condition), op->value_index);
Expr expr = Reduce::make(op->combiner, source, op->axis, ExprOptMutator(mutator_, args_).Mutate(op->condition),
op->value_index);
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
......@@ -868,14 +962,15 @@ class ExprOptMutator : public IRMutator {
InitExprStatusIfNeed(e);
Array<Expr> vectors;
for (Expr v : op->vectors) {
vectors.push_back(ExprOptMutator(mutator_).Mutate(v));
vectors.push_back(ExprOptMutator(mutator_, args_).Mutate(v));
}
Array<Expr> indices;
for (Expr indic : op->indices) {
indices.push_back(ExprOptMutator(mutator_).Mutate(indic));
indices.push_back(ExprOptMutator(mutator_, args_).Mutate(indic));
}
Expr expr = Shuffle::make(vectors, indices);
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
......@@ -883,26 +978,29 @@ class ExprOptMutator : public IRMutator {
InitExprStatusIfNeed(e);
Array<Expr> args;
for (Expr arg : op->args) {
args.push_back(ExprOptMutator(mutator_).Mutate(arg));
args.push_back(ExprOptMutator(mutator_, args_).Mutate(arg));
}
Expr expr = Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index);
mutator_.AddBroadCastCallIfNeed(op, expr);
exprs_.push_back(expr);
mutator_.AddBroadCastCallIfNeed(op, expr);
UpdateExprStatus(e, expr);
return expr;
}
Expr Mutate_(const Ramp *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr =
Ramp::make(ExprOptMutator(mutator_).Mutate(op->base), ExprOptMutator(mutator_).Mutate(op->stride), op->lanes);
Expr expr = Ramp::make(ExprOptMutator(mutator_, args_).Mutate(op->base),
ExprOptMutator(mutator_, args_).Mutate(op->stride), op->lanes);
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
Expr Mutate_(const Broadcast *op, const Expr &e) {
InitExprStatusIfNeed(e);
Expr expr = Broadcast::make(ExprOptMutator(mutator_).Mutate(op->value), op->lanes);
Expr expr = Broadcast::make(ExprOptMutator(mutator_, args_).Mutate(op->value), op->lanes);
exprs_.push_back(expr);
UpdateExprStatus(e, expr);
return expr;
}
......@@ -927,12 +1025,21 @@ class ExprOptMutator : public IRMutator {
}
}
void UpdateExprStatus(const Expr &before, const Expr &after) {
const Object *b = before.get();
const Object *a = after.get();
CHECK(notation_map_.find(b) != notation_map_.end());
notation_map_[a] = notation_map_[b];
CHECK(sign_map_.find(b) != sign_map_.end());
sign_map_[a] = sign_map_[b];
}
bool IsNewRoot(const Expr &e) {
CHECK(notation_map_.find(e.get()) != notation_map_.end());
std::string root = notation_map_[e.get()];
std::string type_key = e->GetTypeKey();
return !((root == Add::_type_key || root == Sub::_type_key) &&
(type_key == Add::_type_key || type_key == Sub::_type_key)) ||
(type_key == Add::_type_key || type_key == Sub::_type_key)) &&
!((root == Mul::_type_key || root == Div::_type_key) &&
(type_key == Mul::_type_key || type_key == Div::_type_key));
}
......@@ -946,7 +1053,7 @@ class ExprOptMutator : public IRMutator {
std::string type_key = e->GetTypeKey();
Expr expr = e;
if (IsNewRoot(e)) {
expr = T::make(ExprOptMutator(mutator_).Mutate(op->a), ExprOptMutator(mutator_).Mutate(op->b));
expr = T::make(ExprOptMutator(mutator_, args_).Mutate(op->a), ExprOptMutator(mutator_, args_).Mutate(op->b));
notation_map_[expr.get()] = root_of_e;
sign_map_[expr.get()] = pos_of_e;
exprs_.push_back(expr);
......@@ -957,6 +1064,7 @@ class ExprOptMutator : public IRMutator {
sign_map_[op->b.get()] = (type_key == Sub::_type_key || type_key == Div::_type_key) ? !pos_of_e : pos_of_e;
expr = T::make(IRMutator::Mutate(op->a), IRMutator::Mutate(op->b));
}
UpdateExprStatus(e, expr);
return expr;
}
......@@ -968,11 +1076,11 @@ class ExprOptMutator : public IRMutator {
Expr RebuildExpr() {
CHECK(!exprs_.empty());
Expr expr = exprs_.front();
exprs_.pop_front();
Expr expr = exprs_.back();
exprs_.pop_back();
while (!exprs_.empty()) {
expr = RebuildExpr(expr, exprs_.front());
exprs_.pop_front();
expr = RebuildExpr(expr, exprs_.back());
exprs_.pop_back();
}
return expr;
}
......@@ -1004,58 +1112,12 @@ class ExprOptMutator : public IRMutator {
}
ThreeAddressExprMutator &mutator_;
std::list<Expr> exprs_;
Array<Expr> args_;
std::vector<Expr> exprs_;
std::unordered_map<const Object *, std::string> notation_map_;
std::unordered_map<const Object *, bool> sign_map_;
};
class LoopMutator : public IRMutator {
public:
LoopMutator() : loop_level_(0) {}
~LoopMutator() override = default;
Stmt Mutate_(const For *op, const Stmt &s) final {
loop_level_++;
loop_vars_.push_front(op);
Stmt stmt = IRMutator::Mutate(op->body);
if (provides_.size() == 1 || provides_.front()->args.size() == provides_.front()->args.size()) {
return s;
}
if (!stmt->IsInstance<For>()) {
provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); });
const Provide *provide = provides_.back();
stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args);
provides_.pop_back();
}
stmt = RebuildForStmt(loop_vars_.back(), stmt);
loop_vars_.pop_back();
loop_level_--;
return stmt;
}
Stmt Mutate_(const Provide *op, const Stmt &s) final {
if (!loop_vars_.empty()) {
provides_.push_back(op);
}
return IRMutator::Mutate_(op, s);
}
private:
Stmt RebuildForStmt(const For *op, Stmt &body) {
Stmt stmt = body;
while (!provides_.empty() && op->loop_var.same_as(provides_.back()->args[0])) {
const Provide *second = provides_.back();
stmt = Block::make(Provide::make(second->func, second->value_index, second->value, second->args), stmt);
provides_.pop_back();
}
return For::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, stmt);
}
size_t loop_level_;
std::list<const Provide *> provides_;
std::list<const For *> loop_vars_;
};
class InferUpperBound {
private:
class Bound {
......@@ -1384,12 +1446,16 @@ class ThreeAddressStmtMutator : public IRMutator {
args_ = args;
static_cast<void>(this->Mutate(op->value));
// mutate according to the result of instruction selection
ThreeAddressExprMutator mutator(output, args, shape, broadcast_, is_reduction, cross_stmt_simplify_);
ThreeAddressExprMutator mutator(output, args, op->args, shape, broadcast_, is_reduction, cross_stmt_simplify_,
is_simple_);
if (cross_stmt_simplify_) {
// Bring over the common exprs from previous stage
mutator.SetCommonExpr(global_common_expr_);
}
value = ExprOptMutator(mutator).Mutate(value);
if (is_simple_) {
value = ExprOptMutator(mutator, args_).Mutate(value);
}
value = InstructionMutator(mutator, args_).Mutate(value);
value = mutator.Mutate(value);
if (cross_stmt_simplify_) {
// Take back the common exprs for next stages
......@@ -1489,11 +1555,46 @@ class ThreeAddressStmtMutator : public IRMutator {
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (loop_level == 0) {
is_simple_ = IsSimpleFor(op);
}
loop_level++;
dom_map[op->loop_var] = Range::make_by_min_extent(op->min, op->extent);
return IRMutator::Mutate_(op, s);
Stmt stmt = IRMutator::Mutate_(op, s);
loop_level--;
if (loop_level == 0) {
is_simple_ = true;
}
return stmt;
}
static bool IsSimpleFor(const For *op) {
if (const For *sub_for = op->body.as<For>()) {
return IsSimpleFor(sub_for);
}
if (const Block *block = op->body.as<Block>()) {
return IsSimpleBlock(block);
}
return op->body->IsInstance<Provide>();
}
private:
static bool IsSimpleBlock(const Block *op) {
if (op->first->IsInstance<Provide>() && op->rest->IsInstance<Provide>()) {
return true;
}
if (op->first->IsInstance<Block>() && op->rest->IsInstance<Block>()) {
return IsSimpleBlock(op->first.as<Block>()) && IsSimpleBlock(op->rest.as<Block>());
}
if (op->first->IsInstance<Provide>() && op->rest->IsInstance<Block>()) {
return IsSimpleBlock(op->rest.as<Block>());
}
if (op->first->IsInstance<Block>() && op->rest->IsInstance<Provide>()) {
return IsSimpleBlock(op->first.as<Block>());
}
return false;
}
std::unordered_map<Tensor, std::vector<Tensor>> split_to_;
std::unordered_map<FunctionRef, std::set<int>, air::NodeHash, air::NodeEqual> op_indices_;
......@@ -1504,6 +1605,9 @@ class ThreeAddressStmtMutator : public IRMutator {
std::unordered_map<size_t, std::pair<Expr, Expr>> global_common_expr_;
int loop_level{0};
bool is_simple_{true};
// mark broadcast
Tensor output_;
Array<Expr> args_;
......@@ -1513,8 +1617,83 @@ class ThreeAddressStmtMutator : public IRMutator {
bool cross_stmt_simplify_;
};
class LoopMutator : public IRMutator {
public:
Stmt Mutate_(const For *op, const Stmt &s) final {
if (loop_vars_.empty() && !ThreeAddressStmtMutator::IsSimpleFor(op)) {
return s;
}
loop_vars_.push_back(op);
Stmt stmt = IRMutator::Mutate(op->body);
if (!provides_.empty()) {
provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); });
while (!provides_.empty()) {
SplitProides();
}
}
for (size_t index = 0; index < stmts_.size(); ++index) {
if (IsContain(args_[index], loop_vars_.back()->loop_var)) {
stmts_[index] = For::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, stmts_[index]);
}
}
loop_vars_.pop_back();
if (loop_vars_.empty()) {
stmt = stmts_.back();
for (auto iter = ++stmts_.rbegin(); iter != stmts_.rend(); iter++) {
stmt = Block::make(*iter, stmt);
}
stmts_.clear();
args_.clear();
}
return stmt;
}
Stmt Mutate_(const Provide *op, const Stmt &s) final {
if (!loop_vars_.empty()) {
provides_.push_back(op);
}
return IRMutator::Mutate_(op, s);
}
private:
void SplitProides() {
const Provide *provide = provides_.back();
Stmt stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args);
provides_.pop_back();
while (!provides_.empty()) {
const Provide *next = provides_.back();
if (provide->args.size() != next->args.size()) {
break;
}
stmt = Block::make(Provide::make(next->func, next->value_index, next->value, next->args), stmt);
provides_.pop_back();
}
stmts_.insert(stmts_.begin(), stmt);
args_.insert(args_.begin(), provide->args);
}
bool IsContain(const Array<Expr> &args, const Var &var) {
VarSet all_vars;
for (Expr e : args) {
GatherVars(e, &all_vars);
}
for (auto v : all_vars) {
if (v.same_as(var)) {
return true;
}
}
return false;
}
std::list<const For *> loop_vars_{};
std::list<const Provide *> provides_{};
std::vector<Stmt> stmts_{};
std::vector<Array<Expr>> args_{};
};
Stmt ToThreeAddress(Stmt stmt, bool reuse_variable, int minimum_split, bool cross_stmt_simplify) {
stmt = ThreeAddressStmtMutator(reuse_variable, minimum_split, cross_stmt_simplify).Mutate(stmt);
stmt = LoopMutator().Mutate(stmt);
return Simplify_cce(stmt);
}
} // namespace ir
......
......@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) {
UTTensorElementHelper th({16, 32, 1024});
using Add = air::ir::Add;
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
air::Expr expr =
Add::make(
Add::make(
Add::make(th.Elem("a", 2), th.Elem("b", 1)),
th.Elem("c", 3)),
th.Elem("d", 1));
air::Expr expr = Add::make(Add::make(Add::make(th.Elem("a", 2), th.Elem("b", 1)), th.Elem("c", 3)), th.Elem("d", 1));
std::string dump_expr = UTDumpHelper::Dump(expr);
EXPECT_EQ(dump_expr, "(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))");
}
......@@ -49,12 +44,12 @@ TEST(ToThreeAddressTest, BuildCase1) {
class ThreeAddressExprMutatorTest : public testing::Test {
public:
ThreeAddressExprMutatorTest()
: mutator_(air::TensorNode::make(
UTExprBuilder::CreateShape(shape_), // shape
: mutator_(air::TensorNode::make(UTExprBuilder::CreateShape(shape_), // shape
dtype_, // dtype
UTExprBuilder::PlaceholderOpNode("out", shape_), // op
0), // index
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateShape(shape_), // shape
std::unordered_set<const Call *>(), // broadcast
false, // IsReductionOp
......@@ -76,9 +71,7 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
class PassTestToThreeAddress1 : public ::testing::Test {
public:
PassTestToThreeAddress1() {
Construct();
}
PassTestToThreeAddress1() { Construct(); }
~PassTestToThreeAddress1() = default;
void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
......@@ -88,17 +81,15 @@ class PassTestToThreeAddress1 : public ::testing::Test {
stmt = air::ir::AttrStmt::make(
out_, "", UTExprBuilder::IntImm(1),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_,
air::ir::ProducerConsumer::make(out_, true,
out_, air::ir::ProducerConsumer::make(
out_, true,
UTStmtBuilder::CreateFor(
"i", 0, 32,
UTStmtBuilder::CreateFor(
"j", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i", "j"},
air::ir::Add::make(
UTExprBuilder::ElementOf(a_, {"j"}),
UTExprBuilder::ElementOf(b_, {"i", "j"})),
air::ir::Add::make(UTExprBuilder::ElementOf(a_, {"j"}), UTExprBuilder::ElementOf(b_, {"i", "j"})),
UTExprBuilder::ElementOf(c_, {"j"})))))));
}
......@@ -110,7 +101,7 @@ class PassTestToThreeAddress1 : public ::testing::Test {
}; // class PassTestToThreeAddress1
TEST_F(PassTestToThreeAddress1, CaseCheck) {
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> infos_lhs =
UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "out(i, j)");
......@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) {
* out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j))
*/
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info1 =
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info1 =
UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)");
ASSERT_EQ(info1.size(), 1);
std::string dump_b_target = std::get<0>(info1[0]);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info2 =
UTProvideCheckerForBinary().Find(
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info2 =
UTProvideCheckerForBinary().Find(stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
ASSERT_EQ(info2.size(), 1);
std::string dump_sum1_target = std::get<0>(info2[0]);
EXPECT_EQ(std::get<2>(info2[0]), 32 * 1024);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info3 =
UTProvideCheckerForBinary().Find(
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info3 =
UTProvideCheckerForBinary().Find(stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
ASSERT_EQ(info3.size(), 1);
EXPECT_EQ(std::get<0>(info3[0]), "out(i, j)");
EXPECT_EQ(std::get<2>(info3[0]), 32 * 1024);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册