提交 bdb331eb 编写于 作者: L looop5

reorder device args order to match real args order, and fix args not correct...

reorder device args order to match real args order, and fix args not correct when having multiple device functions
上级 5237a9a0
...@@ -99,10 +99,6 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1) ...@@ -99,10 +99,6 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1)
if(USE_AKG_LOG) if(USE_AKG_LOG)
add_definitions(-DUSE_AKG_LOG=1) add_definitions(-DUSE_AKG_LOG=1)
endif() endif()
if(NOT USE_CUDA
OR ENABLE_AKG)
add_definitions("-DFIX_INPUT_ORDER_TVM")
endif()
# Generic compilation options # Generic compilation options
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
......
...@@ -234,32 +234,25 @@ class HostDeviceSplitter : public IRMutator { ...@@ -234,32 +234,25 @@ class HostDeviceSplitter : public IRMutator {
} }
} }
#ifdef FIX_INPUT_ORDER_TVM // Reorder args to match args_real
std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>(); Array<Var> ordered_args;
for (unsigned i = 0; i < (unsigned)args_real.size(); i++) { std::unordered_set<Var, NodeHash, NodeEqual> args_set;
bool match = false; std::unordered_set<Var, NodeHash, NodeEqual> args_real_set;
for (unsigned j = 0; j < (unsigned)n->args.size(); j++) { for (size_t i = 0; i < n->args.size(); ++i) {
if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) { args_set.insert(n->args[i]);
na->args.push_back(n->args[j]);
match = true;
break;
} else {
continue;
}
} }
for (size_t i = 0; i < args_real.size(); ++i) {
if (!match) { args_real_set.insert(args_real[i]);
na->args.push_back(args_real[i]); if (args_set.find(args_real[i]) != args_set.end()) {
// mark handle data type. ordered_args.push_back(args_real[i]);
for (auto kv : handle_data_type_) {
if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
n->handle_data_type.Set(args_real[i], kv.second);
} }
} }
for (size_t i = 0; i < n->args.size(); ++i) {
if (args_real_set.find(n->args[i]) == args_real_set.end()) {
ordered_args.push_back(n->args[i]);
} }
} }
n->args = na->args; n->args = ordered_args;
#endif
LoweredFunc f_device(n); LoweredFunc f_device(n);
Array<Expr> call_args; Array<Expr> call_args;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册