提交 75410aed 编写于 作者: G groot

refine scheduler


Former-commit-id: 9b772adf62a9f7f2ae349f3a2420fcecb08af6ce
上级 87d0ed29
...@@ -11,8 +11,9 @@ namespace vecwise { ...@@ -11,8 +11,9 @@ namespace vecwise {
namespace server { namespace server {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BaseTask::BaseTask(const std::string& task_group) BaseTask::BaseTask(const std::string& task_group, bool async)
: task_group_(task_group), : task_group_(task_group),
async_(async),
done_(false), done_(false),
error_code_(SERVER_SUCCESS) { error_code_(SERVER_SUCCESS) {
...@@ -73,14 +74,6 @@ void VecServiceScheduler::Stop() { ...@@ -73,14 +74,6 @@ void VecServiceScheduler::Stop() {
stopped_ = true; stopped_ = true;
} }
ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
return PutTaskToQueue(task_ptr);
}
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) { if(task_ptr == nullptr) {
return SERVER_NULL_POINTER; return SERVER_NULL_POINTER;
...@@ -91,7 +84,11 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { ...@@ -91,7 +84,11 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
return err; return err;
} }
return task_ptr->WaitToFinish(); if(task_ptr->IsAsync()) {
return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere
}
return task_ptr->WaitToFinish();//sync execution
} }
namespace { namespace {
......
...@@ -17,7 +17,7 @@ namespace server { ...@@ -17,7 +17,7 @@ namespace server {
class BaseTask { class BaseTask {
protected: protected:
BaseTask(const std::string& task_group); BaseTask(const std::string& task_group, bool async = false);
virtual ~BaseTask(); virtual ~BaseTask();
public: public:
...@@ -27,6 +27,9 @@ public: ...@@ -27,6 +27,9 @@ public:
std::string TaskGroup() const { return task_group_; } std::string TaskGroup() const { return task_group_; }
ServerError ErrorCode() const { return error_code_; } ServerError ErrorCode() const { return error_code_; }
bool IsAsync() const { return async_; }
protected: protected:
virtual ServerError OnExecute() = 0; virtual ServerError OnExecute() = 0;
...@@ -35,6 +38,7 @@ protected: ...@@ -35,6 +38,7 @@ protected:
std::condition_variable finish_cond_; std::condition_variable finish_cond_;
std::string task_group_; std::string task_group_;
bool async_;
bool done_; bool done_;
ServerError error_code_; ServerError error_code_;
}; };
...@@ -54,9 +58,6 @@ public: ...@@ -54,9 +58,6 @@ public:
void Start(); void Start();
void Stop(); void Stop();
//async
ServerError PushTask(const BaseTaskPtr& task_ptr);
//sync
ServerError ExecuteTask(const BaseTaskPtr& task_ptr); ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
protected: protected:
......
...@@ -16,8 +16,8 @@ namespace zilliz { ...@@ -16,8 +16,8 @@ namespace zilliz {
namespace vecwise { namespace vecwise {
namespace server { namespace server {
static const std::string NORMAL_TASK_GROUP = "normal"; static const std::string DQL_TASK_GROUP = "dql";
static const std::string SEARCH_TASK_GROUP = "search"; static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
namespace { namespace {
class DBWrapper { class DBWrapper {
...@@ -53,7 +53,7 @@ namespace { ...@@ -53,7 +53,7 @@ namespace {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddGroupTask::AddGroupTask(int32_t dimension, AddGroupTask::AddGroupTask(int32_t dimension,
const std::string& group_id) const std::string& group_id)
: BaseTask(NORMAL_TASK_GROUP), : BaseTask(DDL_DML_TASK_GROUP),
dimension_(dimension), dimension_(dimension),
group_id_(group_id) { group_id_(group_id) {
...@@ -81,7 +81,7 @@ ServerError AddGroupTask::OnExecute() { ...@@ -81,7 +81,7 @@ ServerError AddGroupTask::OnExecute() {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension) GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
: BaseTask(NORMAL_TASK_GROUP), : BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id), group_id_(group_id),
dimension_(dimension) { dimension_(dimension) {
...@@ -111,7 +111,7 @@ ServerError GetGroupTask::OnExecute() { ...@@ -111,7 +111,7 @@ ServerError GetGroupTask::OnExecute() {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteGroupTask::DeleteGroupTask(const std::string& group_id) DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
: BaseTask(NORMAL_TASK_GROUP), : BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id) { group_id_(group_id) {
} }
...@@ -132,7 +132,7 @@ ServerError DeleteGroupTask::OnExecute() { ...@@ -132,7 +132,7 @@ ServerError DeleteGroupTask::OnExecute() {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id, AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id,
const VecTensor &tensor) const VecTensor &tensor)
: BaseTask(NORMAL_TASK_GROUP), : BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id), group_id_(group_id),
tensor_(tensor) { tensor_(tensor) {
...@@ -169,7 +169,7 @@ ServerError AddSingleVectorTask::OnExecute() { ...@@ -169,7 +169,7 @@ ServerError AddSingleVectorTask::OnExecute() {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id, AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
const VecTensorList &tensor_list) const VecTensorList &tensor_list)
: BaseTask(NORMAL_TASK_GROUP), : BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id), group_id_(group_id),
tensor_list_(tensor_list) { tensor_list_(tensor_list) {
...@@ -233,7 +233,7 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id, ...@@ -233,7 +233,7 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id,
const VecTensorList& tensor_list, const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list, const VecTimeRangeList& time_range_list,
VecSearchResultList& result) VecSearchResultList& result)
: BaseTask(SEARCH_TASK_GROUP), : BaseTask(DQL_TASK_GROUP),
group_id_(group_id), group_id_(group_id),
top_k_(top_k), top_k_(top_k),
tensor_list_(tensor_list), tensor_list_(tensor_list),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册