提交 75410aed 编写于 作者: G groot

refine scheduler


Former-commit-id: 9b772adf62a9f7f2ae349f3a2420fcecb08af6ce
上级 87d0ed29
......@@ -11,10 +11,11 @@ namespace vecwise {
namespace server {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BaseTask::BaseTask(const std::string& task_group)
: task_group_(task_group),
done_(false),
error_code_(SERVER_SUCCESS) {
BaseTask::BaseTask(const std::string& task_group, bool async)
: task_group_(task_group),
async_(async),
done_(false),
error_code_(SERVER_SUCCESS) {
}
......@@ -73,14 +74,6 @@ void VecServiceScheduler::Stop() {
stopped_ = true;
}
ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
return PutTaskToQueue(task_ptr);
}
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
......@@ -91,7 +84,11 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
return err;
}
return task_ptr->WaitToFinish();
if(task_ptr->IsAsync()) {
return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere
}
return task_ptr->WaitToFinish();//sync execution
}
namespace {
......
......@@ -17,7 +17,7 @@ namespace server {
class BaseTask {
protected:
BaseTask(const std::string& task_group);
BaseTask(const std::string& task_group, bool async = false);
virtual ~BaseTask();
public:
......@@ -27,6 +27,9 @@ public:
std::string TaskGroup() const { return task_group_; }
ServerError ErrorCode() const { return error_code_; }
bool IsAsync() const { return async_; }
protected:
virtual ServerError OnExecute() = 0;
......@@ -35,6 +38,7 @@ protected:
std::condition_variable finish_cond_;
std::string task_group_;
bool async_;
bool done_;
ServerError error_code_;
};
......@@ -54,9 +58,6 @@ public:
void Start();
void Stop();
//async
ServerError PushTask(const BaseTaskPtr& task_ptr);
//sync
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
protected:
......
......@@ -16,8 +16,8 @@ namespace zilliz {
namespace vecwise {
namespace server {
static const std::string NORMAL_TASK_GROUP = "normal";
static const std::string SEARCH_TASK_GROUP = "search";
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
namespace {
class DBWrapper {
......@@ -53,7 +53,7 @@ namespace {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddGroupTask::AddGroupTask(int32_t dimension,
const std::string& group_id)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
dimension_(dimension),
group_id_(group_id) {
......@@ -81,7 +81,7 @@ ServerError AddGroupTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
dimension_(dimension) {
......@@ -111,7 +111,7 @@ ServerError GetGroupTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id) {
}
......@@ -132,7 +132,7 @@ ServerError DeleteGroupTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id,
const VecTensor &tensor)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_(tensor) {
......@@ -169,7 +169,7 @@ ServerError AddSingleVectorTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
const VecTensorList &tensor_list)
: BaseTask(NORMAL_TASK_GROUP),
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_list_(tensor_list) {
......@@ -233,7 +233,7 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id,
const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result)
: BaseTask(SEARCH_TASK_GROUP),
: BaseTask(DQL_TASK_GROUP),
group_id_(group_id),
top_k_(top_k),
tensor_list_(tensor_list),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册