/******************************************************************************* * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Unauthorized copying of this file, via any medium is strictly prohibited. * Proprietary and confidential. ******************************************************************************/ #include "MegasearchScheduler.h" #include "utils/Log.h" #include "megasearch_types.h" #include "megasearch_constants.h" namespace zilliz { namespace vecwise { namespace server { using namespace megasearch; namespace { const std::map &ErrorMap() { static const std::map code_map = { {SERVER_UNEXPECTED_ERROR, thrift::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_NULL_POINTER, thrift::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_INVALID_ARGUMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_FILE_NOT_FOUND, thrift::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_NOT_IMPLEMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_BLOCKING_QUEUE_EMPTY, thrift::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_GROUP_NOT_EXIST, thrift::ErrorCode::TABLE_NOT_EXISTS}, {SERVER_INVALID_TIME_RANGE, thrift::ErrorCode::ILLEGAL_RANGE}, {SERVER_INVALID_VECTOR_DIMENSION, thrift::ErrorCode::ILLEGAL_DIMENSION}, }; return code_map; } const std::map &ErrorMessage() { static const std::map msg_map = { {SERVER_UNEXPECTED_ERROR, "unexpected error occurs"}, {SERVER_NULL_POINTER, "null pointer error"}, {SERVER_INVALID_ARGUMENT, "invalid argument"}, {SERVER_FILE_NOT_FOUND, "file not found"}, {SERVER_NOT_IMPLEMENT, "not implemented"}, {SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"}, {SERVER_GROUP_NOT_EXIST, "group not exist"}, {SERVER_INVALID_TIME_RANGE, "invalid time range"}, {SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"}, }; return msg_map; } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// BaseTask::BaseTask(const std::string& task_group, bool async) : task_group_(task_group), async_(async), done_(false), error_code_(SERVER_SUCCESS) { } BaseTask::~BaseTask() { WaitToFinish(); } ServerError BaseTask::Execute() { error_code_ = OnExecute(); done_ = true; finish_cond_.notify_all(); return error_code_; } ServerError BaseTask::WaitToFinish() { std::unique_lock lock(finish_mtx_); finish_cond_.wait(lock, [this] { return done_; }); return error_code_; } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// MegasearchScheduler::MegasearchScheduler() : stopped_(false) { Start(); } MegasearchScheduler::~MegasearchScheduler() { Stop(); } void MegasearchScheduler::ExecTask(BaseTaskPtr& task_ptr) { if(task_ptr == nullptr) { return; } MegasearchScheduler& scheduler = MegasearchScheduler::GetInstance(); scheduler.ExecuteTask(task_ptr); if(!task_ptr->IsAsync()) { task_ptr->WaitToFinish(); ServerError err = task_ptr->ErrorCode(); if (err != SERVER_SUCCESS) { thrift::Exception ex; ex.__set_code(ErrorMap().at(err)); std::string msg = task_ptr->ErrorMsg(); if(msg.empty()){ msg = ErrorMessage().at(err); } ex.__set_reason(msg); throw ex; } } } void MegasearchScheduler::Start() { if(!stopped_) { return; } stopped_ = false; } void MegasearchScheduler::Stop() { if(stopped_) { return; } SERVER_LOG_INFO << "Scheduler gonna stop..."; { std::lock_guard lock(queue_mtx_); for(auto iter : task_groups_) { if(iter.second != nullptr) { iter.second->Put(nullptr); } } } for(auto iter : execute_threads_) { if(iter == nullptr) continue; iter->join(); } stopped_ = true; SERVER_LOG_INFO << "Scheduler stopped"; } ServerError MegasearchScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { if(task_ptr == nullptr) { return SERVER_NULL_POINTER; } ServerError err = PutTaskToQueue(task_ptr); if(err != SERVER_SUCCESS) { return err; } if(task_ptr->IsAsync()) { return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere } return task_ptr->WaitToFinish();//sync execution } namespace { void TakeTaskToExecute(TaskQueuePtr task_queue) { if(task_queue == nullptr) { return; } while(true) { BaseTaskPtr task = task_queue->Take(); if (task == nullptr) { break;//stop the thread } try { ServerError err = task->Execute(); if(err != SERVER_SUCCESS) { SERVER_LOG_ERROR << "Task failed with code: " << err; } } catch (std::exception& ex) { SERVER_LOG_ERROR << "Task failed to execute: " << ex.what(); } } } } ServerError MegasearchScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { std::lock_guard lock(queue_mtx_); std::string group_name = task_ptr->TaskGroup(); if(task_groups_.count(group_name) > 0) { task_groups_[group_name]->Put(task_ptr); } else { TaskQueuePtr queue = std::make_shared(); queue->Put(task_ptr); task_groups_.insert(std::make_pair(group_name, queue)); //start a thread ThreadPtr thread = std::make_shared(&TakeTaskToExecute, queue); execute_threads_.push_back(thread); SERVER_LOG_INFO << "Create new thread for task group: " << group_name; } return SERVER_SUCCESS; } } } }