diff --git a/core/src/scheduler/SchedInst.h b/core/src/scheduler/SchedInst.h index e758f3785164147ee8d8df9736e76715bdc44802..a3048069f9febe9ec7ea9b5112b3b7b6fab537f9 100644 --- a/core/src/scheduler/SchedInst.h +++ b/core/src/scheduler/SchedInst.h @@ -26,9 +26,11 @@ #include "optimizer/OnlyCPUPass.h" #include "optimizer/OnlyGPUPass.h" #include "optimizer/Optimizer.h" +#include "server/Config.h" #include #include +#include #include namespace milvus { @@ -95,11 +97,21 @@ class OptimizerInst { if (instance == nullptr) { std::lock_guard lock(mutex_); if (instance == nullptr) { + server::Config& config = server::Config::GetInstance(); + std::vector search_resources; + bool has_cpu = false; + config.GetResourceConfigSearchResources(search_resources); + for (auto& resource : search_resources) { + if (resource == "cpu") { + has_cpu = true; + } + } + std::vector pass_list; pass_list.push_back(std::make_shared()); pass_list.push_back(std::make_shared()); pass_list.push_back(std::make_shared()); - pass_list.push_back(std::make_shared()); + pass_list.push_back(std::make_shared(has_cpu)); instance = std::make_shared(pass_list); } } diff --git a/core/src/scheduler/optimizer/OnlyCPUPass.cpp b/core/src/scheduler/optimizer/OnlyCPUPass.cpp index 2651a6e1a54e5542c34de568a4f1ce25029395b8..238a91a82c77c05fe74576e2fedf33123214dbfc 100644 --- a/core/src/scheduler/optimizer/OnlyCPUPass.cpp +++ b/core/src/scheduler/optimizer/OnlyCPUPass.cpp @@ -35,13 +35,13 @@ OnlyCPUPass::Run(const TaskPtr& task) { } auto gpu_id = get_gpu_pool(); - if (gpu_id.empty()) { - ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); - auto label = std::make_shared(std::weak_ptr(res_ptr)); - task->label() = label; - return true; - } - return false; + if (not gpu_id.empty()) + return false; + + ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); + auto label = std::make_shared(std::weak_ptr(res_ptr)); + task->label() = label; + return true; } } // namespace scheduler diff --git a/core/src/scheduler/optimizer/OnlyGPUPass.cpp b/core/src/scheduler/optimizer/OnlyGPUPass.cpp index f39ca1a04236cb14229f6fa94106dbb26f6cc838..3fcda0e8a347bc6ca62890c78b394314f426c9e1 100644 --- a/core/src/scheduler/optimizer/OnlyGPUPass.cpp +++ b/core/src/scheduler/optimizer/OnlyGPUPass.cpp @@ -20,14 +20,16 @@ #include "scheduler/Utils.h" #include "scheduler/task/SearchTask.h" #include "scheduler/tasklabel/SpecResLabel.h" -#include "server/Config.h" namespace milvus { namespace scheduler { +OnlyGPUPass::OnlyGPUPass(bool has_cpu) : has_cpu_(has_cpu) { +} + bool OnlyGPUPass::Run(const TaskPtr& task) { - if (task->Type() != TaskType::SearchTask) + if (task->Type() != TaskType::SearchTask || has_cpu_) return false; auto search_task = std::static_pointer_cast(task); @@ -36,29 +38,15 @@ OnlyGPUPass::Run(const TaskPtr& task) { return false; } - server::Config& config = server::Config::GetInstance(); - std::vector search_resources; - config.GetResourceConfigSearchResources(search_resources); - for (auto& resource : search_resources) { - if (resource == "cpu") { - return false; - } - } - auto gpu_id = get_gpu_pool(); - if (!gpu_id.empty()) { - ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource(ResourceType::GPU, gpu_id[specified_gpu_id_]); - auto label = std::make_shared(std::weak_ptr(res_ptr)); - task->label() = label; - } else { + if (gpu_id.empty()) return false; - } - if (specified_gpu_id_ < gpu_id.size() - 1) { - ++specified_gpu_id_; - } else { - specified_gpu_id_ = 0; - } + ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource(ResourceType::GPU, gpu_id[specified_gpu_id_]); + auto label = std::make_shared(std::weak_ptr(res_ptr)); + task->label() = label; + + specified_gpu_id_ = specified_gpu_id_++ % gpu_id.size(); return true; } diff --git a/core/src/scheduler/optimizer/OnlyGPUPass.h b/core/src/scheduler/optimizer/OnlyGPUPass.h index 75a5f9e4f1a6532231b2479618df8f4eb1bc3629..10d909d30e1888d5025df8a80ee899bc348eafc9 100644 --- a/core/src/scheduler/optimizer/OnlyGPUPass.h +++ b/core/src/scheduler/optimizer/OnlyGPUPass.h @@ -34,7 +34,7 @@ namespace scheduler { class OnlyGPUPass : public Pass { public: - OnlyGPUPass() = default; + explicit OnlyGPUPass(bool has_cpu); public: bool @@ -42,6 +42,7 @@ class OnlyGPUPass : public Pass { private: uint64_t specified_gpu_id_ = 0; + bool has_cpu_ = false; }; using OnlyGPUPassPtr = std::shared_ptr;