提交 b03aafdc 编写于 作者: G groot

implement scheduler


Former-commit-id: 1be5a738138a626ddb4a7412e798c74debbc4c3a
上级 11cb43ac
...@@ -40,8 +40,11 @@ rm -rf ./cmake_build ...@@ -40,8 +40,11 @@ rm -rf ./cmake_build
mkdir cmake_build mkdir cmake_build
cd cmake_build cd cmake_build
CUDA_COMPILER=/usr/local/cuda/bin/nvcc
CMAKE_CMD="cmake -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ CMAKE_CMD="cmake -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \
$@ ../" $@ ../"
echo ${CMAKE_CMD} echo ${CMAKE_CMD}
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "VecIdMapper.h" #include "VecIdMapper.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/CommonUtil.h" #include "utils/CommonUtil.h"
#include "utils/TimeRecorder.h"
#include "db/DB.h" #include "db/DB.h"
#include "db/Env.h" #include "db/Env.h"
...@@ -73,9 +74,7 @@ VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tens ...@@ -73,9 +74,7 @@ VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tens
SERVER_LOG_INFO << "add_vector() called"; SERVER_LOG_INFO << "add_vector() called";
SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size(); SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size();
VecTensorList tensor_list; BaseTaskPtr task_ptr = AddSingleVectorTask::Create(group_id, tensor);
tensor_list.tensor_list.push_back(tensor);
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list);
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr); scheduler.ExecuteTask(task_ptr);
...@@ -88,10 +87,11 @@ VecServiceHandler::add_vector_batch(const std::string &group_id, ...@@ -88,10 +87,11 @@ VecServiceHandler::add_vector_batch(const std::string &group_id,
SERVER_LOG_INFO << "add_vector_batch() called"; SERVER_LOG_INFO << "add_vector_batch() called";
SERVER_LOG_TRACE << "group_id = " << group_id << ", vector list size = " SERVER_LOG_TRACE << "group_id = " << group_id << ", vector list size = "
<< tensor_list.tensor_list.size(); << tensor_list.tensor_list.size();
TimeRecorder rc("Add VECTOR BATCH");
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, tensor_list);
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr); scheduler.ExecuteTask(task_ptr);
rc.Elapse("DONE!");
SERVER_LOG_INFO << "add_vector_batch() finished"; SERVER_LOG_INFO << "add_vector_batch() finished";
} }
...@@ -108,28 +108,17 @@ VecServiceHandler::search_vector(VecSearchResult &_return, ...@@ -108,28 +108,17 @@ VecServiceHandler::search_vector(VecSearchResult &_return,
<< ", vector size = " << tensor.tensor.size() << ", vector size = " << tensor.tensor.size()
<< ", time range list size = " << time_range_list.range_list.size(); << ", time range list size = " << time_range_list.range_list.size();
try { VecTensorList tensor_list;
engine::QueryResults results; tensor_list.tensor_list.push_back(tensor);
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end()); VecSearchResultList result;
engine::Status stat = db_->search(group_id, (size_t)top_k, 1, vec_f.data(), results); BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, result);
if(!stat.ok()) { VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); scheduler.ExecuteTask(task_ptr);
} else {
if(!results.empty()) { if(!result.result_list.empty()) {
std::string nid_prefix = group_id + "_"; _return = result.result_list[0];
for(auto id : results[0]) { } else {
std::string sid; SERVER_LOG_ERROR << "No search result returned";
std::string nid = nid_prefix + std::to_string(id);
IVecIdMapper::GetInstance()->Get(nid, sid);
_return.id_list.push_back(sid);
_return.distance_list.push_back(0.0);//TODO: return distance
}
}
}
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
} }
SERVER_LOG_INFO << "search_vector() finished"; SERVER_LOG_INFO << "search_vector() finished";
...@@ -146,36 +135,9 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return, ...@@ -146,36 +135,9 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
<< ", vector list size = " << tensor_list.tensor_list.size() << ", vector list size = " << tensor_list.tensor_list.size()
<< ", time range list size = " << time_range_list.range_list.size(); << ", time range list size = " << time_range_list.range_list.size();
try { BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, _return);
std::vector<float> vec_f; VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
for(const VecTensor& tensor : tensor_list.tensor_list) { scheduler.ExecuteTask(task_ptr);
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
}
engine::QueryResults results;
engine::Status stat = db_->search(group_id, (size_t)top_k, tensor_list.tensor_list.size(), vec_f.data(), results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
for(engine::QueryResult& res : results){
VecSearchResult v_res;
std::string nid_prefix = group_id + "_";
for(auto id : results[0]) {
std::string sid;
std::string nid = nid_prefix + std::to_string(id);
IVecIdMapper::GetInstance()->Get(nid, sid);
v_res.id_list.push_back(sid);
v_res.distance_list.push_back(0.0);//TODO: return distance
}
_return.result_list.push_back(v_res);
}
}
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
SERVER_LOG_INFO << "search_vector_batch() finished"; SERVER_LOG_INFO << "search_vector_batch() finished";
} }
......
...@@ -25,6 +25,7 @@ BaseTask::~BaseTask() { ...@@ -25,6 +25,7 @@ BaseTask::~BaseTask() {
ServerError BaseTask::Execute() { ServerError BaseTask::Execute() {
error_code_ = OnExecute(); error_code_ = OnExecute();
done_ = true; done_ = true;
finish_cond_.notify_all();
return error_code_; return error_code_;
} }
...@@ -72,7 +73,7 @@ void VecServiceScheduler::Stop() { ...@@ -72,7 +73,7 @@ void VecServiceScheduler::Stop() {
stopped_ = true; stopped_ = true;
} }
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) { if(task_ptr == nullptr) {
return SERVER_NULL_POINTER; return SERVER_NULL_POINTER;
} }
...@@ -80,6 +81,19 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { ...@@ -80,6 +81,19 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
return PutTaskToQueue(task_ptr); return PutTaskToQueue(task_ptr);
} }
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
ServerError err = PutTaskToQueue(task_ptr);
if(err != SERVER_SUCCESS) {
return err;
}
return task_ptr->WaitToFinish();
}
namespace { namespace {
void TakeTaskToExecute(TaskQueuePtr task_queue) { void TakeTaskToExecute(TaskQueuePtr task_queue) {
if(task_queue == nullptr) { if(task_queue == nullptr) {
...@@ -120,6 +134,8 @@ ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { ...@@ -120,6 +134,8 @@ ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
execute_threads_.push_back(thread); execute_threads_.push_back(thread);
SERVER_LOG_INFO << "Create new thread for task group: " << group_name; SERVER_LOG_INFO << "Create new thread for task group: " << group_name;
} }
return SERVER_SUCCESS;
} }
} }
......
...@@ -54,6 +54,9 @@ public: ...@@ -54,6 +54,9 @@ public:
void Start(); void Start();
void Stop(); void Stop();
//async
ServerError PushTask(const BaseTaskPtr& task_ptr);
//sync
ServerError ExecuteTask(const BaseTaskPtr& task_ptr); ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
protected: protected:
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "VecIdMapper.h" #include "VecIdMapper.h"
#include "utils/CommonUtil.h" #include "utils/CommonUtil.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "db/DB.h" #include "db/DB.h"
#include "db/Env.h" #include "db/Env.h"
...@@ -16,6 +17,7 @@ namespace vecwise { ...@@ -16,6 +17,7 @@ namespace vecwise {
namespace server { namespace server {
static const std::string NORMAL_TASK_GROUP = "normal"; static const std::string NORMAL_TASK_GROUP = "normal";
static const std::string SEARCH_TASK_GROUP = "search";
namespace { namespace {
class DBWrapper { class DBWrapper {
...@@ -128,7 +130,44 @@ ServerError DeleteGroupTask::OnExecute() { ...@@ -128,7 +130,44 @@ ServerError DeleteGroupTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& group_id, AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id,
const VecTensor &tensor)
: BaseTask(NORMAL_TASK_GROUP),
group_id_(group_id),
tensor_(tensor) {
}
BaseTaskPtr AddSingleVectorTask::Create(const std::string& group_id,
const VecTensor &tensor) {
return std::shared_ptr<BaseTask>(new AddSingleVectorTask(group_id, tensor));
}
ServerError AddSingleVectorTask::OnExecute() {
try {
engine::IDNumbers vector_ids;
std::vector<float> vec_f(tensor_.tensor.begin(), tensor_.tensor.end());
engine::Status stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
if(vector_ids.empty()) {
SERVER_LOG_ERROR << "Vector ID not returned";
} else {
std::string nid = group_id_ + "_" + std::to_string(vector_ids[0]);
IVecIdMapper::GetInstance()->Put(nid, tensor_.uid);
SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", sid = " << tensor_.uid;
}
}
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
const VecTensorList &tensor_list) const VecTensorList &tensor_list)
: BaseTask(NORMAL_TASK_GROUP), : BaseTask(NORMAL_TASK_GROUP),
group_id_(group_id), group_id_(group_id),
...@@ -136,31 +175,50 @@ AddVectorTask::AddVectorTask(const std::string& group_id, ...@@ -136,31 +175,50 @@ AddVectorTask::AddVectorTask(const std::string& group_id,
} }
BaseTaskPtr AddVectorTask::Create(const std::string& group_id, BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
const VecTensorList &tensor_list) { const VecTensorList &tensor_list) {
return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor_list)); return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list));
} }
ServerError AddVectorTask::OnExecute() { ServerError AddBatchVectorTask::OnExecute() {
try { try {
TimeRecorder rc("Add vector batch");
engine::meta::GroupSchema group_info;
group_info.group_id = group_id_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
}
std::vector<float> vec_f; std::vector<float> vec_f;
vec_f.reserve(tensor_list_.tensor_list.size()*group_info.dimension*4);
for(const VecTensor& tensor : tensor_list_.tensor_list) { for(const VecTensor& tensor : tensor_list_.tensor_list) {
if(tensor.tensor.size() != group_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector data size: " << tensor.tensor.size()
<< " vs. group dimension:" << group_info.dimension;
return SERVER_UNEXPECTED_ERROR;
}
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
} }
rc.Record("prepare vectors data");
engine::IDNumbers vector_ids; engine::IDNumbers vector_ids;
engine::Status stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids); stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids);
rc.Record("add vectors to engine");
if(!stat.ok()) { if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else { } else {
if(vector_ids.size() != tensor_list_.tensor_list.size()) { if(vector_ids.size() < tensor_list_.tensor_list.size()) {
SERVER_LOG_ERROR << "Vector ID not returned"; SERVER_LOG_ERROR << "Vector ID not returned";
} else { } else {
std::string nid_prefix = group_id_ + "_"; std::string nid_prefix = group_id_ + "_";
for(size_t i = 0; i < vector_ids.size(); i++) { for(size_t i = 0; i < tensor_list_.tensor_list.size(); i++) {
std::string nid = nid_prefix + std::to_string(vector_ids[i]); std::string nid = nid_prefix + std::to_string(vector_ids[i]);
IVecIdMapper::GetInstance()->Put(nid, tensor_list_.tensor_list[i].uid); IVecIdMapper::GetInstance()->Put(nid, tensor_list_.tensor_list[i].uid);
} }
rc.Record("build id mapping");
} }
} }
...@@ -170,26 +228,26 @@ ServerError AddVectorTask::OnExecute() { ...@@ -170,26 +228,26 @@ ServerError AddVectorTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(VecSearchResultList& result, SearchVectorTask::SearchVectorTask(const std::string& group_id,
const std::string& group_id,
const int64_t top_k, const int64_t top_k,
const VecTensorList& tensor_list, const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list) const VecTimeRangeList& time_range_list,
: BaseTask(NORMAL_TASK_GROUP), VecSearchResultList& result)
result_(result), : BaseTask(SEARCH_TASK_GROUP),
group_id_(group_id), group_id_(group_id),
top_k_(top_k), top_k_(top_k),
tensor_list_(tensor_list), tensor_list_(tensor_list),
time_range_list_(time_range_list) { time_range_list_(time_range_list),
result_(result) {
} }
BaseTaskPtr SearchVectorTask::Create(VecSearchResultList& result, BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
const std::string& group_id,
const int64_t top_k, const int64_t top_k,
const VecTensorList& tensor_list, const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list) { const VecTimeRangeList& time_range_list,
return std::shared_ptr<BaseTask>(new SearchVectorTask(result, group_id, top_k, tensor_list, time_range_list)); VecSearchResultList& result) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, time_range_list, result));
} }
ServerError SearchVectorTask::OnExecute() { ServerError SearchVectorTask::OnExecute() {
...@@ -213,6 +271,9 @@ ServerError SearchVectorTask::OnExecute() { ...@@ -213,6 +271,9 @@ ServerError SearchVectorTask::OnExecute() {
IVecIdMapper::GetInstance()->Get(nid, sid); IVecIdMapper::GetInstance()->Get(nid, sid);
v_res.id_list.push_back(sid); v_res.id_list.push_back(sid);
v_res.distance_list.push_back(0.0);//TODO: return distance v_res.distance_list.push_back(0.0);//TODO: return distance
SERVER_LOG_TRACE << "nid = " << nid << ", string id = " << sid;
} }
result_.result_list.push_back(v_res); result_.result_list.push_back(v_res);
......
...@@ -67,13 +67,31 @@ private: ...@@ -67,13 +67,31 @@ private:
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask { class AddSingleVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const VecTensor &tensor);
protected:
AddSingleVectorTask(const std::string& group_id,
const VecTensor &tensor);
ServerError OnExecute() override;
private:
std::string group_id_;
const VecTensor& tensor_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddBatchVectorTask : public BaseTask {
public: public:
static BaseTaskPtr Create(const std::string& group_id, static BaseTaskPtr Create(const std::string& group_id,
const VecTensorList &tensor_list); const VecTensorList &tensor_list);
protected: protected:
AddVectorTask(const std::string& group_id, AddBatchVectorTask(const std::string& group_id,
const VecTensorList &tensor_list); const VecTensorList &tensor_list);
ServerError OnExecute() override; ServerError OnExecute() override;
...@@ -87,28 +105,28 @@ private: ...@@ -87,28 +105,28 @@ private:
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask { class SearchVectorTask : public BaseTask {
public: public:
static BaseTaskPtr Create(VecSearchResultList& result, static BaseTaskPtr Create(const std::string& group_id,
const std::string& group_id,
const int64_t top_k, const int64_t top_k,
const VecTensorList& tensor_list, const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list); const VecTimeRangeList& time_range_list,
VecSearchResultList& result);
protected: protected:
SearchVectorTask(VecSearchResultList& result, SearchVectorTask(const std::string& group_id,
const std::string& group_id,
const int64_t top_k, const int64_t top_k,
const VecTensorList& tensor_list, const VecTensorList& tensor_list,
const VecTimeRangeList& time_range_list); const VecTimeRangeList& time_range_list,
VecSearchResultList& result);
ServerError OnExecute() override; ServerError OnExecute() override;
private: private:
VecSearchResultList& result_;
std::string group_id_; std::string group_id_;
int64_t top_k_; int64_t top_k_;
const VecTensorList& tensor_list_; const VecTensorList& tensor_list_;
const VecTimeRangeList& time_range_list_; const VecTimeRangeList& time_range_list_;
VecSearchResultList& result_;
}; };
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include <utils/TimeRecorder.h>
#include "ClientApp.h" #include "ClientApp.h"
#include "ClientSession.h" #include "ClientSession.h"
#include "server/ServerConfig.h" #include "server/ServerConfig.h"
...@@ -37,21 +38,44 @@ void ClientApp::Run(const std::string &config_file) { ...@@ -37,21 +38,44 @@ void ClientApp::Run(const std::string &config_file) {
group.index_type = 0; group.index_type = 0;
session.interface()->add_group(group); session.interface()->add_group(group);
//add vectors const int64_t count = 500;
for(int64_t k = 0; k < 10000; k++) { //add vectors one by one
VecTensor tensor; {
for(int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double)(i + k));
}
tensor.uid = "vec_" + std::to_string(k);
session.interface()->add_vector(group.id, tensor); server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
for (int64_t k = 0; k < count; k++) {
VecTensor tensor;
for (int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double) (i + k));
}
tensor.uid = "vec_" + std::to_string(k);
CLIENT_LOG_INFO << "add vector no." << k; session.interface()->add_vector(group.id, tensor);
CLIENT_LOG_INFO << "add vector no." << k;
}
rc.Elapse("done!");
}
//add vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
VecTensorList vec_list;
for (int64_t k = 0; k < count; k++) {
VecTensor tensor;
for (int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double) (i + k));
}
tensor.uid = "vec_" + std::to_string(k);
vec_list.tensor_list.push_back(tensor);
}
session.interface()->add_vector_batch(group.id, vec_list);
rc.Elapse("done!");
} }
//search vector //search vector
{ {
server::TimeRecorder rc("Search top_k");
VecTensor tensor; VecTensor tensor;
for (int32_t i = 0; i < dim; i++) { for (int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double) (i + 100)); tensor.tensor.push_back((double) (i + 100));
...@@ -65,6 +89,7 @@ void ClientApp::Run(const std::string &config_file) { ...@@ -65,6 +89,7 @@ void ClientApp::Run(const std::string &config_file) {
for(auto id : res.id_list) { for(auto id : res.id_list) {
std::cout << id << std::endl; std::cout << id << std::endl;
} }
rc.Elapse("done!");
} }
} catch (std::exception& ex) { } catch (std::exception& ex) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册