From 11cb43ac2ed0fb55ed81e1b7fdf58a6818921296 Mon Sep 17 00:00:00 2001 From: groot Date: Thu, 25 Apr 2019 09:38:27 +0800 Subject: [PATCH] add scheduler Former-commit-id: aead7396cc627fc680408188d3182fd098b5271d --- cpp/src/server/VecServiceHandler.cpp | 106 ++++-------- cpp/src/server/VecServiceScheduler.cpp | 107 +++++++++++- cpp/src/server/VecServiceScheduler.h | 56 ++++++ cpp/src/server/VecServiceTask.cpp | 229 +++++++++++++++++++++++++ cpp/src/server/VecServiceTask.h | 116 +++++++++++++ 5 files changed, 537 insertions(+), 77 deletions(-) create mode 100644 cpp/src/server/VecServiceTask.cpp create mode 100644 cpp/src/server/VecServiceTask.h diff --git a/cpp/src/server/VecServiceHandler.cpp b/cpp/src/server/VecServiceHandler.cpp index c7892976..798f0de0 100644 --- a/cpp/src/server/VecServiceHandler.cpp +++ b/cpp/src/server/VecServiceHandler.cpp @@ -4,6 +4,7 @@ * Proprietary and confidential. ******************************************************************************/ #include "VecServiceHandler.h" +#include "VecServiceTask.h" #include "ServerConfig.h" #include "VecIdMapper.h" #include "utils/Log.h" @@ -34,19 +35,11 @@ VecServiceHandler::add_group(const VecGroup &group) { SERVER_LOG_TRACE << "group.id = " << group.id << ", group.dimension = " << group.dimension << ", group.index_type = " << group.index_type; - try { - engine::meta::GroupSchema group_info; - group_info.dimension = (size_t)group.dimension; - group_info.group_id = group.id; - engine::Status stat = db_->add_group(group_info); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } + BaseTaskPtr task_ptr = AddGroupTask::Create(group.dimension, group.id); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "add_group() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "add_group() finished"; } void @@ -54,21 +47,12 @@ VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) { SERVER_LOG_INFO << "get_group() called"; SERVER_LOG_TRACE << "group_id = " << group_id; - try { - engine::meta::GroupSchema group_info; - group_info.group_id = group_id; - engine::Status stat = db_->get_group(group_info); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - _return.id = group_info.group_id; - _return.dimension = (int32_t)group_info.dimension; - } + _return.id = group_id; + BaseTaskPtr task_ptr = GetGroupTask::Create(group_id, _return.dimension); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "get_group() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "get_group() finished"; } void @@ -76,12 +60,11 @@ VecServiceHandler::del_group(const std::string &group_id) { SERVER_LOG_INFO << "del_group() called"; SERVER_LOG_TRACE << "group_id = " << group_id; - try { + BaseTaskPtr task_ptr = DeleteGroupTask::Create(group_id); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "del_group() not implemented"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "del_group() not implemented"; } @@ -90,25 +73,13 @@ VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tens SERVER_LOG_INFO << "add_vector() called"; SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size(); - try { - engine::IDNumbers vector_ids; - std::vector vec_f(tensor.tensor.begin(), tensor.tensor.end()); - engine::Status stat = db_->add_vectors(group_id, 1, vec_f.data(), vector_ids); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - if(vector_ids.size() != 1) { - SERVER_LOG_ERROR << "Vector ID not returned"; - } else { - std::string nid = group_id + "_" + std::to_string(vector_ids[0]); - IVecIdMapper::GetInstance()->Put(nid, tensor.uid); - } - } + VecTensorList tensor_list; + tensor_list.tensor_list.push_back(tensor); + BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "add_vector() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "add_vector() finished"; } void @@ -118,32 +89,11 @@ VecServiceHandler::add_vector_batch(const std::string &group_id, SERVER_LOG_TRACE << "group_id = " << group_id << ", vector list size = " << tensor_list.tensor_list.size(); - try { - std::vector vec_f; - for(const VecTensor& tensor : tensor_list.tensor_list) { - vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); - } - - engine::IDNumbers vector_ids; - engine::Status stat = db_->add_vectors(group_id, tensor_list.tensor_list.size(), vec_f.data(), vector_ids); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - if(vector_ids.size() != tensor_list.tensor_list.size()) { - SERVER_LOG_ERROR << "Vector ID not returned"; - } else { - std::string nid_prefix = group_id + "_"; - for(size_t i = 0; i < vector_ids.size(); i++) { - std::string nid = nid_prefix + std::to_string(vector_ids[i]); - IVecIdMapper::GetInstance()->Put(nid, tensor_list.tensor_list[i].uid); - } - } - } + BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "add_vector_batch() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "add_vector_batch() finished"; } @@ -177,10 +127,12 @@ VecServiceHandler::search_vector(VecSearchResult &_return, } } - SERVER_LOG_INFO << "search_vector() finished"; + } catch (std::exception& ex) { SERVER_LOG_ERROR << ex.what(); } + + SERVER_LOG_INFO << "search_vector() finished"; } void @@ -220,10 +172,12 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return, } } - SERVER_LOG_INFO << "search_vector_batch() finished"; + } catch (std::exception& ex) { SERVER_LOG_ERROR << ex.what(); } + + SERVER_LOG_INFO << "search_vector_batch() finished"; } VecServiceHandler::~VecServiceHandler() { diff --git a/cpp/src/server/VecServiceScheduler.cpp b/cpp/src/server/VecServiceScheduler.cpp index b4dd5063..5acadc08 100644 --- a/cpp/src/server/VecServiceScheduler.cpp +++ b/cpp/src/server/VecServiceScheduler.cpp @@ -4,17 +4,122 @@ * Proprietary and confidential. ******************************************************************************/ #include "VecServiceScheduler.h" +#include "utils/Log.h" namespace zilliz { namespace vecwise { namespace server { -VecServiceScheduler::VecServiceScheduler() { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +BaseTask::BaseTask(const std::string& task_group) + : task_group_(task_group), + done_(false), + error_code_(SERVER_SUCCESS) { } +BaseTask::~BaseTask() { + WaitToFinish(); +} + +ServerError BaseTask::Execute() { + error_code_ = OnExecute(); + done_ = true; + return error_code_; +} + +ServerError BaseTask::WaitToFinish() { + std::unique_lock lock(finish_mtx_); + finish_cond_.wait(lock, [this] { return done_; }); + + return error_code_; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +VecServiceScheduler::VecServiceScheduler() +: stopped_(false) { + Start(); +} + VecServiceScheduler::~VecServiceScheduler() { + Stop(); +} + +void VecServiceScheduler::Start() { + if(!stopped_) { + return; + } + + stopped_ = false; +} + +void VecServiceScheduler::Stop() { + { + std::lock_guard lock(queue_mtx_); + for(auto iter : task_groups_) { + if(iter.second != nullptr) { + iter.second->Put(nullptr); + } + } + } + + for(auto iter : execute_threads_) { + if(iter == nullptr) + continue; + + iter->join(); + } + stopped_ = true; +} + +ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { + if(task_ptr == nullptr) { + return SERVER_NULL_POINTER; + } + + return PutTaskToQueue(task_ptr); +} + +namespace { + void TakeTaskToExecute(TaskQueuePtr task_queue) { + if(task_queue == nullptr) { + return; + } + + while(true) { + BaseTaskPtr task = task_queue->Take(); + if (task == nullptr) { + break;//stop the thread + } + + try { + ServerError err = task->Execute(); + if(err != SERVER_SUCCESS) { + SERVER_LOG_ERROR << "Task failed with code: " << err; + } + } catch (std::exception& ex) { + SERVER_LOG_ERROR << "Task failed to execute: " << ex.what(); + } + } + } +} + +ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { + std::lock_guard lock(queue_mtx_); + + std::string group_name = task_ptr->TaskGroup(); + if(task_groups_.count(group_name) > 0) { + task_groups_[group_name]->Put(task_ptr); + } else { + TaskQueuePtr queue = std::make_shared(); + queue->Put(task_ptr); + task_groups_.insert(std::make_pair(group_name, queue)); + //start a thread + ThreadPtr thread = std::make_shared(&TakeTaskToExecute, queue); + execute_threads_.push_back(thread); + SERVER_LOG_INFO << "Create new thread for task group: " << group_name; + } } } diff --git a/cpp/src/server/VecServiceScheduler.h b/cpp/src/server/VecServiceScheduler.h index e5819aab..bbbcd151 100644 --- a/cpp/src/server/VecServiceScheduler.h +++ b/cpp/src/server/VecServiceScheduler.h @@ -5,15 +5,71 @@ ******************************************************************************/ #pragma once +#include "utils/BlockingQueue.h" + +#include +#include +#include + namespace zilliz { namespace vecwise { namespace server { +class BaseTask { +protected: + BaseTask(const std::string& task_group); + virtual ~BaseTask(); + +public: + ServerError Execute(); + ServerError WaitToFinish(); + + std::string TaskGroup() const { return task_group_; } + + ServerError ErrorCode() const { return error_code_; } +protected: + virtual ServerError OnExecute() = 0; + +protected: + mutable std::mutex finish_mtx_; + std::condition_variable finish_cond_; + + std::string task_group_; + bool done_; + ServerError error_code_; +}; + +using BaseTaskPtr = std::shared_ptr; +using TaskQueue = BlockingQueue; +using TaskQueuePtr = std::shared_ptr; +using ThreadPtr = std::shared_ptr; + class VecServiceScheduler { public: + static VecServiceScheduler& GetInstance() { + static VecServiceScheduler scheduler; + return scheduler; + } + + void Start(); + void Stop(); + + ServerError ExecuteTask(const BaseTaskPtr& task_ptr); + +protected: VecServiceScheduler(); virtual ~VecServiceScheduler(); + ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr); + +private: + mutable std::mutex queue_mtx_; + + std::map task_groups_; + + std::vector execute_threads_; + + bool stopped_; }; diff --git a/cpp/src/server/VecServiceTask.cpp b/cpp/src/server/VecServiceTask.cpp new file mode 100644 index 00000000..5e80132f --- /dev/null +++ b/cpp/src/server/VecServiceTask.cpp @@ -0,0 +1,229 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#include "VecServiceTask.h" +#include "ServerConfig.h" +#include "VecIdMapper.h" +#include "utils/CommonUtil.h" +#include "utils/Log.h" +#include "db/DB.h" +#include "db/Env.h" + +namespace zilliz { +namespace vecwise { +namespace server { + +static const std::string NORMAL_TASK_GROUP = "normal"; + +namespace { + class DBWrapper { + public: + DBWrapper() { + zilliz::vecwise::engine::Options opt; + ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_SERVER); + opt.meta.backend_uri = config.GetValue(CONFIG_SERVER_DB_URL); + std::string db_path = config.GetValue(CONFIG_SERVER_DB_PATH); + opt.meta.path = db_path + "/db"; + + CommonUtil::CreateDirectory(opt.meta.path); + + zilliz::vecwise::engine::DB::Open(opt, &db_); + if(db_ == nullptr) { + SERVER_LOG_ERROR << "Failed to open db"; + throw ServerException(SERVER_NULL_POINTER, "Failed to open db"); + } + } + + zilliz::vecwise::engine::DB* DB() { return db_; } + + private: + zilliz::vecwise::engine::DB* db_ = nullptr; + }; + + zilliz::vecwise::engine::DB* DB() { + static DBWrapper db_wrapper; + return db_wrapper.DB(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +AddGroupTask::AddGroupTask(int32_t dimension, + const std::string& group_id) +: BaseTask(NORMAL_TASK_GROUP), + dimension_(dimension), + group_id_(group_id) { + +} + +BaseTaskPtr AddGroupTask::Create(int32_t dimension, + const std::string& group_id) { + return std::shared_ptr(new AddGroupTask(dimension,group_id)); +} + +ServerError AddGroupTask::OnExecute() { + try { + engine::meta::GroupSchema group_info; + group_info.dimension = (size_t)dimension_; + group_info.group_id = group_id_; + engine::Status stat = DB()->add_group(group_info); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id), + dimension_(dimension) { + +} + +BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t& dimension) { + return std::shared_ptr(new GetGroupTask(group_id, dimension)); +} + +ServerError GetGroupTask::OnExecute() { + try { + dimension_ = 0; + + engine::meta::GroupSchema group_info; + group_info.group_id = group_id_; + engine::Status stat = DB()->get_group(group_info); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + dimension_ = (int32_t)group_info.dimension; + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +DeleteGroupTask::DeleteGroupTask(const std::string& group_id) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id) { + +} + +BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) { + return std::shared_ptr(new DeleteGroupTask(group_id)); +} + +ServerError DeleteGroupTask::OnExecute() { + try { + + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +AddVectorTask::AddVectorTask(const std::string& group_id, + const VecTensorList &tensor_list) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id), + tensor_list_(tensor_list) { + +} + +BaseTaskPtr AddVectorTask::Create(const std::string& group_id, + const VecTensorList &tensor_list) { + return std::shared_ptr(new AddVectorTask(group_id, tensor_list)); +} + +ServerError AddVectorTask::OnExecute() { + try { + std::vector vec_f; + for(const VecTensor& tensor : tensor_list_.tensor_list) { + vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + } + + engine::IDNumbers vector_ids; + engine::Status stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + if(vector_ids.size() != tensor_list_.tensor_list.size()) { + SERVER_LOG_ERROR << "Vector ID not returned"; + } else { + std::string nid_prefix = group_id_ + "_"; + for(size_t i = 0; i < vector_ids.size(); i++) { + std::string nid = nid_prefix + std::to_string(vector_ids[i]); + IVecIdMapper::GetInstance()->Put(nid, tensor_list_.tensor_list[i].uid); + } + } + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +SearchVectorTask::SearchVectorTask(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list) + : BaseTask(NORMAL_TASK_GROUP), + result_(result), + group_id_(group_id), + top_k_(top_k), + tensor_list_(tensor_list), + time_range_list_(time_range_list) { + +} + +BaseTaskPtr SearchVectorTask::Create(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list) { + return std::shared_ptr(new SearchVectorTask(result, group_id, top_k, tensor_list, time_range_list)); +} + +ServerError SearchVectorTask::OnExecute() { + try { + std::vector vec_f; + for(const VecTensor& tensor : tensor_list_.tensor_list) { + vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + } + + engine::QueryResults results; + engine::Status stat = DB()->search(group_id_, (size_t)top_k_, tensor_list_.tensor_list.size(), vec_f.data(), results); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + for(engine::QueryResult& res : results){ + VecSearchResult v_res; + std::string nid_prefix = group_id_ + "_"; + for(auto id : results[0]) { + std::string sid; + std::string nid = nid_prefix + std::to_string(id); + IVecIdMapper::GetInstance()->Get(nid, sid); + v_res.id_list.push_back(sid); + v_res.distance_list.push_back(0.0);//TODO: return distance + } + + result_.result_list.push_back(v_res); + } + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +} +} +} diff --git a/cpp/src/server/VecServiceTask.h b/cpp/src/server/VecServiceTask.h new file mode 100644 index 00000000..cf1b60be --- /dev/null +++ b/cpp/src/server/VecServiceTask.h @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#pragma once + +#include "VecServiceScheduler.h" +#include "utils/Error.h" +#include "db/Types.h" + +#include "thrift/gen-cpp/VectorService_types.h" + +#include +#include + +namespace zilliz { +namespace vecwise { +namespace server { + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class AddGroupTask : public BaseTask { +public: + static BaseTaskPtr Create(int32_t dimension, + const std::string& group_id); + +protected: + AddGroupTask(int32_t dimension, + const std::string& group_id); + + ServerError OnExecute() override; + +private: + int32_t dimension_; + std::string group_id_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class GetGroupTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id, int32_t& dimension); + +protected: + GetGroupTask(const std::string& group_id, int32_t& dimension); + + ServerError OnExecute() override; + + +private: + std::string group_id_; + int32_t& dimension_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class DeleteGroupTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id); + +protected: + DeleteGroupTask(const std::string& group_id); + + ServerError OnExecute() override; + + +private: + std::string group_id_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class AddVectorTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id, + const VecTensorList &tensor_list); + +protected: + AddVectorTask(const std::string& group_id, + const VecTensorList &tensor_list); + + ServerError OnExecute() override; + + +private: + std::string group_id_; + const VecTensorList& tensor_list_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class SearchVectorTask : public BaseTask { +public: + static BaseTaskPtr Create(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list); + +protected: + SearchVectorTask(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list); + + ServerError OnExecute() override; + + +private: + VecSearchResultList& result_; + std::string group_id_; + int64_t top_k_; + const VecTensorList& tensor_list_; + const VecTimeRangeList& time_range_list_; +}; + +} +} +} \ No newline at end of file -- GitLab