提交 a79017ef 编写于 作者: W wxyu

MS-389 Add clone interface in Task


Former-commit-id: bf4681eec6cc6e6b21087b82e5bb5676ec3418d9
上级 3b31d63b
......@@ -34,6 +34,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-380 - Update resource loader and executor, work util all finished
- MS-383 - Modify condition variable usage in scheduler
- MS-384 - Add global instance of ResourceMgr and Scheduler
- MS-389 - Add clone interface in Task
## New Feature
- MS-343 - Implement ResourceMgr
......
......@@ -20,6 +20,11 @@ XDeleteTask::Execute() {
}
TaskPtr
XDeleteTask::Clone() {
return nullptr;
}
}
}
}
......@@ -19,6 +19,9 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
};
}
......
......@@ -99,16 +99,27 @@ CollectDurationMetrics(int index_type, double total_time) {
}
}
XSearchTask::XSearchTask(TableFileSchemaPtr file) : file_(file) {
index_engine_ = EngineFactory::Build(file_->dimension_,
file_->location_,
(EngineType) file_->engine_type_);
}
void
XSearchTask::Load(LoadType type, uint8_t device_id) {
server::TimeRecorder rc("");
//step 1: load index
ExecutionEnginePtr index_ptr = EngineFactory::Build(file_->dimension_,
file_->location_,
(EngineType) file_->engine_type_);
try {
index_ptr->Load();
if (type == LoadType::DISK2CPU) {
index_engine_->Load();
} else if (type == LoadType::CPU2GPU) {
index_engine_->Load();
index_engine_->CopyToGpu(device_id);
} else if (type == LoadType::GPU2CPU) {
index_engine_->CopyToCpu();
} else {
// TODO: exception
}
} catch (std::exception &ex) {
//typical error: out of disk space or permition denied
std::string msg = "Failed to load index file: " + std::string(ex.what());
......@@ -121,7 +132,7 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
return;
}
size_t file_size = index_ptr->PhysicalSize();
size_t file_size = index_engine_->PhysicalSize();
std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" + std::to_string(file_->file_type_)
+ " size:" + std::to_string(file_size) + " bytes from location: " + file_->location_ + " totally cost";
......@@ -135,7 +146,6 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
//step 2: return search task for later execution
index_id_ = file_->id_;
index_type_ = file_->file_type_;
index_engine_ = index_ptr;
search_contexts_.swap(search_contexts_);
}
......@@ -157,12 +167,13 @@ XSearchTask::Execute() {
for (auto &context : search_contexts_) {
//step 1: allocate memory
auto inner_k = context->topk();
auto nprobe = context->nprobe();
output_ids.resize(inner_k * context->nq());
output_distence.resize(inner_k * context->nq());
try {
//step 2: search
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distence.data(),
output_ids.data());
double span = rc.RecordSection("do search for context:" + context->Identity());
......@@ -199,6 +210,16 @@ XSearchTask::Execute() {
rc.ElapseFromBegin("totally cost");
}
TaskPtr
XSearchTask::Clone() {
auto ret = std::make_shared<XSearchTask>(file_);
ret->index_id_ = index_id_;
ret->index_engine_ = index_engine_->Clone();
ret->search_contexts_ = search_contexts_;
ret->metric_l2 = metric_l2;
return ret;
}
Status XSearchTask::ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
......@@ -343,6 +364,7 @@ Status XSearchTask::TopkResult(SearchContext::ResultSet &result_src,
return Status::OK();
}
}
}
}
......@@ -14,12 +14,18 @@ namespace engine {
class XSearchTask : public Task {
public:
explicit
XSearchTask(TableFileSchemaPtr file);
void
Load(LoadType type, uint8_t device_id) override;
void
Execute() override;
TaskPtr
Clone() override;
public:
static Status ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
......
......@@ -35,6 +35,9 @@ public:
virtual void
Execute() = 0;
virtual TaskPtr
Clone() = 0;
public:
std::vector<SearchContextPtr> search_contexts_;
ScheduleTaskPtr task_;
......
......@@ -16,8 +16,7 @@ TaskConvert(const ScheduleTaskPtr &schedule_task) {
switch (schedule_task->type()) {
case ScheduleTaskType::kIndexLoad: {
auto load_task = std::static_pointer_cast<IndexLoadTask>(schedule_task);
auto task = std::make_shared<XSearchTask>();
task->file_ = load_task->file_;
auto task = std::make_shared<XSearchTask>(load_task->file_);
task->search_contexts_ = load_task->search_contexts_;
task->task_ = schedule_task;
return task;
......
#include "scheduler/TaskTable.h"
#include "scheduler/Cost.h"
#include <gtest/gtest.h>
#include "scheduler/task/TestTask.h"
using namespace zilliz::milvus::engine;
......@@ -10,7 +11,7 @@ protected:
void
SetUp() override {
for (uint64_t i = 0; i < 8; ++i) {
auto task = std::make_shared<XSearchTask>();
auto task = std::make_shared<TestTask>();
table_.Put(task);
}
table_.Get(0)->state = TaskTableItemState::INVALID;
......
#include "scheduler/TaskTable.h"
#include "scheduler/task/TestTask.h"
#include <gtest/gtest.h>
......@@ -43,8 +44,8 @@ protected:
void
SetUp() override {
invalid_task_ = nullptr;
task1_ = std::make_shared<XSearchTask>();
task2_ = std::make_shared<XSearchTask>();
task1_ = std::make_shared<TestTask>();
task2_ = std::make_shared<TestTask>();
}
TaskPtr invalid_task_;
......@@ -83,7 +84,7 @@ protected:
void
SetUp() override {
for (uint64_t i = 0; i < 8; ++i) {
auto task = std::make_shared<XSearchTask>();
auto task = std::make_shared<TestTask>();
table1_.Put(task);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册