提交 4d1be48d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!53 support dynamical memory allocation ratio adjustment in micro-tuning for...

!53 support dynamical memory allocation ratio adjustment in micro-tuning for allocation exceed problem
Merge pull request !53 from yangsijia/feature/micro-tuning
......@@ -116,8 +116,8 @@ def four2five_tiling_strategy_dynamic(tensor, input_format):
strategy.append(ct_util.create_constraint_on_tensor(tensor, 16, ct_util.TileConstraint.FACTOR, 4)[0])
return strategy
@vc_util.check_input_type(akg.tvm.tensor.Tensor, str, str)
def four2five(data, format_, dst_dtype='float16'):
@vc_util.check_input_type(akg.tvm.tensor.Tensor, str, str, bool)
def four2five(data, format_, dst_dtype='float16', need_custom_tiling=True):
"""
Convert 4-dims "data" to 5-dims,the format of "data" is defined in "format_"
......@@ -294,8 +294,9 @@ def four2five(data, format_, dst_dtype='float16'):
dim_info, _ = four2five_set_dim_func(data, format_, dst_dtype)
if dim_info != "":
attrs["dim"] = dim_info
if need_custom_tiling:
attrs["custom_tiling"] = four2five_tiling_strategy(output, format_, expansion)
else:
elif need_custom_tiling:
attrs["custom_tiling"] = four2five_tiling_strategy_dynamic(output, format_)
if is_dynamic:
......
......@@ -458,7 +458,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
PassTimer *pass_timer = PassTimer::GetInstance();
global_attrs.Set(kKernelName, StringImm::make(name));
global_attrs.Set(kDumpPassIr, ktvm::make_const(Int(1), config->dump_pass_ir));
global_attrs.Set(kDumpPassIr, ktvm::make_const(Int(32), config->dump_pass_ir));
if (config->dump_pass_ir) {
std::string dump_ir_dir;
if (global_attrs.GetStringAttr(kDumpIrDir, &dump_ir_dir)) {
......@@ -498,7 +498,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(RenameRealize, stmt, binds_0, replace);
bool is_dynamic = !shape_vars.empty();
global_attrs.Set(kIsDynamic, ktvm::make_const(Int(1), is_dynamic));
global_attrs.Set(kIsDynamic, ktvm::make_const(Int(32), is_dynamic));
Array<NodeRef> arg_list_1;
Map<Tensor, Buffer> binds_1;
......@@ -594,7 +594,17 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
NodeRef tuning_spaces = NEXT_PASS(GenTuningSpace, stmt, binds_0, attrs_1, false);
return tuning_spaces;
}
Array<NodeRef> poly_res = NEXT_PASS(AutoPoly, stmt, binds_0, global_attrs, false, is_dynamic);
}
// micro-tuning configs: current strategy is to retry autopoly up to 3 times when storage flatten/rewrite fails
bool need_micro_tuning = !aicpu && polyhedral && !is_dynamic && global_attrs.GetStringAttr("dim", "").empty();
const int max_enter_poly_times = global_attrs.GetIntAttr(kMaxNumRetryPoly, need_micro_tuning ? 4 : 1);
int enter_count = 0;
Stmt stmt_before_poly = stmt;
while (enter_count < max_enter_poly_times) {
if (!aicpu && polyhedral) {
Array<NodeRef> poly_res = NEXT_PASS(AutoPoly, stmt_before_poly, binds_0, global_attrs, false, is_dynamic);
enter_count++;
CHECK_EQ(poly_res.size(), 2);
stmt = ktvm::Downcast<Stmt>(poly_res[0]);
Array<ktvm::Var> tiling_params = ktvm::Downcast<Array<ktvm::Var>>(poly_res[1]);
......@@ -665,8 +675,15 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(ConvertIfToSelect, stmt);
}
}
try {
stmt = NEXT_PASS(StorageFlatten, stmt, binds_0, 64);
} catch (const std::runtime_error &e) {
if (enter_count >= max_enter_poly_times) {
CHECK(false) << e.what();
}
global_attrs.Set(kErrorInfo, StringImm::make(e.what()));
continue;
}
stmt = NEXT_PASS(DmaFlatten, stmt, global_attrs.GetBoolAttr(kTileSizeIsVar, false));
if (global_attrs.GetBoolAttr(kAlgebraSimplify, false) && is_dynamic) {
stmt = NEXT_PASS(AlgebraSimplify, stmt);
......@@ -814,7 +831,18 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
bool bc_no_limits = false;
// timeout for MaxSAT solver in seconds (int)
int maxsat_timeout = 4;
try {
stmt = NEXT_PASS(StorageRewriteCCE, stmt, maxsat_filename, use_bc_opt, bc_no_limits, maxsat_timeout);
} catch (MemoryAllocationException &e) {
if (enter_count >= max_enter_poly_times) {
CHECK(false) << e.what();
}
global_attrs.Set(kAllocBits, ktvm::make_const(Int(32), e.alloc_bits_ + e.need_bits_));
global_attrs.Set(kErrorScope, StringImm::make(e.scope_));
continue;
}
break;
}
if (!is_dynamic)
stmt = NEXT_PASS(UnrollLoop, stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
......
......@@ -98,7 +98,13 @@ int AttrMap::GetIntAttr(const std::string &attr_name, int dft_value) {
const NodeRef &e = this->at(attr_name);
return ir::GetInt32Const(Downcast<Expr>(e));
}
double AttrMap::GetFloatAttr(const std::string &attr_name, double dft_value) {
if (this->count(attr_name) == 0) {
return dft_value;
}
const NodeRef &e = this->at(attr_name);
return ir::GetFloatConst(Downcast<Expr>(e));
}
bool AttrMap::GetBoolAttr(const std::string &attr_name, bool dft_value) {
int result = GetIntAttr(attr_name, static_cast<int>(dft_value));
CHECK(result == 0 || result == 1) << "Bool attribute " << attr_name << " must be 0 or 1, but found "
......
......@@ -91,6 +91,11 @@ constexpr auto kEnableRemoveBroadcastCopy = "enable_remove_broadcast_copy";
constexpr auto kEnableSubstituteDivVar = "enable_divide_var";
constexpr auto kEnableComputeInPlace = "enable_compute_in_place";
constexpr auto kEnableRewriteScalarCompute = "enable_rewrite_scalar_compute";
constexpr auto kMaxNumRetryPoly = "max_num_retry_poly";
constexpr auto kUBRatio = "ub_ratio";
constexpr auto kErrorInfo = "";
constexpr auto kErrorScope = "";
constexpr auto kAllocBits = "alloc_bits";
static std::unordered_map<std::string, int> help_tiling_level = {
{"None", 0},
......@@ -109,7 +114,7 @@ class AttrMap : public Map<std::string, NodeRef> {
bool GetBoolAttr(const std::string &attr_name, bool dft_value);
int GetIntAttr(const std::string &attr_name, int dft_value);
double GetFloatAttr(const std::string &attr_name, double dft_value);
bool GetStringAttr(const std::string &attr_name, std::string *attr_to_set);
std::string GetStringAttr(const std::string &attr_name, const std::string &dft_value);
};
......
......@@ -18,11 +18,33 @@
#define INCLUDE_AKG_BUILD_MODULE_H_
#include <string>
#include <exception>
#include "codegen/util.h"
namespace akg {
extern AttrMap global_attrs;
/*
* Custom exception used when memory allocation fails and triggers micro-tuning to try to recover from failure.
*/
class MemoryAllocationException : public std::exception {
public:
MemoryAllocationException(const std::string &scope, uint64_t need_bits, uint64_t alloc_bits)
: scope_(scope), need_bits_(need_bits), alloc_bits_(alloc_bits){};
const char *what() const throw() {
std::runtime_error re(("Allocation exceed bound of memory tag " + scope_ + ": need " + std::to_string(need_bits_) +
" bits, total alloc " + std::to_string(alloc_bits_) + " bits.")
.c_str());
return re.what();
}
std::string scope_{""};
uint64_t need_bits_{0};
uint64_t alloc_bits_{0};
};
NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name,
const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode,
bool polyhedral, bool tuning, bool aicpu, const BuildConfig &config);
......
......@@ -26,6 +26,7 @@
#include <regex>
#include "ir_pass.h"
#include "build_module.h"
#include "pass/ir_util.h"
#include "emit_insn/insn_info.h"
#include "pass/storage_rewrite_cce.h"
......@@ -1146,8 +1147,7 @@ bool StoragePlanRewriterCCE::DoRewrite(const std::string scope, std::vector<std:
}
if (spec_level <= 0 || child_idx < 0) {
if (!is_dynamic_) {
LOG(FATAL) << "Allocation exceed bound of memory tag " << scope << ": need " << need_nbits
<< " bits, total alloc " << total_alloc_bits << " bits";
throw MemoryAllocationException(scope, need_nbits, total_alloc_bits);
} else {
LOG(WARNING) << "Dynamic shape static allocation exceed bound of memory tag " << scope << ": need "
<< need_nbits << " bits, will use dynamic allocation instead";
......
......@@ -16,11 +16,63 @@
*/
#include "poly/tiling_solver.h"
#include "build_module.h"
namespace akg {
namespace ir {
namespace poly {
/*
* This function parse StorageFlatten error info into a ratio that guides the auto tiling to reduce
* memory allocation.
* e.g.
* error info : Check failed: const_size * op->type.bits() <= info->max_num_bits (5242880 vs. 2097152) :
* Allocation exceed bound of memory tag local.UB.
* ratio : memory_size / alloc_size = (2097152 / 5242880) = 0.4, which means the total allocation
* size used in auto tiling shoulde reduce 0.4 times.
*/
double TilingSolver::GetNewAllocRatioWhenFlattenFail(const std::string &error_info) {
std::vector<std::string> sub_strs;
sub_strs = akg::common::Split(error_info, "(");
CHECK_GE(sub_strs.size(), 2U);
std::string tmp_str = sub_strs[2];
sub_strs = akg::common::Split(tmp_str, " ");
CHECK(!sub_strs.empty());
auto alloc_bits = static_cast<double>(std::strtod(sub_strs[0].c_str(), nullptr));
sub_strs = akg::common::Split(error_info, ")");
CHECK_GE(sub_strs.size(), 1U);
tmp_str = sub_strs[1];
sub_strs = akg::common::Split(tmp_str, " ");
CHECK(!sub_strs.empty());
auto memory_bits = static_cast<double>(std::strtod(sub_strs.back().c_str(), nullptr));
CHECK_NE(alloc_bits, 0);
return memory_bits / alloc_bits;
}
/*
* This function returns an adjust ratio that further reduces the memory allocation limit apart from
* the default percentage reserved for auto double buffer and try to generate smaller tile sizes that
* helps to recover from memory allocation failure such as the one in storage rewrite cce pass.
*/
double TilingSolver::GetNewAllocRatioWhenRewriteFail(int64_t memory_bits) {
auto actual_allocs = global_attrs.GetFloatAttr(kAllocBits, 0.0);
auto last_adjust_ratio = global_attrs.GetFloatAttr(kUBRatio, 1.0);
auto adjust_ratio = 1.0;
if (actual_allocs != 0) {
std::stringstream ss;
auto expect_allocs = memory_bits * last_adjust_ratio;
adjust_ratio = (expect_allocs / actual_allocs);
ss << "Adjust memory allocation ratio to " << adjust_ratio << " times and retry tiling.";
global_attrs.Set(kUBRatio, ktvm::make_const(Float(32), adjust_ratio));
analyzer_.logger_.AppendLog(MICRO_TUNING, ss);
}
return adjust_ratio;
}
void TilingSolver::CollectMemoryLimit() {
// Init memory allocation percentage.
percentage_ = ALLOCATION_PERCENTAGE;
for (auto attr : analyzer_.RootAxis()->attrs) {
if (attr.attr_key != "MEM_RATIO") continue;
......@@ -29,9 +81,27 @@ void TilingSolver::CollectMemoryLimit() {
break;
}
// Handle previous error info if storage flatten fails and adjust allocation percentage.
auto error_info = global_attrs.GetStringAttr(kErrorInfo, "");
if (!error_info.empty() && error_info.find("storage_flatten") != std::string::npos) {
std::stringstream ss;
ss << "Get Error Info! -> " << global_attrs.GetStringAttr(kErrorInfo, "");
percentage_ = percentage_ * GetNewAllocRatioWhenFlattenFail(error_info);
ss << "Adjust memory allocation to " << percentage_ << " of memory size and retry tiling.";
global_attrs.Set(kErrorInfo, StringImm::make(""));
analyzer_.logger_.AppendLog(MICRO_TUNING, ss);
}
// Init memory limit for each scope and reduce ratio of local.UB if storage rewrite fails previously.
DavinciInfo &d_info = DavinciInfo::GetInstance();
auto error_scope = global_attrs.GetStringAttr(kErrorScope, "");
for (auto i = 0; i < MEM_SCOPE_BULK; ++i) {
this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage_;
if (i == DavinciMemScope::MEM_SCOPE_UB && error_scope == "local.UB") {
this->mem_limit_[i] =
std::max(static_cast<int>(this->mem_limit_[i] * GetNewAllocRatioWhenRewriteFail(this->mem_limit_[i])), 1);
global_attrs.Set(kErrorScope, StringImm::make(""));
}
}
}
......
......@@ -30,6 +30,8 @@ class TilingSolver {
~TilingSolver() {}
void CollectMemoryLimit();
void CollectTileAxisTopDown();
double GetNewAllocRatioWhenFlattenFail(const std::string &error_info);
double GetNewAllocRatioWhenRewriteFail(int64_t memory_bits);
TileCandidate *Solve();
TilingAnalyzer &analyzer_;
......
......@@ -29,6 +29,8 @@ void TileLogger::AppendLine(LogStage stage, const std::string &line) {
analyze_tiling_space_stage_.emplace_back(line);
} else if (stage == DO_TILING) {
do_tiling_stage_.emplace_back(line);
} else if (stage == MICRO_TUNING) {
micro_tuning_strage_.emplace_back(line);
} else {
do_tuning_stage_.emplace_back(line);
}
......@@ -70,6 +72,11 @@ bool TileLogger::DumpLogFile() {
of << line << std::endl;
}
of << "=========================" << std::endl;
of << ">>>>>>>>>> Micro tuning stage <<<<<<<<<<<<" << std::endl;
for (const auto &line : micro_tuning_strage_) {
of << line << std::endl;
}
of << "=========================" << std::endl;
of.close();
return true;
}
......
......@@ -32,7 +32,7 @@ enum DavinciMemScope {
MEM_SCOPE_L0C,
MEM_SCOPE_BULK,
};
enum LogStage { ANA_SCHETREE, ANA_BUF_LIVE_EXTENT, ANA_TILING_SPACE, DO_TILING, DO_TUNING };
enum LogStage { ANA_SCHETREE, ANA_BUF_LIVE_EXTENT, ANA_TILING_SPACE, DO_TILING, DO_TUNING, MICRO_TUNING };
class DavinciInfo {
public:
......@@ -89,6 +89,7 @@ class TileLogger {
LogFile analyze_tiling_space_stage_;
LogFile do_tiling_stage_;
LogFile do_tuning_stage_;
LogFile micro_tuning_strage_;
};
} // namespace poly
} // namespace ir
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""unittest for micro-tuning"""
from akg.utils import kernel_exec
from akg.ops.array import four2five
def test_four2five_without_custom_tiling(build_shape, dtype, op_attrs):
"""This test case will fail without cunstom tiling and micro-tuning will automatically adjust tile sizes."""
build_attr = op_attrs + [False]
return kernel_exec.op_build_test(four2five.four2five, [build_shape], [dtype], build_attr, kernel_name="four2five", attrs={}, tuning=False)
if __name__ == "__main__":
test_four2five_without_custom_tiling(
[32, 1001, 1, 1], "float16", ['NCHW', 'float16'])
......@@ -22,6 +22,7 @@ casefiles=(
"pass/test_promote_if.py"
"pass/test_sink_if.py"
"pass/test_ir_parser.py"
"pass/test_micro_tuning.py"
"pass/test_elim_vector_mask.py"
"pass/test_copy_propagation.py"
"pass/test_utils_detect_non_linear_index.py"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册