提交 dc8ff8e5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!90 Enable args reorder in split_host_device pass when run AKG operators...

!90 Enable args reorder in split_host_device pass when run AKG operators directly or run MindSpore operators. 
Merge pull request !90 from looop5/split_host_device
...@@ -99,10 +99,6 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1) ...@@ -99,10 +99,6 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1)
if(USE_AKG_LOG) if(USE_AKG_LOG)
add_definitions(-DUSE_AKG_LOG=1) add_definitions(-DUSE_AKG_LOG=1)
endif() endif()
if(NOT USE_CUDA
OR ENABLE_AKG)
add_definitions("-DFIX_INPUT_ORDER_TVM")
endif()
# Generic compilation options # Generic compilation options
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
......
...@@ -234,32 +234,25 @@ class HostDeviceSplitter : public IRMutator { ...@@ -234,32 +234,25 @@ class HostDeviceSplitter : public IRMutator {
} }
} }
#ifdef FIX_INPUT_ORDER_TVM // Reorder args to match args_real
std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>(); Array<Var> ordered_args;
for (unsigned i = 0; i < (unsigned)args_real.size(); i++) { std::unordered_set<Var, NodeHash, NodeEqual> args_set;
bool match = false; std::unordered_set<Var, NodeHash, NodeEqual> args_real_set;
for (unsigned j = 0; j < (unsigned)n->args.size(); j++) { for (size_t i = 0; i < n->args.size(); ++i) {
if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) { args_set.insert(n->args[i]);
na->args.push_back(n->args[j]); }
match = true; for (size_t i = 0; i < args_real.size(); ++i) {
break; args_real_set.insert(args_real[i]);
} else { if (args_set.find(args_real[i]) != args_set.end()) {
continue; ordered_args.push_back(args_real[i]);
}
} }
}
if (!match) { for (size_t i = 0; i < n->args.size(); ++i) {
na->args.push_back(args_real[i]); if (args_real_set.find(n->args[i]) == args_real_set.end()) {
// mark handle data type. ordered_args.push_back(n->args[i]);
for (auto kv : handle_data_type_) {
if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
n->handle_data_type.Set(args_real[i], kv.second);
}
}
} }
} }
n->args = na->args; n->args = ordered_args;
#endif
LoweredFunc f_device(n); LoweredFunc f_device(n);
Array<Expr> call_args; Array<Expr> call_args;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册