提交 54141f8a 编写于 作者: L lyfne

Improve EquivFormat op

上级 6a84977e
......@@ -651,4 +651,21 @@ TensorAdd12.InplaceAssign1516.InplaceAssign11220.1721': {
},
},
},
# Mul_TensorAdd
'2.EquivFormat.Mul13.EquivFormat3.TensorAdd12.5': {
'16_16_32_32_16_16.16_1_1_512.16_16_32_32_16_16': {
'float16--': {
'metadata': {
'attrs': {
'enable_mark_multi_core': True,
'multicore_loop_switch_hoist': False,
'multicore_scalar_rearrange': True,
'enable_post_poly_loop_partition': False,
},
},
'dim': '0 0 1 1 0 1 2 1 0 2 16 1 0 3 1 1 0 4 32 1 0 5 16 1'
},
},
},
}
......@@ -787,9 +787,9 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
// must be after EmitInsn
stmt = NEXT_PASS(TileCoverCorrect, stmt);
if (global_attrs.GetBoolAttr(kEnableCoverProtectOptimize, true) && !is_dynamic) {
// simulated blocks > 2 400 000 => simulated case takes too much time (> 100 sec)
// number of protections > 512 => too many brackets in the if statement throw an error
stmt = NEXT_PASS(CoverProtection, stmt, 2400000, 512);
// simulated blocks > 240 000 => simulated case takes too much time (> 10 sec)
// number of protections > 128 => too many brackets in the if statement throw an error
stmt = NEXT_PASS(CoverProtection, stmt, 240000, 128);
}
stmt = NEXT_PASS(ConvertDivModToShift, stmt);
if (!polyhedral || global_attrs.GetBoolAttr(kCoarsenImg2Col, false)) {
......
......@@ -636,9 +636,22 @@ TVM_REGISTER_GLOBAL("InplaceAssign").set_body([](TVMArgs args, TVMRetValue *rv)
TVM_REGISTER_GLOBAL("EquivFormat").set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_GE(args.size(), 1);
auto inputs = args[0].operator Array<NodeRef>();
CHECK(inputs[0]->IsInstance<TensorNode>());
auto ref = Downcast<Tensor>(inputs[0]);
*rv = ref;
if (inputs[0]->IsInstance<TensorNode>()) {
auto ref = [](NodeRef attr) -> Array<Expr> {
auto shape = Downcast<Array<Integer>>(attr);
CHECK(!shape.empty());
Array<Expr> newshape;
for (auto s : shape) {
newshape.push_back(s);
}
return newshape;
};
TOPI_ONE_INPUT_ONE_ATTR_CALL(args, rv, topi::reshape, ref);
} else {
Array<Expr> shape = {Expr(1)};
*rv = compute(shape, [&](const Array<Var> &indices) { return Downcast<Expr>(inputs[0]); });
}
});
TVM_REGISTER_GLOBAL("AddMinValue").set_body([](TVMArgs args, TVMRetValue *rv) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册