提交 f9d521e9 编写于 作者: D dabaiji

refactor build module to support gpu

上级 6a84977e
...@@ -80,9 +80,9 @@ def build_config(**kwargs): ...@@ -80,9 +80,9 @@ def build_config(**kwargs):
@vc_util.check_input_type(schedule.Schedule, (list, tuple), (list, tuple), str, @vc_util.check_input_type(schedule.Schedule, (list, tuple), (list, tuple), str,
(dict, type(None)), (dict, type(None)), bool, bool, bool, bool) (dict, type(None)), (dict, type(None)), bool, bool, bool, str)
def lower(sch, args, shape_params=None, name="default_function", binds=None, attrs=None, def lower(sch, args, shape_params=None, name="default_function", binds=None, attrs=None,
simple_mode=False, polyhedral=False, tuning=False, aicpu=False): simple_mode=False, polyhedral=False, tuning=False, target="cce"):
"""Lowering function.""" """Lowering function."""
tmp_binds = None tmp_binds = None
if binds is not None: if binds is not None:
...@@ -96,7 +96,7 @@ def lower(sch, args, shape_params=None, name="default_function", binds=None, att ...@@ -96,7 +96,7 @@ def lower(sch, args, shape_params=None, name="default_function", binds=None, att
cfg = _api_internal._GetCurrentBuildConfig() cfg = _api_internal._GetCurrentBuildConfig()
ret = _api_internal._Lower(sch, args, shape_params, name, ret = _api_internal._Lower(sch, args, shape_params, name,
tmp_binds, tmp_attrs, simple_mode, tmp_binds, tmp_attrs, simple_mode,
polyhedral, tuning, aicpu, cfg) polyhedral, tuning, target, cfg)
level = tmp_attrs.get("help_tiling") level = tmp_attrs.get("help_tiling")
if tuning or (level is not None and level > help_tiling_level['None']): if tuning or (level is not None and level > help_tiling_level['None']):
...@@ -116,9 +116,9 @@ def lower(sch, args, shape_params=None, name="default_function", binds=None, att ...@@ -116,9 +116,9 @@ def lower(sch, args, shape_params=None, name="default_function", binds=None, att
@vc_util.check_input_type(schedule.Schedule, (list, tuple), (list, tuple, type(None)), str, @vc_util.check_input_type(schedule.Schedule, (list, tuple), (list, tuple, type(None)), str,
(dict, type(None)), (dict, type(None)), bool, bool) (dict, type(None)), (dict, type(None)), bool, str)
def build_to_func(inputs, args, shape_params=None, name="default_function", def build_to_func(inputs, args, shape_params=None, name="default_function",
binds=None, attrs=None, polyhedral=False, aicpu=False): binds=None, attrs=None, polyhedral=False, target="cce"):
"""Build module.""" """Build module."""
tmp_binds = None tmp_binds = None
if binds is not None: if binds is not None:
...@@ -132,14 +132,13 @@ def build_to_func(inputs, args, shape_params=None, name="default_function", ...@@ -132,14 +132,13 @@ def build_to_func(inputs, args, shape_params=None, name="default_function",
shape_params = [] shape_params = []
cfg = _api_internal._GetCurrentBuildConfig() cfg = _api_internal._GetCurrentBuildConfig()
return _api_internal._BuildToFunc(inputs, args, shape_params, name, tmp_binds, tmp_attrs, return _api_internal._BuildToFunc(inputs, args, shape_params, name, tmp_binds, tmp_attrs,
polyhedral, aicpu, cfg) polyhedral, target, cfg)
@vc_util.check_input_type(schedule.Schedule, (list, tuple), (str, type(None)), (list, tuple), str, @vc_util.check_input_type(schedule.Schedule, (list, tuple), str, (list, tuple), str,
(dict, type(None)), (dict, type(None)), bool, bool) (dict, type(None)), (dict, type(None)), bool)
def build(inputs, args, target=None, shape_params=None, name="default_function", def build(inputs, args, target='cce', shape_params=None, name="default_function",
binds=None, attrs=None, polyhedral=False, aicpu=False): binds=None, attrs=None, polyhedral=False):
tmp_rst = build_to_func(inputs, args, shape_params=shape_params, name=name, binds=binds, tmp_rst = build_to_func(inputs, args, shape_params=shape_params, name=name, binds=binds,
attrs=attrs, polyhedral=polyhedral, aicpu=aicpu) attrs=attrs, polyhedral=polyhedral, target=target)
tmp_target = target if target is not None else 'cce' return _api_internal._BuildToModule(tmp_rst, target)
return _api_internal._BuildToModule(tmp_rst, tmp_target)
...@@ -42,7 +42,6 @@ def op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_na ...@@ -42,7 +42,6 @@ def op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_na
logging.error("Device %s is not in [aicore, aicpu].", device) logging.error("Device %s is not in [aicore, aicpu].", device)
return None return None
aicpu = device == "aicpu"
polyhedral = True polyhedral = True
dump_ir = os.getenv(MS_AKG_DUMP_IR) == "on" dump_ir = os.getenv(MS_AKG_DUMP_IR) == "on"
...@@ -57,9 +56,9 @@ def op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_na ...@@ -57,9 +56,9 @@ def op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_na
if attrs: if attrs:
binds = attrs.pop(BINDS, None) binds = attrs.pop(BINDS, None)
rst = akg.build_to_func(s, args, name=kernel_name, attrs=attrs, polyhedral=polyhedral, rst = akg.build_to_func(s, args, name=kernel_name, attrs=attrs, polyhedral=polyhedral,
binds=binds, aicpu=aicpu) binds=binds, target=device)
else: else:
rst = akg.build_to_func(s, args, name=kernel_name, polyhedral=polyhedral, aicpu=aicpu) rst = akg.build_to_func(s, args, name=kernel_name, polyhedral=polyhedral, target=device)
except Exception: except Exception:
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
......
...@@ -724,13 +724,14 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="", ...@@ -724,13 +724,14 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
if TensorUtils.is_output_value(output): if TensorUtils.is_output_value(output):
op_var = op_var + [output] op_var = op_var + [output]
if sch_tmpl != None: if sch_tmpl is not None:
assert(sch_tmpl['target'] == 'cuda') if sch_tmpl['target'] != 'cuda':
raise ValueError("Only support cuda as target when using schedule template.")
kernel_name = kernel_name if kernel_name != "" else sch_tmpl['op_name'] kernel_name = kernel_name if kernel_name != "" else sch_tmpl['op_name']
with akg.tvm.target.cuda() as target: with akg.tvm.target.cuda() as target:
s = sch_tmpl['schedule'](sch_tmpl['output']) s = sch_tmpl['schedule'](sch_tmpl['output'])
with akg.tvm.build_config(dump_pass_ir = True): with akg.build_config(dump_pass_ir=True):
mod = akg.tvm.build(s, op_var, target, target_host = 'stackvm', name = kernel_name) mod = akg.build(s, op_var, "cuda", shape_var, name=kernel_name, attrs=attrs, polyhedral=polyhedral, binds=binds)
dump_cuda_meta.dump(mod, kernel_name, s, op_var) dump_cuda_meta.dump(mod, kernel_name, s, op_var)
return mod return mod
......
...@@ -436,7 +436,7 @@ void FixParametricBinds(const Map<Tensor, Buffer> &binds, const Array<NodeRef> & ...@@ -436,7 +436,7 @@ void FixParametricBinds(const Map<Tensor, Buffer> &binds, const Array<NodeRef> &
NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name, NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name,
const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode, const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode,
bool polyhedral, bool tuning, bool aicpu, const BuildConfig &config) { bool polyhedral, bool tuning, const std::string &target, const BuildConfig &config) {
ir::TestExprCompuationSimplify(); ir::TestExprCompuationSimplify();
CHECK(sch.defined()) << "sch is not defined."; CHECK(sch.defined()) << "sch is not defined.";
CHECK(!name.empty()) << "name is empty."; CHECK(!name.empty()) << "name is empty.";
...@@ -486,6 +486,41 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -486,6 +486,41 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
auto new_sch = sch.normalize(); auto new_sch = sch.normalize();
auto bounds = air::schedule::InferBound(new_sch); auto bounds = air::schedule::InferBound(new_sch);
Stmt stmt = make_pass("schedule.ScheduleOps", new_sch, bounds, false); Stmt stmt = make_pass("schedule.ScheduleOps", new_sch, bounds, false);
if (target == "cuda") {
// Phase 1
stmt = NEXT_PASS(RewriteForTensorCore, stmt, new_sch, binds_0);
stmt = NEXT_PASS(StorageFlatten, stmt, binds_0, 64, config->instrument_bound_checkers);
stmt = NEXT_PASS(CanonicalSimplify, stmt);
// Phase 2
if (!simple_mode) {
stmt = NEXT_PASS(LoopPartition, stmt, config->partition_const_loop);
}
if (config->disable_vectorize) {
stmt = NEXT_PASS(SkipVectorize, stmt);
} else {
stmt = NEXT_PASS(VectorizeLoop, stmt);
}
stmt = NEXT_PASS(InjectVirtualThread, stmt);
stmt = NEXT_PASS(InjectDoubleBuffer, stmt, config->double_buffer_split_loop);
stmt = NEXT_PASS(StorageRewrite, stmt);
stmt = NEXT_PASS(UnrollLoop, stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
config->auto_unroll_max_extent, config->unroll_explicit);
// Phase 3
stmt = NEXT_PASS(Simplify, stmt);
stmt = NEXT_PASS(RemoveNoOp, stmt);
if (config->instrument_bound_checkers) {
stmt = NEXT_PASS(InstrumentBoundCheckers, stmt);
}
if (simple_mode) {
return stmt;
}
LoweredFunc lowered_func = NEXT_PASS(MakeAPI, stmt, name, arg_list_0, 0, config->restricted_func);
return lowered_func;
}
if (!polyhedral) { if (!polyhedral) {
// for conv-matmul manual schedule // for conv-matmul manual schedule
stmt = NEXT_PASS(AutoMadPragmaAttr, stmt, true); stmt = NEXT_PASS(AutoMadPragmaAttr, stmt, true);
...@@ -518,7 +553,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -518,7 +553,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
PassMgr::SetArgs(arg_list_0); PassMgr::SetArgs(arg_list_0);
if (!aicpu) { if (target != "aicpu") {
stmt = NEXT_PASS(MathIntrinRewrite, stmt); stmt = NEXT_PASS(MathIntrinRewrite, stmt);
} }
...@@ -527,7 +562,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -527,7 +562,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
} }
// Phase 1 // Phase 1
if (!aicpu && polyhedral) { if (target != "aicpu" && polyhedral) {
stmt = NEXT_PASS(UnifyLoopVars, stmt, binds_0, arg_list_0); stmt = NEXT_PASS(UnifyLoopVars, stmt, binds_0, arg_list_0);
stmt = NEXT_PASS(CheckShapeParams, stmt, binds_0); stmt = NEXT_PASS(CheckShapeParams, stmt, binds_0);
stmt = NEXT_PASS(AlignPartitionCCE, stmt); stmt = NEXT_PASS(AlignPartitionCCE, stmt);
...@@ -597,12 +632,13 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -597,12 +632,13 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
} }
// micro-tuning configs: current strategy is to retry autopoly up to 3 times when storage flatten/rewrite fails // micro-tuning configs: current strategy is to retry autopoly up to 3 times when storage flatten/rewrite fails
bool need_micro_tuning = !aicpu && polyhedral && !is_dynamic && global_attrs.GetStringAttr("dim", "").empty(); bool need_micro_tuning =
target != "aicpu" && polyhedral && !is_dynamic && global_attrs.GetStringAttr("dim", "").empty();
const int max_enter_poly_times = global_attrs.GetIntAttr(kMaxNumRetryPoly, need_micro_tuning ? 4 : 1); const int max_enter_poly_times = global_attrs.GetIntAttr(kMaxNumRetryPoly, need_micro_tuning ? 4 : 1);
int enter_count = 0; int enter_count = 0;
Stmt stmt_before_poly = stmt; Stmt stmt_before_poly = stmt;
while (enter_count < max_enter_poly_times) { while (enter_count < max_enter_poly_times) {
if (!aicpu && polyhedral) { if (target != "aicpu" && polyhedral) {
Array<NodeRef> poly_res = NEXT_PASS(AutoPoly, stmt_before_poly, binds_0, global_attrs, false, is_dynamic); Array<NodeRef> poly_res = NEXT_PASS(AutoPoly, stmt_before_poly, binds_0, global_attrs, false, is_dynamic);
enter_count++; enter_count++;
CHECK_EQ(poly_res.size(), 2); CHECK_EQ(poly_res.size(), 2);
...@@ -704,7 +740,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -704,7 +740,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
// Loop Partition args : 2 : split_const_loop, 3 : remove Div / Mod ops by partitioning, // Loop Partition args : 2 : split_const_loop, 3 : remove Div / Mod ops by partitioning,
// 4 : whether to partition convolution or not // 4 : whether to partition convolution or not
if (!aicpu && global_attrs.GetBoolAttr(kEnablePostPolyLoopPartition, true)) { if (target != "aicpu" && global_attrs.GetBoolAttr(kEnablePostPolyLoopPartition, true)) {
stmt = NEXT_PASS(LoopPartitionCCE, stmt, true, false, !polyhedral); stmt = NEXT_PASS(LoopPartitionCCE, stmt, true, false, !polyhedral);
} }
...@@ -731,7 +767,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -731,7 +767,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(FixLoopExtent, stmt); stmt = NEXT_PASS(FixLoopExtent, stmt);
} }
if (!aicpu) { if (target != "aicpu") {
stmt = NEXT_PASS(AutoPragma, stmt); stmt = NEXT_PASS(AutoPragma, stmt);
} }
stmt = NEXT_PASS(EliminateAtomicDma, stmt); stmt = NEXT_PASS(EliminateAtomicDma, stmt);
...@@ -741,7 +777,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -741,7 +777,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
if (is_dynamic) { if (is_dynamic) {
stmt = NEXT_PASS(AnalyzeMinAlignDynamic, stmt, global_attrs.GetIntAttr(kEnableConvAnalyzeAlign, true), stmt = NEXT_PASS(AnalyzeMinAlignDynamic, stmt, global_attrs.GetIntAttr(kEnableConvAnalyzeAlign, true),
global_attrs.GetIntAttr(kEnableScalarAlign, false)); global_attrs.GetIntAttr(kEnableScalarAlign, false));
} else { } else {
stmt = NEXT_PASS(RewriteBroadcastVector, stmt); stmt = NEXT_PASS(RewriteBroadcastVector, stmt);
stmt = NEXT_PASS(OptimizePragma, stmt); stmt = NEXT_PASS(OptimizePragma, stmt);
...@@ -815,7 +851,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -815,7 +851,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(AutoDoubleBuffer, stmt); stmt = NEXT_PASS(AutoDoubleBuffer, stmt);
} }
stmt = NEXT_PASS(InjectAccessPtrMSG, stmt); stmt = NEXT_PASS(InjectAccessPtrMSG, stmt);
if (!aicpu) { if (target != "aicpu") {
stmt = NEXT_PASS(InjectPipe, stmt); stmt = NEXT_PASS(InjectPipe, stmt);
} }
stmt = NEXT_PASS(ModDivEliminate, stmt); stmt = NEXT_PASS(ModDivEliminate, stmt);
...@@ -853,7 +889,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -853,7 +889,7 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt = NEXT_PASS(SpecialValueReplacer, stmt); stmt = NEXT_PASS(SpecialValueReplacer, stmt);
stmt = NEXT_PASS(Simplify, stmt); stmt = NEXT_PASS(Simplify, stmt);
if (!aicpu) { if (target != "aicpu") {
stmt = NEXT_PASS(InjectSync, stmt); stmt = NEXT_PASS(InjectSync, stmt);
} }
...@@ -925,52 +961,65 @@ void BuildForDevice(const Array<LoweredFunc> &flist, const std::string &target_n ...@@ -925,52 +961,65 @@ void BuildForDevice(const Array<LoweredFunc> &flist, const std::string &target_n
TVMContext context{kDLCce, 0}; TVMContext context{kDLCce, 0};
DLDeviceType device_type = context.device_type; DLDeviceType device_type = context.device_type;
Array<LoweredFunc> out_flist_0; Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice; Array<LoweredFunc> fdevice;
for (const auto &func : flist) { for (auto func : flist) {
if (func->func_type == air::LoweredFuncType::kMixedFunc) { if (func->func_type == air::LoweredFuncType::kMixedFunc) {
if (target_name == "cuda") {
if (BuildConfig::Current()->detect_global_barrier) {
func = NEXT_PASS(ThreadSync, func, "global");
}
func = NEXT_PASS(ThreadSync, func, "shared");
func = NEXT_PASS(ThreadSync, func, "warp");
func = NEXT_PASS(InferFragment, func);
func = NEXT_PASS(LowerThreadAllreduce, func, target->thread_warp_size);
}
Array<LoweredFunc> fsplits = NEXT_PASS(SplitHostDevice, func); Array<LoweredFunc> fsplits = NEXT_PASS(SplitHostDevice, func);
out_flist_0.push_back(fsplits[0]); fhost.push_back(fsplits[0]);
for (size_t idx = 1; idx < fsplits.size(); idx++) { for (size_t idx = 1; idx < fsplits.size(); idx++) {
fdevice.push_back(fsplits[idx]); fdevice.push_back(fsplits[idx]);
} }
} else if (func->func_type == air::LoweredFuncType::kHostFunc) { } else if (func->func_type == air::LoweredFuncType::kHostFunc) {
out_flist_0.push_back(func); fhost.push_back(func);
} else if (func->func_type == air::LoweredFuncType::kDeviceFunc) { } else if (func->func_type == air::LoweredFuncType::kDeviceFunc) {
out_flist_0.push_back(func); fdevice.push_back(func);
} else { } else {
LOG(FATAL) << "unknown function type " << func->func_type; LOG(FATAL) << "unknown function type " << func->func_type;
} }
} }
Array<LoweredFunc> out_flist_1; if (target_name == "cuda") {
for (const auto &func : out_flist_0) { for (size_t i = 0; i < fdevice.size(); ++i) {
LoweredFunc lowered_func = NEXT_PASS(BindDeviceType, func, static_cast<int>(device_type)); fdevice.Set(i, NEXT_PASS(LowerWarpMemory, fdevice[i], target->thread_warp_size));
out_flist_1.push_back(lowered_func); }
} }
Array<LoweredFunc> out_flist_2;
for (const auto &func : out_flist_1) { for (size_t i = 0; i < fhost.size(); ++i) {
LoweredFunc lowered_func = NEXT_PASS(LowerTVMBuiltin, func); fhost.Set(i, NEXT_PASS(BindDeviceType, fhost[i], static_cast<int>(device_type)));
out_flist_2.push_back(lowered_func); fhost.Set(i, NEXT_PASS(LowerTVMBuiltin, fhost[i]));
} }
Target target_host = Target::Create(target_host_name); Target target_host = Target::Create(target_host_name);
Array<LoweredFunc> fdevice_0;
for (const auto &func : fdevice) { for (size_t i = 0; i < fdevice.size(); ++i) {
LoweredFunc lowered_func = NEXT_PASS(LowerIntrin, func, target->target_name); if (target_name == "cuda") {
fdevice_0.push_back(lowered_func); fdevice.Set(i, NEXT_PASS(LowerDeviceStorageAccessInfo, fdevice[i]));
}
fdevice.Set(i, NEXT_PASS(LowerIntrin, fdevice[i], target->target_name));
} }
Array<LoweredFunc> out_flist_3; for (size_t i = 0; i < fhost.size(); ++i) {
for (const auto &func : out_flist_2) { if (target_name == "cuda") {
LoweredFunc lowered_func = NEXT_PASS(LowerIntrin, func, target_host->target_name); fhost.Set(i, NEXT_PASS(LowerDeviceStorageAccessInfo, fhost[i]));
out_flist_3.push_back(lowered_func); }
fhost.Set(i, NEXT_PASS(LowerIntrin, fhost[i], target_host->target_name));
fhost.Set(i, NEXT_PASS(CombineContextCall, fhost[i]));
} }
for (const auto &func : out_flist_3) {
LoweredFunc lowered_func = NEXT_PASS(CombineContextCall, func); for (const auto &func : fhost) {
out_flist->push_back(lowered_func); out_flist->push_back(func);
} }
*out_mdev = air::codegen::Build(fdevice_0, target_name, g_external_call_name); *out_mdev = air::codegen::Build(fdevice, target_name, g_external_call_name);
return; return;
} }
...@@ -987,7 +1036,7 @@ TVM_REGISTER_NODE_TYPE(BuildRstNode); ...@@ -987,7 +1036,7 @@ TVM_REGISTER_NODE_TYPE(BuildRstNode);
BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars,
const std::string &name, const Map<Tensor, Buffer> &in_binds, const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu, const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target,
const BuildConfig &config) { const BuildConfig &config) {
CHECK(inputs.defined()) << "inputs is not defined."; CHECK(inputs.defined()) << "inputs is not defined.";
CHECK(!name.empty()) << "name is empty."; CHECK(!name.empty()) << "name is empty.";
...@@ -1005,7 +1054,7 @@ BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, cons ...@@ -1005,7 +1054,7 @@ BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, cons
attrs = in_attrs; attrs = in_attrs;
} }
auto rst = Lower(inputs, args, shape_vars, name, binds, attrs, false, polyhedral, false, aicpu, config); auto rst = Lower(inputs, args, shape_vars, name, binds, attrs, false, polyhedral, false, target, config);
return BuildRstNode::make(rst, name); return BuildRstNode::make(rst, name);
} }
...@@ -1073,11 +1122,11 @@ air::runtime::Module BuildToModule(const NodeRef &ref, const std::string &target ...@@ -1073,11 +1122,11 @@ air::runtime::Module BuildToModule(const NodeRef &ref, const std::string &target
} }
air::runtime::Module BuildModule(const Schedule &inputs, const Array<NodeRef> &in_args, air::runtime::Module BuildModule(const Schedule &inputs, const Array<NodeRef> &in_args,
const Array<NodeRef> &shape_vars, const std::string &target_name, const Array<NodeRef> &shape_vars, const std::string &target_name,
const std::string &name, const Map<Tensor, Buffer> &in_binds, const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu, const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target,
const BuildConfig &config) { const BuildConfig &config) {
auto func = BuildToFunc(inputs, in_args, shape_vars, name, in_binds, in_attrs, polyhedral, aicpu, config); auto func = BuildToFunc(inputs, in_args, shape_vars, name, in_binds, in_attrs, polyhedral, target, config);
return BuildToModule(func, target_name); return BuildToModule(func, target_name);
} }
......
...@@ -454,7 +454,7 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map<std::string ...@@ -454,7 +454,7 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map<std::string
CHECK(config.defined()); CHECK(config.defined());
config->dump_pass_ir = akg_dump_pass_ir != nullptr; config->dump_pass_ir = akg_dump_pass_ir != nullptr;
attrs.Set("pragma_reschedule", make_const(Int(32), 1)); attrs.Set("pragma_reschedule", make_const(Int(32), 1));
auto build_rst = akg::BuildToFunc(sch, args, shape_vars, kernel_name, in_binds, attrs, true, false, config); auto build_rst = akg::BuildToFunc(sch, args, shape_vars, kernel_name, in_binds, attrs, true, "cce", config);
CHECK(build_rst.defined()); CHECK(build_rst.defined());
return build_rst; return build_rst;
} }
...@@ -519,7 +519,7 @@ NodeRef composite_lower(const std::string &json_str, Map<std::string, NodeRef> a ...@@ -519,7 +519,7 @@ NodeRef composite_lower(const std::string &json_str, Map<std::string, NodeRef> a
akg::BuildConfig config = akg::BuildConfig::Current(); akg::BuildConfig config = akg::BuildConfig::Current();
CHECK(config.defined()); CHECK(config.defined());
bool tuning = attrs.find("tuning") != attrs.end(); bool tuning = attrs.find("tuning") != attrs.end();
return akg::Lower(sch, args, shape_vars, kernel_name, in_binds, attrs, false, true, tuning, false, config); return akg::Lower(sch, args, shape_vars, kernel_name, in_binds, attrs, false, true, tuning, "cce", config);
} }
TVM_REGISTER_GLOBAL("composite_with_json_to_func").set_body_typed(composite_with_json_to_func); TVM_REGISTER_GLOBAL("composite_with_json_to_func").set_body_typed(composite_with_json_to_func);
......
...@@ -47,19 +47,19 @@ class MemoryAllocationException : public std::exception { ...@@ -47,19 +47,19 @@ class MemoryAllocationException : public std::exception {
NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name, NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, const std::string &name,
const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode, const Map<Tensor, Buffer> &in_binds, const Map<std::string, NodeRef> &in_attrs, bool simple_mode,
bool polyhedral, bool tuning, bool aicpu, const BuildConfig &config); bool polyhedral, bool tuning, const std::string &target, const BuildConfig &config);
air::runtime::Module BuildModule(const Schedule &inputs, const Array<NodeRef> &in_args, air::runtime::Module BuildModule(const Schedule &inputs, const Array<NodeRef> &in_args,
const Array<NodeRef> &shape_vars, const std::string &target_name, const Array<NodeRef> &shape_vars, const std::string &target_name,
const std::string &name, const Map<Tensor, Buffer> &in_binds, const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu, const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target,
const BuildConfig &config); const BuildConfig &config);
class BuildRst; class BuildRst;
BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars, BuildRst BuildToFunc(const Schedule &inputs, const Array<NodeRef> &in_args, const Array<NodeRef> &shape_vars,
const std::string &name, const Map<Tensor, Buffer> &in_binds, const std::string &name, const Map<Tensor, Buffer> &in_binds,
const Map<std::string, NodeRef> &in_attrs, bool polyhedral, bool aicpu, const BuildConfig &config); const Map<std::string, NodeRef> &in_attrs, bool polyhedral, const std::string &target, const BuildConfig &config);
air::runtime::Module BuildToModule(const NodeRef &ref, const std::string &target_name = "cce"); air::runtime::Module BuildToModule(const NodeRef &ref, const std::string &target_name = "cce");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册