diff --git a/python/akg/composite/repository.py b/python/akg/composite/repository.py index 4785b332dd748aa7c9ae50c1b8dd22e34d2e3de0..26613873f59a57c7a5321522669df2e665e86b06 100644 --- a/python/akg/composite/repository.py +++ b/python/akg/composite/repository.py @@ -651,4 +651,21 @@ TensorAdd12.InplaceAssign1516.InplaceAssign11220.1721': { }, }, }, + + # Mul_TensorAdd + '2.EquivFormat.Mul13.EquivFormat3.TensorAdd12.5': { + '16_16_32_32_16_16.16_1_1_512.16_16_32_32_16_16': { + 'float16--': { + 'metadata': { + 'attrs': { + 'enable_mark_multi_core': True, + 'multicore_loop_switch_hoist': False, + 'multicore_scalar_rearrange': True, + 'enable_post_poly_loop_partition': False, + }, + }, + 'dim': '0 0 1 1 0 1 2 1 0 2 16 1 0 3 1 1 0 4 32 1 0 5 16 1' + }, + }, + }, } diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 15030ab426a94c898c2b3ffcbd4920cc7309e624..6a5bbbafcf158aff2050461e3f01ac3e71110e1c 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -787,9 +787,9 @@ NodeRef Lower(Schedule sch, const Array &in_args, const Array // must be after EmitInsn stmt = NEXT_PASS(TileCoverCorrect, stmt); if (global_attrs.GetBoolAttr(kEnableCoverProtectOptimize, true) && !is_dynamic) { - // simulated blocks > 2 400 000 => simulated case takes too much time (> 100 sec) - // number of protections > 512 => too many brackets in the if statement throw an error - stmt = NEXT_PASS(CoverProtection, stmt, 2400000, 512); + // simulated blocks > 240 000 => simulated case takes too much time (> 10 sec) + // number of protections > 128 => too many brackets in the if statement throw an error + stmt = NEXT_PASS(CoverProtection, stmt, 240000, 128); } stmt = NEXT_PASS(ConvertDivModToShift, stmt); if (!polyhedral || global_attrs.GetBoolAttr(kCoarsenImg2Col, false)) { diff --git a/src/composite/composite_topi.cc b/src/composite/composite_topi.cc index 84fa808fd34f836afda58c0abfffdc2d0f582923..02be146009c93a39b1437919e77b0209224af1e3 100644 --- a/src/composite/composite_topi.cc +++ b/src/composite/composite_topi.cc @@ -636,9 +636,22 @@ TVM_REGISTER_GLOBAL("InplaceAssign").set_body([](TVMArgs args, TVMRetValue *rv) TVM_REGISTER_GLOBAL("EquivFormat").set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_GE(args.size(), 1); auto inputs = args[0].operator Array(); - CHECK(inputs[0]->IsInstance()); - auto ref = Downcast(inputs[0]); - *rv = ref; + if (inputs[0]->IsInstance()) { + auto ref = [](NodeRef attr) -> Array { + auto shape = Downcast>(attr); + CHECK(!shape.empty()); + Array newshape; + for (auto s : shape) { + newshape.push_back(s); + } + return newshape; + }; + + TOPI_ONE_INPUT_ONE_ATTR_CALL(args, rv, topi::reshape, ref); + } else { + Array shape = {Expr(1)}; + *rv = compute(shape, [&](const Array &indices) { return Downcast(inputs[0]); }); + } }); TVM_REGISTER_GLOBAL("AddMinValue").set_body([](TVMArgs args, TVMRetValue *rv) {