diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f4736d7e9b25bfd00e1f07a522b3a1feaa2ed7a..136cd43fcf713877db9dc2a5e5bbd772055e082a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,10 +99,6 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1) if(USE_AKG_LOG) add_definitions(-DUSE_AKG_LOG=1) endif() -if(NOT USE_CUDA - OR ENABLE_AKG) - add_definitions("-DFIX_INPUT_ORDER_TVM") -endif() # Generic compilation options include(CheckCXXCompilerFlag) diff --git a/third_party/incubator-tvm/src/pass/split_host_device.cc b/third_party/incubator-tvm/src/pass/split_host_device.cc index 37e9cb4e2838630ee5a87fc7b7b2078f4e475f8c..046707615f48ccc82436313ad8ec35dda61b5b8f 100644 --- a/third_party/incubator-tvm/src/pass/split_host_device.cc +++ b/third_party/incubator-tvm/src/pass/split_host_device.cc @@ -234,32 +234,25 @@ class HostDeviceSplitter : public IRMutator { } } -#ifdef FIX_INPUT_ORDER_TVM - std::shared_ptr na = std::make_shared(); - for (unsigned i = 0; i < (unsigned)args_real.size(); i++) { - bool match = false; - for (unsigned j = 0; j < (unsigned)n->args.size(); j++) { - if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) { - na->args.push_back(n->args[j]); - match = true; - break; - } else { - continue; - } + // Reorder args to match args_real + Array ordered_args; + std::unordered_set args_set; + std::unordered_set args_real_set; + for (size_t i = 0; i < n->args.size(); ++i) { + args_set.insert(n->args[i]); + } + for (size_t i = 0; i < args_real.size(); ++i) { + args_real_set.insert(args_real[i]); + if (args_set.find(args_real[i]) != args_set.end()) { + ordered_args.push_back(args_real[i]); } - - if (!match) { - na->args.push_back(args_real[i]); - // mark handle data type. - for (auto kv : handle_data_type_) { - if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) { - n->handle_data_type.Set(args_real[i], kv.second); - } - } + } + for (size_t i = 0; i < n->args.size(); ++i) { + if (args_real_set.find(n->args[i]) == args_real_set.end()) { + ordered_args.push_back(n->args[i]); } } - n->args = na->args; -#endif + n->args = ordered_args; LoweredFunc f_device(n); Array call_args;