提交 f9d521e9 编写于 作者: D dabaiji

refactor build module to support gpu

上级 6a84977e
......@@ -80,9 +80,9 @@ def build_config(**kwargs):
@vc_util.check_input_type(schedule.Schedule, (list, tuple), (list, tuple), str,
(dict, type(None)), (dict, type(None)), bool, bool, bool, bool)
(dict, type(None)), (dict, type(None)), bool, bool, bool, str)
def lower(sch, args, shape_params=None, name="default_function", binds=None, attrs=None,
simple_mode=False, polyhedral=False, tuning=False, aicpu=False):
simple_mode=False, polyhedral=False, tuning=False, target="cce"):
"""Lowering function."""
tmp_binds = None
if binds is not None:
......@@ -96,7 +96,7 @@ def lower(sch, args, shape_params=None, name="default_function", binds=None, att
cfg = _api_internal._GetCurrentBuildConfig()
ret = _api_internal._Lower(sch, args, shape_params, name,
tmp_binds, tmp_attrs, simple_mode,
polyhedral, tuning, aicpu, cfg)
polyhedral, tuning, target, cfg)
level = tmp_attrs.get("help_tiling")
if tuning or (level is not None and level > help_tiling_level['None']):
......@@ -116,9 +116,9 @@ def lower(sch, args, shape_params=None, name="default_function", binds=None, att
@vc_util.check_input_type(schedule.Schedule, (list, tuple), (list, tuple, type(None)), str,
(dict, type(None)), (dict, type(None)), bool, bool)
(dict, type(None)), (dict, type(None)), bool, str)
def build_to_func(inputs, args, shape_params=None, name="default_function",
binds=None, attrs=None, polyhedral=False, aicpu=False):
binds=None, attrs=None, polyhedral=False, target="cce"):
"""Build module."""
tmp_binds = None
if binds is not None:
......@@ -132,14 +132,13 @@ def build_to_func(inputs, args, shape_params=None, name="default_function",
shape_params = []
cfg = _api_internal._GetCurrentBuildConfig()
return _api_internal._BuildToFunc(inputs, args, shape_params, name, tmp_binds, tmp_attrs,
polyhedral, aicpu, cfg)
polyhedral, target, cfg)
@vc_util.check_input_type(schedule.Schedule, (list, tuple), (str, type(None)), (list, tuple), str,
(dict, type(None)), (dict, type(None)), bool, bool)
def build(inputs, args, target=None, shape_params=None, name="default_function",
binds=None, attrs=None, polyhedral=False, aicpu=False):
@vc_util.check_input_type(schedule.Schedule, (list, tuple), str, (list, tuple), str,
(dict, type(None)), (dict, type(None)), bool)
def build(inputs, args, target='cce', shape_params=None, name="default_function",
binds=None, attrs=None, polyhedral=False):
tmp_rst = build_to_func(inputs, args, shape_params=shape_params, name=name, binds=binds,
attrs=attrs, polyhedral=polyhedral, aicpu=aicpu)
attrs=attrs, polyhedral=polyhedral, target=target)
tmp_target = target if target is not None else 'cce'
return _api_internal._BuildToModule(tmp_rst, tmp_target)
return _api_internal._BuildToModule(tmp_rst, target)
......@@ -42,7 +42,6 @@ def op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_na
logging.error("Device %s is not in [aicore, aicpu].", device)
return None
aicpu = device == "aicpu"
polyhedral = True
dump_ir = os.getenv(MS_AKG_DUMP_IR) == "on"
......@@ -57,9 +56,9 @@ def op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_na
if attrs:
binds = attrs.pop(BINDS, None)
rst = akg.build_to_func(s, args, name=kernel_name, attrs=attrs, polyhedral=polyhedral,
binds=binds, aicpu=aicpu)
binds=binds, target=device)
else:
rst = akg.build_to_func(s, args, name=kernel_name, polyhedral=polyhedral, aicpu=aicpu)
rst = akg.build_to_func(s, args, name=kernel_name, polyhedral=polyhedral, target=device)
except Exception:
logging.error(traceback.format_exc())
......
......@@ -724,13 +724,14 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
if TensorUtils.is_output_value(output):
op_var = op_var + [output]
if sch_tmpl != None:
assert(sch_tmpl['target'] == 'cuda')
if sch_tmpl is not None:
if sch_tmpl['target'] != 'cuda':
raise ValueError("Only support cuda as target when using schedule template.")
kernel_name = kernel_name if kernel_name != "" else sch_tmpl['op_name']
with akg.tvm.target.cuda() as target:
s = sch_tmpl['schedule'](sch_tmpl['output'])
with akg.tvm.build_config(dump_pass_ir = True):
mod = akg.tvm.build(s, op_var, target, target_host = 'stackvm', name = kernel_name)
with akg.build_config(dump_pass_ir=True):
mod = akg.build(s, op_var, "cuda", shape_var, name=kernel_name, attrs=attrs, polyhedral=polyhedral, binds=binds)
dump_cuda_meta.dump(mod, kernel_name, s, op_var)
return mod
......
......@@ -436,7 +436,7 @@ void FixParametricBinds(const Map<Tensor, Buffer> &binds, const Array<NodeRef> &
NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name,
const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode,
bool polyhedral, bool tuning, bool aicpu, const BuildConfig &config) {
bool polyhedral, bool tuning, const std::string &target, const BuildConfig &config) {
ir::TestExprCompuationSimplify();
CHECK(sch.defined()) << "sch is not defined.";
CHECK(!name.empty()) << "name is empty.";
......@@ -486,6 +486,41 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
auto new_sch = sch.normalize();
auto bounds = air::schedule::InferBound(new_sch);
Stmt stmt = make_pass("schedule.ScheduleOps", new_sch, bounds, false);
if (target == "cuda") {
// Phase 1
stmt = NEXT_PASS(RewriteForTensorCore, stmt, new_sch, binds_0);
stmt = NEXT_PASS(StorageFlatten, stmt, binds_0, 64, config->instrument_bound_checkers);
stmt = NEXT_PASS(CanonicalSimplify, stmt);
// Phase 2
if (!simple_mode) {
stmt = NEXT_PASS(LoopPartition, stmt, config->partition_const_loop);
}
if (config->disable_vectorize) {
stmt = NEXT_PASS(SkipVectorize, stmt);
} else {
stmt = NEXT_PASS(VectorizeLoop, stmt);
}
stmt = NEXT_PASS(InjectVirtualThread, stmt);
stmt = NEXT_PASS(InjectDoubleBuffer, stmt, config->double_buffer_split_loop);
stmt = NEXT_PASS(StorageRewrite, stmt);
stmt = NEXT_PASS(UnrollLoop, stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
config->auto_unroll_max_extent, config->unroll_explicit);
// Phase 3
stmt = NEXT_PASS(Simplify, stmt);
stmt = NEXT_PASS(RemoveNoOp, stmt);
if (config->instrument_bound_checkers) {
stmt = NEXT_PASS(InstrumentBoundCheckers, stmt);
}
if (simple_mode) {
return stmt;
}
LoweredFunc lowered_func = NEXT_PASS(MakeAPI, stmt, name, arg_list_0, 0, config->restricted_func);
return lowered_func;
}
if (!polyhedral) {
// for conv-matmul manual schedule
stmt = NEXT_PASS(AutoMadPragmaAttr, stmt, true);
......@@ -518,7 +553,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
PassMgr::SetArgs(arg_list_0);
if (!aicpu) {
if (target != "aicpu") {
stmt = NEXT_PASS(MathIntrinRewrite, stmt);
}
......@@ -527,7 +562,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
}
// Phase 1
if (!aicpu && polyhedral) {
if (target != "aicpu" && polyhedral) {
stmt = NEXT_PASS(UnifyLoopVars, stmt, binds_0, arg_list_0);
stmt = NEXT_PASS(CheckShapeParams, stmt, binds_0);
stmt = NEXT_PASS(AlignPartitionCCE, stmt);
......@@ -597,12 +632,13 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
}
// micro-tuning configs: current strategy is to retry autopoly up to 3 times when storage flatten/rewrite fails
bool need_micro_tuning = !aicpu && polyhedral && !is_dynamic && global_attrs.GetStringAttr("dim", "").empty();
bool need_micro_tuning =
target != "aicpu" && polyhedral && !is_dynamic && global_attrs.GetStringAttr("dim", "").empty();
const int max_enter_poly_times = global_attrs.GetIntAttr(kMaxNumRetryPoly, need_micro_tuning ? 4 : 1);
int enter_count = 0;
Stmt stmt_before_poly = stmt;
while (enter_count < max_enter_poly_times) {
if (!aicpu && polyhedral) {
if (target != "aicpu" && polyhedral) {
Array<NodeRef> poly_res = NEXT_PASS(AutoPoly, stmt_before_poly, binds_0, global_attrs, false, is_dynamic);
enter_count++;
CHECK_EQ(poly_res.size(), 2);
......@@ -704,7 +740,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
// Loop Partition args : 2 : split_const_loop, 3 : remove Div / Mod ops by partitioning,
// 4 : whether to partition convolution or not
if (!aicpu && global_attrs.GetBoolAttr(kEnablePostPolyLoopPartition, true)) {
if (target != "aicpu" && global_attrs.GetBoolAttr(kEnablePostPolyLoopPartition, true)) {
stmt = NEXT_PASS(LoopPartitionCCE, stmt, true, false, !polyhedral);
}
......@@ -731,7 +767,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(FixLoopExtent, stmt);
}
if (!aicpu) {
if (target != "aicpu") {
stmt = NEXT_PASS(AutoPragma, stmt);
}
stmt = NEXT_PASS(EliminateAtomicDma, stmt);
......@@ -815,7 +851,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(AutoDoubleBuffer, stmt);
}
stmt = NEXT_PASS(InjectAccessPtrMSG, stmt);
if (!aicpu) {
if (target != "aicpu") {
stmt = NEXT_PASS(InjectPipe, stmt);
}
stmt = NEXT_PASS(ModDivEliminate, stmt);
......@@ -853,7 +889,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(SpecialValueReplacer, stmt);
stmt = NEXT_PASS(Simplify, stmt);
if (!aicpu) {
if (target != "aicpu") {
stmt = NEXT_PASS(InjectSync, stmt);
}
......@@ -925,52 +961,65 @@ void BuildForDevice(const Array<LoweredFunc> &flist, const std::string &target_n
TVMContext context{kDLCce, 0};
DLDeviceType device_type = context.device_type;
Array<LoweredFunc> out_flist_0;
Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;
for (const auto &func : flist) {
for (auto func : flist) {
if (func->func_type == air::LoweredFuncType::kMixedFunc) {
if (target_name == "cuda") {
if (BuildConfig::Current()->detect_global_barrier) {
func = NEXT_PASS(ThreadSync, func, "global");
}
func = NEXT_PASS(ThreadSync, func, "shared");
func = NEXT_PASS(ThreadSync, func, "warp");
func = NEXT_PASS(InferFragment, func);
func = NEXT_PASS(LowerThreadAllreduce, func, target->thread_warp_size);
}
Array<LoweredFunc> fsplits = NEXT_PASS(SplitHostDevice, func);
out_flist_0.push_back(fsplits[0]);
fhost.push_back(fsplits[0]);
for (size_t idx = 1; idx < fsplits.size(); idx++) {
fdevice.push_back(fsplits[idx]);
}
} else if (func->func_type == air::LoweredFuncType::kHostFunc) {
out_flist_0.push_back(func);
fhost.push_back(func);
} else if (func->func_type == air::LoweredFuncType::kDeviceFunc) {
out_flist_0.push_back(func);
fdevice.push_back(func);
} else {
LOG(FATAL) << "unknown function type " << func->func_type;
}
}
Array<LoweredFunc> out_flist_1;
for (const auto &func : out_flist_0) {
LoweredFunc lowered_func = NEXT_PASS(BindDeviceType, func, static_cast<int>(device_type));
out_flist_1.push_back(lowered_func);
if (target_name == "cuda") {
for (size_t i = 0; i < fdevice.size(); ++i) {
fdevice.Set(i, NEXT_PASS(LowerWarpMemory, fdevice[i], target->thread_warp_size));
}
Array<LoweredFunc> out_flist_2;
for (const auto &func : out_flist_1) {
LoweredFunc lowered_func = NEXT_PASS(LowerTVMBuiltin, func);
out_flist_2.push_back(lowered_func);
}
for (size_t i = 0; i < fhost.size(); ++i) {
fhost.Set(i, NEXT_PASS(BindDeviceType, fhost[i], static_cast<int>(device_type)));
fhost.Set(i, NEXT_PASS(LowerTVMBuiltin, fhost[i]));
}
Target target_host = Target::Create(target_host_name);
Array<LoweredFunc> fdevice_0;
for (const auto &func : fdevice) {
LoweredFunc lowered_func = NEXT_PASS(LowerIntrin, func, target->target_name);
fdevice_0.push_back(lowered_func);
for (size_t i = 0; i < fdevice.size(); ++i) {
if (target_name == "cuda") {
fdevice.Set(i, NEXT_PASS(LowerDeviceStorageAccessInfo, fdevice[i]));
}
fdevice.Set(i, NEXT_PASS(LowerIntrin, fdevice[i], target->target_name));
}
Array<LoweredFunc> out_flist_3;
for (const auto &func : out_flist_2) {
LoweredFunc lowered_func = NEXT_PASS(LowerIntrin, func, target_host->target_name);
out_flist_3.push_back(lowered_func);
for (size_t i = 0; i < fhost.size(); ++i) {
if (target_name == "cuda") {
fhost.Set(i, NEXT_PASS(LowerDeviceStorageAccessInfo, fhost[i]));
}
for (const auto &func : out_flist_3) {
LoweredFunc lowered_func = NEXT_PASS(CombineContextCall, func);
out_flist->push_back(lowered_func);
fhost.Set(i, NEXT_PASS(LowerIntrin, fhost[i], target_host->target_name));
fhost.Set(i, NEXT_PASS(CombineContextCall, fhost[i]));
}
for (const auto &func : fhost) {
out_flist->push_back(func);
}
*out_mdev = air::codegen::Build(fdevice_0, target_name, g_external_call_name);
*out_mdev = air::codegen::Build(fdevice, target_name, g_external_call_name);
return;
}
......@@ -987,7 +1036,7 @@ TVM_REGISTER_NODE_TYPE(BuildRstNode);
BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars,
const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target,
const BuildConfig &config) {
CHECK(inputs.defined()) << "inputs is not defined.";
CHECK(!name.empty()) << "name is empty.";
......@@ -1005,7 +1054,7 @@ BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, cons
attrs = in_attrs;
}
auto rst = Lower(inputs, args, shape_vars, name, binds, attrs, false, polyhedral, false, aicpu, config);
auto rst = Lower(inputs, args, shape_vars, name, binds, attrs, false, polyhedral, false, target, config);
return BuildRstNode::make(rst, name);
}
......@@ -1075,9 +1124,9 @@ air::runtime::Module BuildToModule(const NodeRef &ref, const std::string &target
air::runtime::Module BuildModule(const Schedule &inputs, const Array<NodeRef> &in_args,
const Array<NodeRef> &shape_vars, const std::string &target_name,
const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target,
const BuildConfig &config) {
auto func = BuildToFunc(inputs, in_args, shape_vars, name, in_binds, in_attrs, polyhedral, aicpu, config);
auto func = BuildToFunc(inputs, in_args, shape_vars, name, in_binds, in_attrs, polyhedral, target, config);
return BuildToModule(func, target_name);
}
......
......@@ -454,7 +454,7 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map<std::string
CHECK(config.defined());
config->dump_pass_ir = akg_dump_pass_ir != nullptr;
attrs.Set("pragma_reschedule", make_const(Int(32), 1));
auto build_rst = akg::BuildToFunc(sch, args, shape_vars, kernel_name, in_binds, attrs, true, false, config);
auto build_rst = akg::BuildToFunc(sch, args, shape_vars, kernel_name, in_binds, attrs, true, "cce", config);
CHECK(build_rst.defined());
return build_rst;
}
......@@ -519,7 +519,7 @@ NodeRef composite_lower(const std::string &json_str, Map<std::string, NodeRef> a
akg::BuildConfig config = akg::BuildConfig::Current();
CHECK(config.defined());
bool tuning = attrs.find("tuning") != attrs.end();
return akg::Lower(sch, args, shape_vars, kernel_name, in_binds, attrs, false, true, tuning, false, config);
return akg::Lower(sch, args, shape_vars, kernel_name, in_binds, attrs, false, true, tuning, "cce", config);
}
TVM_REGISTER_GLOBAL("composite_with_json_to_func").set_body_typed(composite_with_json_to_func);
......
......@@ -47,19 +47,19 @@ class MemoryAllocationException : public std::exception {
NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name,
const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode,
bool polyhedral, bool tuning, bool aicpu, const BuildConfig &config);
bool polyhedral, bool tuning, const std::string &target, const BuildConfig &config);
air::runtime::Module BuildModule(const Schedule &inputs, const Array<NodeRef> &in_args,
const Array<NodeRef> &shape_vars, const std::string &target_name,
const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target,
const BuildConfig &config);
class BuildRst;
BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars,
const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu, const BuildConfig &config);
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target, const BuildConfig &config);
air::runtime::Module BuildToModule(const NodeRef &ref, const std::string &target_name = "cce");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册