提交 03b28a9b 编写于 作者: Y Yu Kun

Merge remote-tracking branch 'upstream/branch-0.4.0' into branch-0.4.0


Former-commit-id: 2afda3c33fc5c4970005991d77937600904c312b
......@@ -15,7 +15,7 @@ container('milvus-build-env') {
dir ("cpp") {
sh "git config --global user.email \"test@zilliz.com\""
sh "git config --global user.name \"test\""
sh "./build.sh -t ${params.BUILD_TYPE} -k ${knowhere_build_dir}"
sh "./build.sh -t ${params.BUILD_TYPE} -k ${knowhere_build_dir} -j"
}
} catch (exc) {
updateGitlabCommitStatus name: 'Build Engine', state: 'failed'
......
......@@ -14,6 +14,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-346 - Add some implementation of scheduler to solve compile error
- MS-348 - Add ResourceFactory Test
- MS-350 - Remove knowhere submodule
- MS-354 - Add task class and interface in scheduler
- MS-355 - Add copy interface in ExcutionEngine
## New Feature
- MS-343 - Implement ResourceMgr
......
......@@ -2556,6 +2556,8 @@ macro(build_gperftools)
${GPERFTOOLS_STATIC_LIB})
endif()
ExternalProject_Add_StepDependencies(gperftools_ep build libunwind_ep)
file(MAKE_DIRECTORY "${GPERFTOOLS_INCLUDE_DIR}")
add_library(gperftools STATIC IMPORTED)
......
......@@ -42,6 +42,10 @@ public:
virtual Status Load(bool to_cache = true) = 0;
virtual Status CopyToGpu(uint64_t device_id) = 0;
virtual Status CopyToCpu() = 0;
virtual Status Merge(const std::string& location) = 0;
virtual Status Search(long n,
......
......@@ -143,6 +143,32 @@ Status ExecutionEngineImpl::Load(bool to_cache) {
return Status::OK();
}
Status ExecutionEngineImpl::CopyToGpu(uint64_t device_id) {
try {
index_ = index_->CopyToGpu(device_id);
ENGINE_LOG_DEBUG << "CPU to GPU" << device_id;
} catch (knowhere::KnowhereException &e) {
ENGINE_LOG_ERROR << e.what();
return Status::Error(e.what());
} catch (std::exception &e) {
return Status::Error(e.what());
}
return Status::OK();
}
Status ExecutionEngineImpl::CopyToCpu() {
try {
index_ = index_->CopyToCpu();
ENGINE_LOG_DEBUG << "GPU to CPU";
} catch (knowhere::KnowhereException &e) {
ENGINE_LOG_ERROR << e.what();
return Status::Error(e.what());
} catch (std::exception &e) {
return Status::Error(e.what());
}
return Status::OK();
}
Status ExecutionEngineImpl::Merge(const std::string &location) {
if (location == location_) {
return Status::Error("Cannot Merge Self");
......
......@@ -18,7 +18,7 @@ namespace engine {
class ExecutionEngineImpl : public ExecutionEngine {
public:
public:
ExecutionEngineImpl(uint16_t dimension,
const std::string &location,
......@@ -42,6 +42,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
Status Load(bool to_cache) override;
Status CopyToGpu(uint64_t device_id) override;
Status CopyToCpu() override;
Status Merge(const std::string &location) override;
Status Search(long n,
......@@ -56,12 +60,12 @@ class ExecutionEngineImpl : public ExecutionEngine {
Status Init() override;
private:
private:
VecIndexPtr CreatetVecIndex(EngineType type);
VecIndexPtr Load(const std::string &location);
protected:
protected:
VecIndexPtr index_ = nullptr;
EngineType build_type;
EngineType current_type;
......
......@@ -9,6 +9,7 @@
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "db/engine/EngineFactory.h"
#include "scheduler/task/TaskConvert.h"
namespace zilliz {
namespace milvus {
......@@ -85,6 +86,9 @@ TaskScheduler::TaskDispatchWorker() {
return true;
}
// TODO: Put task into Disk-TaskTable
// auto task = TaskConvert(task_ptr);
// DiskResourcePtr->task_table().Put(task)
//execute task
ScheduleTaskPtr next_task = task_ptr->Execute();
if(next_task != nullptr) {
......@@ -98,6 +102,7 @@ TaskScheduler::TaskDispatchWorker() {
bool
TaskScheduler::TaskWorker() {
while(true) {
// TODO: expected blocking forever
ScheduleTaskPtr task_ptr = task_queue_.Take();
if(task_ptr == nullptr) {
SERVER_LOG_INFO << "Stop db task worker thread";
......
......@@ -6,7 +6,7 @@
#pragma once
#include <vector>
#include "Task.h"
#include "task/Task.h"
#include "TaskTable.h"
#include "CacheMgr.h"
......
......@@ -9,7 +9,7 @@
#include <deque>
#include <mutex>
#include "Task.h"
#include "task/SearchTask.h"
namespace zilliz {
......
......@@ -13,7 +13,7 @@
#include <condition_variable>
#include "../TaskTable.h"
#include "../Task.h"
#include "../task/Task.h"
#include "../Cost.h"
#include "Node.h"
#include "Connection.h"
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "DeleteTask.h"
namespace zilliz {
namespace milvus {
namespace engine {
void
XDeleteTask::Load(LoadType type, uint8_t device_id) {
}
void
XDeleteTask::Execute() {
}
}
}
}
......@@ -5,25 +5,22 @@
******************************************************************************/
#pragma once
#include <string>
#include <memory>
#include "Task.h"
namespace zilliz {
namespace milvus {
namespace engine {
// dummy task
class Task {
class XDeleteTask : public Task {
public:
Task(const std::string &name) {}
void
Load(LoadType type, uint8_t device_id) override;
void
Execute() {}
Execute() override;
};
using TaskPtr = std::shared_ptr<Task>;
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "src/metrics/Metrics.h"
#include "src/utils/TimeRecorder.h"
#include "src/db/engine/EngineFactory.h"
#include "src/db/Log.h"
#include "SearchTask.h"
#include <thread>
namespace zilliz {
namespace milvus {
namespace engine {
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;
bool
NeedParallelReduce(uint64_t nq, uint64_t topk) {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode &db_config = config.GetConfig(server::CONFIG_DB);
bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, false);
if (!need_parallel) {
return false;
}
return nq * topk >= PARALLEL_REDUCE_THRESHOLD;
}
void
ParallelReduce(std::function<void(size_t, size_t)> &reduce_function, size_t max_index) {
size_t reduce_batch = PARALLEL_REDUCE_BATCH;
auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work
if (thread_count > 0) {
reduce_batch = max_index / thread_count + 1;
}
ENGINE_LOG_DEBUG << "use " << thread_count <<
" thread parallelly do reduce, each thread process " << reduce_batch << " vectors";
std::vector<std::shared_ptr<std::thread> > thread_array;
size_t from_index = 0;
while (from_index < max_index) {
size_t to_index = from_index + reduce_batch;
if (to_index > max_index) {
to_index = max_index;
}
auto reduce_thread = std::make_shared<std::thread>(reduce_function, from_index, to_index);
thread_array.push_back(reduce_thread);
from_index = to_index;
}
for (auto &thread_ptr : thread_array) {
thread_ptr->join();
}
}
void
CollectFileMetrics(int file_type, size_t file_size) {
switch (file_type) {
case meta::TableFileSchema::RAW:
case meta::TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
break;
}
default: {
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
break;
}
}
}
void
CollectDurationMetrics(int index_type, double total_time) {
switch (index_type) {
case meta::TableFileSchema::RAW: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
break;
}
case meta::TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
break;
}
default: {
server::Metrics::GetInstance().SearchIndexDataDurationSecondsHistogramObserve(total_time);
break;
}
}
}
void
XSearchTask::Load(LoadType type, uint8_t device_id) {
server::TimeRecorder rc("");
//step 1: load index
ExecutionEnginePtr index_ptr = EngineFactory::Build(file_->dimension_,
file_->location_,
(EngineType) file_->engine_type_);
try {
index_ptr->Load();
} catch (std::exception &ex) {
//typical error: out of disk space or permition denied
std::string msg = "Failed to load index file: " + std::string(ex.what());
ENGINE_LOG_ERROR << msg;
for (auto &context : search_contexts_) {
context->IndexSearchDone(file_->id_);//mark as done avoid dead lock, even failed
}
return;
}
size_t file_size = index_ptr->PhysicalSize();
std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" + std::to_string(file_->file_type_)
+ " size:" + std::to_string(file_size) + " bytes from location: " + file_->location_ + " totally cost";
double span = rc.ElapseFromBegin(info);
for (auto &context : search_contexts_) {
context->AccumLoadCost(span);
}
CollectFileMetrics(file_->file_type_, file_size);
//step 2: return search task for later execution
index_id_ = file_->id_;
index_type_ = file_->file_type_;
index_engine_ = index_ptr;
search_contexts_.swap(search_contexts_);
}
void
XSearchTask::Execute() {
if (index_engine_ == nullptr) {
return;
}
ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_ << " with "
<< search_contexts_.size() << " tasks";
server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_));
auto start_time = METRICS_NOW_TIME;
std::vector<long> output_ids;
std::vector<float> output_distence;
for (auto &context : search_contexts_) {
//step 1: allocate memory
auto inner_k = context->topk();
output_ids.resize(inner_k * context->nq());
output_distence.resize(inner_k * context->nq());
try {
//step 2: search
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
output_ids.data());
double span = rc.RecordSection("do search for context:" + context->Identity());
context->AccumSearchCost(span);
//step 3: cluster result
SearchContext::ResultSet result_set;
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
XSearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
span = rc.RecordSection("cluster result for context:" + context->Identity());
context->AccumReduceCost(span);
//step 4: pick up topk result
XSearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
span = rc.RecordSection("reduce topk for context:" + context->Identity());
context->AccumReduceCost(span);
} catch (std::exception &ex) {
ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
continue;
}
//step 5: notify to send result to client
context->IndexSearchDone(index_id_);
}
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
CollectDurationMetrics(index_type_, total_time);
rc.ElapseFromBegin("totally cost");
}
Status XSearchTask::ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
uint64_t topk,
SearchContext::ResultSet &result_set) {
if (output_ids.size() < nq * topk || output_distence.size() < nq * topk) {
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
" distance array size: " + std::to_string(output_distence.size());
ENGINE_LOG_ERROR << msg;
return Status::Error(msg);
}
result_set.clear();
result_set.resize(nq);
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
for (auto i = from_index; i < to_index; i++) {
SearchContext::Id2DistanceMap id_distance;
id_distance.reserve(topk);
for (auto k = 0; k < topk; k++) {
uint64_t index = i * topk + k;
if (output_ids[index] < 0) {
continue;
}
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
}
result_set[i] = id_distance;
}
};
if (NeedParallelReduce(nq, topk)) {
ParallelReduce(reduce_worker, nq);
} else {
reduce_worker(0, nq);
}
return Status::OK();
}
Status XSearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
SearchContext::Id2DistanceMap &distance_target,
uint64_t topk,
bool ascending) {
//Note: the score_src and score_target are already arranged by score in ascending order
if (distance_src.empty()) {
ENGINE_LOG_WARNING << "Empty distance source array";
return Status::OK();
}
if (distance_target.empty()) {
distance_target.swap(distance_src);
return Status::OK();
}
size_t src_count = distance_src.size();
size_t target_count = distance_target.size();
SearchContext::Id2DistanceMap distance_merged;
distance_merged.reserve(topk);
size_t src_index = 0, target_index = 0;
while (true) {
//all score_src items are merged, if score_merged.size() still less than topk
//move items from score_target to score_merged until score_merged.size() equal topk
if (src_index >= src_count) {
for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
distance_merged.push_back(distance_target[i]);
}
break;
}
//all score_target items are merged, if score_merged.size() still less than topk
//move items from score_src to score_merged until score_merged.size() equal topk
if (target_index >= target_count) {
for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
distance_merged.push_back(distance_src[i]);
}
break;
}
//compare score,
// if ascending = true, put smallest score to score_merged one by one
// else, put largest score to score_merged one by one
auto &src_pair = distance_src[src_index];
auto &target_pair = distance_target[target_index];
if (ascending) {
if (src_pair.second > target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
} else {
distance_merged.push_back(src_pair);
src_index++;
}
} else {
if (src_pair.second < target_pair.second) {
distance_merged.push_back(target_pair);
target_index++;
} else {
distance_merged.push_back(src_pair);
src_index++;
}
}
//score_merged.size() already equal topk
if (distance_merged.size() >= topk) {
break;
}
}
distance_target.swap(distance_merged);
return Status::OK();
}
Status XSearchTask::TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
bool ascending,
SearchContext::ResultSet &result_target) {
if (result_target.empty()) {
result_target.swap(result_src);
return Status::OK();
}
if (result_src.size() != result_target.size()) {
std::string msg = "Invalid result set size";
ENGINE_LOG_ERROR << msg;
return Status::Error(msg);
}
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
for (size_t i = from_index; i < to_index; i++) {
SearchContext::Id2DistanceMap &score_src = result_src[i];
SearchContext::Id2DistanceMap &score_target = result_target[i];
XSearchTask::MergeResult(score_src, score_target, topk, ascending);
}
};
if (NeedParallelReduce(result_src.size(), topk)) {
ParallelReduce(ReduceWorker, result_src.size());
} else {
ReduceWorker(0, result_src.size());
}
return Status::OK();
}
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "Task.h"
namespace zilliz {
namespace milvus {
namespace engine {
class XSearchTask : public Task {
public:
void
Load(LoadType type, uint8_t device_id) override;
void
Execute() override;
public:
static Status ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
uint64_t topk,
SearchContext::ResultSet &result_set);
static Status MergeResult(SearchContext::Id2DistanceMap &distance_src,
SearchContext::Id2DistanceMap &distance_target,
uint64_t topk,
bool ascending);
static Status TopkResult(SearchContext::ResultSet &result_src,
uint64_t topk,
bool ascending,
SearchContext::ResultSet &result_target);
public:
TableFileSchemaPtr file_;
size_t index_id_ = 0;
int index_type_ = 0;
ExecutionEnginePtr index_engine_ = nullptr;
bool metric_l2 = true;
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <string>
#include <memory>
#include <src/db/scheduler/context/SearchContext.h>
#include "src/db/scheduler/task/IScheduleTask.h"
namespace zilliz {
namespace milvus {
namespace engine {
enum class LoadType {
DISK2CPU,
CPU2GPU,
GPU2CPU,
};
class Task;
using TaskPtr = std::shared_ptr<Task>;
class Task {
public:
Task() = default;
virtual void
Load(LoadType type, uint8_t device_id) = 0;
virtual void
Execute() = 0;
public:
std::vector<SearchContextPtr> search_contexts_;
ScheduleTaskPtr task_;
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "TaskConvert.h"
namespace zilliz {
namespace milvus {
namespace engine {
TaskPtr
TaskConvert(const ScheduleTaskPtr &schedule_task) {
switch (schedule_task->type()) {
case ScheduleTaskType::kIndexLoad: {
auto load_task = std::static_pointer_cast<IndexLoadTask>(schedule_task);
auto task = std::make_shared<XSearchTask>();
task->file_ = load_task->file_;
task->search_contexts_ = load_task->search_contexts_;
task->task_ = schedule_task;
return task;
}
case ScheduleTaskType::kDelete: {
// TODO: convert to delete task
return nullptr;
}
default: {
// TODO: unexpected !!!
return nullptr;
}
}
}
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "src/db/scheduler/task/IndexLoadTask.h"
#include "Task.h"
#include "SearchTask.h"
namespace zilliz {
namespace milvus {
namespace engine {
TaskPtr
TaskConvert(const ScheduleTaskPtr &schedule_task);
}
}
}
......@@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() {
return type;
}
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
//if (auto new_type = GetGpuIndexType(type)) {
// auto device_index = index_->CopyToGpu(device_id);
// return std::make_shared<VecIndexImpl>(device_index, new_type);
//}
//return nullptr;
// TODO(linxj): update type
auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg);
return std::make_shared<VecIndexImpl>(gpu_index, type);
}
// TODO(linxj): rename copytocpu => copygputocpu
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg);
return std::make_shared<VecIndexImpl>(cpu_index, type);
}
float *BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); }
......
......@@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex {
const Config &cfg,
const long &nt,
const float *xt) override;
VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr CopyToCpu(const Config &cfg) override;
IndexType GetType() override;
int64_t Dimension() override;
int64_t Count() override;
......
......@@ -35,6 +35,9 @@ enum class IndexType {
NSG_MIX,
};
class VecIndex;
using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex {
public:
virtual server::KnowhereError BuildAll(const long &nb,
......@@ -55,6 +58,11 @@ class VecIndex {
long *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToGpu(const int64_t& device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0;
virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0;
......@@ -66,8 +74,6 @@ class VecIndex {
virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
};
using VecIndexPtr = std::shared_ptr<VecIndex>;
extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr read_index(const std::string &location);
......
......@@ -3,10 +3,29 @@
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/config config_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/cache cache_srcs)
aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper/knowhere knowhere_src)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler scheduler_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/context scheduler_context_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/task scheduler_task_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/resource scheduler_resource_srcs)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/task scheduler_task_srcs)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler scheduler_srcs)
aux_source_directory(./ test_srcs)
set(util_files ${MILVUS_ENGINE_SRC}/utils/ValidationUtil.cpp)
set(db_scheduler_srcs
${scheduler_files}
${scheduler_context_files}
${scheduler_task_files})
include_directories(/usr/local/cuda/include)
link_directories("/usr/local/cuda/lib64")
......@@ -17,8 +36,20 @@ include_directories(/usr/include/mysql)
set(scheduler_test_src
${unittest_srcs}
${scheduler_resource_srcs}
${scheduler_task_srcs}
${scheduler_srcs}
${test_srcs}
${config_files}
${cache_srcs}
${db_main_files}
${db_engine_files}
${db_insert_files}
${db_meta_files}
${db_scheduler_srcs}
${wrapper_src}
${knowhere_src}
${util_files}
${require_files}
)
cuda_add_executable(scheduler_test ${scheduler_test_src})
......@@ -31,7 +62,29 @@ set(scheduler_libs
mysqlpp
)
target_link_libraries(scheduler_test ${scheduler_libs} ${unittest_libs})
set(knowhere_libs
knowhere
SPTAGLibStatic
arrow
jemalloc_pic
faiss
openblas
lapack
tbb
cudart
cublas
)
if (${BUILD_FAISS_WITH_MKL} STREQUAL "ON")
set(db_libs ${db_libs} ${MKL_LIBS} ${MKL_LIBS})
else ()
set(db_libs ${db_libs}
lapack
openblas)
endif ()
target_link_libraries(scheduler_test ${scheduler_libs} ${knowhere_libs} ${unittest_libs})
install(TARGETS scheduler_test DESTINATION bin)
......@@ -23,8 +23,8 @@ TEST(normal_test, DISABLED_test1) {
res_mgr->Start();
auto task1 = std::make_shared<Task>("123456789");
auto task2 = std::make_shared<Task>("222222222");
auto task1 = std::make_shared<XSearchTask>();
auto task2 = std::make_shared<XSearchTask>();
if (auto observe = disk.lock()) {
observe->task_table().Put(task1);
observe->task_table().Put(task2);
......
......@@ -43,8 +43,8 @@ protected:
void
SetUp() override {
invalid_task_ = nullptr;
task1_ = std::make_shared<Task>("1");
task2_ = std::make_shared<Task>("2");
task1_ = std::make_shared<XSearchTask>();
task2_ = std::make_shared<XSearchTask>();
empty_table_ = TaskTable();
}
......@@ -85,7 +85,7 @@ protected:
void
SetUp() override {
for (uint64_t i = 0; i < 8; ++i) {
auto task = std::make_shared<Task>(std::to_string(i));
auto task = std::make_shared<XSearchTask>();
table1_.Put(task);
}
......
......@@ -58,6 +58,14 @@ public:
}
engine::VecIndexPtr CopyToGpu(const int64_t &device_id, const engine::Config &cfg) override {
}
engine::VecIndexPtr CopyToCpu(const engine::Config &cfg) override {
}
virtual int64_t Dimension() {
return dimension_;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册