提交 37b519ff 编写于 作者: J jinhai

Merge branch 'branch-0.5.0' into 'branch-0.5.0'

MS-609 Update task construct function

See merge request megasearch/milvus!658

Former-commit-id: efa7467a04e40727dd47ce8db8a85a6169c774ea
...@@ -23,6 +23,8 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -23,6 +23,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-574 - Milvus configuration refactor - MS-574 - Milvus configuration refactor
- MS-578 - Make sure milvus5.0 don't crack 0.3.1 data - MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
- MS-585 - Update namespace in scheduler - MS-585 - Update namespace in scheduler
- MS-608 - Update TODO names
- MS-609 - Update task construct function
## New Feature ## New Feature
......
...@@ -74,7 +74,7 @@ ResourceMgr::Connect(const std::string& name1, const std::string& name2, Connect ...@@ -74,7 +74,7 @@ ResourceMgr::Connect(const std::string& name1, const std::string& name2, Connect
auto res2 = GetResource(name2); auto res2 = GetResource(name2);
if (res1 && res2) { if (res1 && res2) {
res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection); res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection);
// TODO(wxy): enable when task balance supported // TODO(wxyu): enable when task balance supported
// res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection); // res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection);
return true; return true;
} }
......
...@@ -64,7 +64,7 @@ class ResourceMgr { ...@@ -64,7 +64,7 @@ class ResourceMgr {
return disk_resources_; return disk_resources_;
} }
// TODO(wxy): why return shared pointer // TODO(wxyu): why return shared pointer
inline std::vector<ResourcePtr> inline std::vector<ResourcePtr>
GetAllResources() { GetAllResources() {
return resources_; return resources_;
...@@ -89,7 +89,7 @@ class ResourceMgr { ...@@ -89,7 +89,7 @@ class ResourceMgr {
GetNumGpuResource() const; GetNumGpuResource() const;
public: public:
// TODO(wxy): add stats interface(low) // TODO(wxyu): add stats interface(low)
public: public:
/******** Utility Functions ********/ /******** Utility Functions ********/
......
...@@ -146,7 +146,7 @@ load_advance_config() { ...@@ -146,7 +146,7 @@ load_advance_config() {
// } // }
// } catch (const char *msg) { // } catch (const char *msg) {
// SERVER_LOG_ERROR << msg; // SERVER_LOG_ERROR << msg;
// // TODO(wxy): throw exception instead // // TODO(wxyu): throw exception instead
// exit(-1); // exit(-1);
//// throw std::exception(); //// throw std::exception();
// } // }
......
...@@ -92,7 +92,7 @@ Scheduler::Process(const EventPtr& event) { ...@@ -92,7 +92,7 @@ Scheduler::Process(const EventPtr& event) {
process_event(event); process_event(event);
} }
// TODO(wxy): refactor the function // TODO(wxyu): refactor the function
void void
Scheduler::OnLoadCompleted(const EventPtr& event) { Scheduler::OnLoadCompleted(const EventPtr& event) {
auto load_completed_event = std::static_pointer_cast<LoadCompletedEvent>(event); auto load_completed_event = std::static_pointer_cast<LoadCompletedEvent>(event);
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
// TODO(wxy): refactor, not friendly to unittest, logical in framework code // TODO(wxyu): refactor, not friendly to unittest, logical in framework code
class Scheduler { class Scheduler {
public: public:
explicit Scheduler(ResourceMgrWPtr res_mgr); explicit Scheduler(ResourceMgrWPtr res_mgr);
......
...@@ -32,7 +32,7 @@ TaskCreator::Create(const JobPtr& job) { ...@@ -32,7 +32,7 @@ TaskCreator::Create(const JobPtr& job) {
return Create(std::static_pointer_cast<DeleteJob>(job)); return Create(std::static_pointer_cast<DeleteJob>(job));
} }
default: { default: {
// TODO(wxy): error // TODO(wxyu): error
return std::vector<TaskPtr>(); return std::vector<TaskPtr>();
} }
} }
...@@ -42,8 +42,8 @@ std::vector<TaskPtr> ...@@ -42,8 +42,8 @@ std::vector<TaskPtr>
TaskCreator::Create(const SearchJobPtr& job) { TaskCreator::Create(const SearchJobPtr& job) {
std::vector<TaskPtr> tasks; std::vector<TaskPtr> tasks;
for (auto& index_file : job->index_files()) { for (auto& index_file : job->index_files()) {
auto task = std::make_shared<XSearchTask>(index_file.second); auto label = std::make_shared<DefaultLabel>();
task->label() = std::make_shared<DefaultLabel>(); auto task = std::make_shared<XSearchTask>(index_file.second, label);
task->job_ = job; task->job_ = job;
tasks.emplace_back(task); tasks.emplace_back(task);
} }
...@@ -54,8 +54,8 @@ TaskCreator::Create(const SearchJobPtr& job) { ...@@ -54,8 +54,8 @@ TaskCreator::Create(const SearchJobPtr& job) {
std::vector<TaskPtr> std::vector<TaskPtr>
TaskCreator::Create(const DeleteJobPtr& job) { TaskCreator::Create(const DeleteJobPtr& job) {
std::vector<TaskPtr> tasks; std::vector<TaskPtr> tasks;
auto task = std::make_shared<XDeleteTask>(job); auto label = std::make_shared<BroadcastLabel>();
task->label() = std::make_shared<BroadcastLabel>(); auto task = std::make_shared<XDeleteTask>(job, label);
task->job_ = job; task->job_ = job;
tasks.emplace_back(task); tasks.emplace_back(task);
......
...@@ -125,7 +125,7 @@ class TaskTable { ...@@ -125,7 +125,7 @@ class TaskTable {
Get(uint64_t index); Get(uint64_t index);
/* /*
* TODO(wxy): BIG GC * TODO(wxyu): BIG GC
* Remove sequence task which is DONE or MOVED from front; * Remove sequence task which is DONE or MOVED from front;
* Called by ? * Called by ?
*/ */
...@@ -173,7 +173,7 @@ class TaskTable { ...@@ -173,7 +173,7 @@ class TaskTable {
public: public:
/******** Action ********/ /******** Action ********/
// TODO(wxy): bool to Status // TODO(wxyu): bool to Status
/* /*
* Load a task; * Load a task;
* Set state loading; * Set state loading;
......
...@@ -82,7 +82,7 @@ Action::PushTaskToNeighbourRandomly(const TaskPtr& task, const ResourcePtr& self ...@@ -82,7 +82,7 @@ Action::PushTaskToNeighbourRandomly(const TaskPtr& task, const ResourcePtr& self
} }
} else { } else {
// TODO(wxy): process // TODO(wxyu): process
} }
} }
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
// TODO(wxy): Storage, Route, Executor // TODO(wxyu): Storage, Route, Executor
enum class ResourceType { enum class ResourceType {
DISK = 0, DISK = 0,
CPU = 1, CPU = 1,
...@@ -114,11 +114,11 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> { ...@@ -114,11 +114,11 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
return enable_executor_; return enable_executor_;
} }
// TODO(wxy): const // TODO(wxyu): const
uint64_t uint64_t
NumOfTaskToExec(); NumOfTaskToExec();
// TODO(wxy): need double ? // TODO(wxyu): need double ?
inline uint64_t inline uint64_t
TaskAvgCost() const { TaskAvgCost() const {
return total_cost_ / total_task_; return total_cost_ / total_task_;
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job) XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label)
: Task(TaskType::DeleteTask), delete_job_(delete_job) { : Task(TaskType::DeleteTask, std::move(label)), delete_job_(delete_job) {
} }
void void
......
...@@ -25,7 +25,7 @@ namespace scheduler { ...@@ -25,7 +25,7 @@ namespace scheduler {
class XDeleteTask : public Task { class XDeleteTask : public Task {
public: public:
explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job); explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label);
void void
Load(LoadType type, uint8_t device_id) override; Load(LoadType type, uint8_t device_id) override;
......
...@@ -95,7 +95,8 @@ CollectFileMetrics(int file_type, size_t file_size) { ...@@ -95,7 +95,8 @@ CollectFileMetrics(int file_type, size_t file_size) {
} }
} }
XSearchTask::XSearchTask(TableFileSchemaPtr file) : Task(TaskType::SearchTask), file_(file) { XSearchTask::XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label)
: Task(TaskType::SearchTask, std::move(label)), file_(file) {
if (file_) { if (file_) {
if (file_->metric_type_ != static_cast<int>(MetricType::L2)) { if (file_->metric_type_ != static_cast<int>(MetricType::L2)) {
metric_l2 = false; metric_l2 = false;
......
...@@ -26,10 +26,10 @@ ...@@ -26,10 +26,10 @@
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
// TODO(wxy): rewrite // TODO(wxyu): rewrite
class XSearchTask : public Task { class XSearchTask : public Task {
public: public:
explicit XSearchTask(TableFileSchemaPtr file); explicit XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label);
void void
Load(LoadType type, uint8_t device_id) override; Load(LoadType type, uint8_t device_id) override;
......
...@@ -48,7 +48,7 @@ using TaskPtr = std::shared_ptr<Task>; ...@@ -48,7 +48,7 @@ using TaskPtr = std::shared_ptr<Task>;
// TODO: re-design // TODO: re-design
class Task { class Task {
public: public:
explicit Task(TaskType type) : type_(type) { explicit Task(TaskType type, TaskLabelPtr label) : type_(type), label_(std::move(label)) {
} }
/* /*
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
TestTask::TestTask(TableFileSchemaPtr& file) : XSearchTask(file) { TestTask::TestTask(TableFileSchemaPtr& file, TaskLabelPtr label)
: XSearchTask(file, std::move(label)) {
} }
void void
...@@ -42,7 +43,9 @@ TestTask::Execute() { ...@@ -42,7 +43,9 @@ TestTask::Execute() {
void void
TestTask::Wait() { TestTask::Wait() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return done_; }); cv_.wait(lock, [&] {
return done_;
});
} }
} // namespace scheduler } // namespace scheduler
......
...@@ -24,7 +24,7 @@ namespace scheduler { ...@@ -24,7 +24,7 @@ namespace scheduler {
class TestTask : public XSearchTask { class TestTask : public XSearchTask {
public: public:
explicit TestTask(TableFileSchemaPtr& file); explicit TestTask(TableFileSchemaPtr& file, TaskLabelPtr label);
public: public:
void void
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <functional>
namespace milvus { namespace milvus {
namespace engine { namespace engine {
......
...@@ -24,7 +24,7 @@ namespace milvus { ...@@ -24,7 +24,7 @@ namespace milvus {
namespace scheduler { namespace scheduler {
TEST(TaskTest, INVALID_INDEX) { TEST(TaskTest, INVALID_INDEX) {
auto search_task = std::make_shared<XSearchTask>(nullptr); auto search_task = std::make_shared<XSearchTask>(nullptr, nullptr);
search_task->Load(LoadType::TEST, 10); search_task->Load(LoadType::TEST, 10);
} }
......
...@@ -54,7 +54,8 @@ TEST(NormalTest, INST_TEST) { ...@@ -54,7 +54,8 @@ TEST(NormalTest, INST_TEST) {
ASSERT_FALSE(disks.empty()); ASSERT_FALSE(disks.empty());
if (auto observe = disks[0].lock()) { if (auto observe = disks[0].lock()) {
for (uint64_t i = 0; i < NUM_TASK; ++i) { for (uint64_t i = 0; i < NUM_TASK; ++i) {
auto task = std::make_shared<ms::TestTask>(dummy); auto label = std::make_shared<ms::DefaultLabel>();
auto task = std::make_shared<ms::TestTask>(dummy, label);
task->label() = std::make_shared<ms::DefaultLabel>(); task->label() = std::make_shared<ms::DefaultLabel>();
tasks.push_back(task); tasks.push_back(task);
observe->task_table().Put(task); observe->task_table().Put(task);
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "scheduler/resource/TestResource.h" #include "scheduler/resource/TestResource.h"
#include "scheduler/task/Task.h" #include "scheduler/task/Task.h"
#include "scheduler/task/TestTask.h" #include "scheduler/task/TestTask.h"
#include "scheduler/tasklabel/DefaultLabel.h"
#include "scheduler/ResourceFactory.h" #include "scheduler/ResourceFactory.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -185,7 +186,8 @@ TEST_F(ResourceAdvanceTest, DISK_RESOURCE_TEST) { ...@@ -185,7 +186,8 @@ TEST_F(ResourceAdvanceTest, DISK_RESOURCE_TEST) {
std::vector<std::shared_ptr<TestTask>> tasks; std::vector<std::shared_ptr<TestTask>> tasks;
TableFileSchemaPtr dummy = nullptr; TableFileSchemaPtr dummy = nullptr;
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
auto task = std::make_shared<TestTask>(dummy); auto label = std::make_shared<DefaultLabel>();
auto task = std::make_shared<TestTask>(dummy, label);
tasks.push_back(task); tasks.push_back(task);
disk_resource_->task_table().Put(task); disk_resource_->task_table().Put(task);
} }
...@@ -210,7 +212,8 @@ TEST_F(ResourceAdvanceTest, CPU_RESOURCE_TEST) { ...@@ -210,7 +212,8 @@ TEST_F(ResourceAdvanceTest, CPU_RESOURCE_TEST) {
std::vector<std::shared_ptr<TestTask>> tasks; std::vector<std::shared_ptr<TestTask>> tasks;
TableFileSchemaPtr dummy = nullptr; TableFileSchemaPtr dummy = nullptr;
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
auto task = std::make_shared<TestTask>(dummy); auto label = std::make_shared<DefaultLabel>();
auto task = std::make_shared<TestTask>(dummy, label);
tasks.push_back(task); tasks.push_back(task);
cpu_resource_->task_table().Put(task); cpu_resource_->task_table().Put(task);
} }
...@@ -235,7 +238,8 @@ TEST_F(ResourceAdvanceTest, GPU_RESOURCE_TEST) { ...@@ -235,7 +238,8 @@ TEST_F(ResourceAdvanceTest, GPU_RESOURCE_TEST) {
std::vector<std::shared_ptr<TestTask>> tasks; std::vector<std::shared_ptr<TestTask>> tasks;
TableFileSchemaPtr dummy = nullptr; TableFileSchemaPtr dummy = nullptr;
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
auto task = std::make_shared<TestTask>(dummy); auto label = std::make_shared<DefaultLabel>();
auto task = std::make_shared<TestTask>(dummy, label);
tasks.push_back(task); tasks.push_back(task);
gpu_resource_->task_table().Put(task); gpu_resource_->task_table().Put(task);
} }
...@@ -260,7 +264,8 @@ TEST_F(ResourceAdvanceTest, TEST_RESOURCE_TEST) { ...@@ -260,7 +264,8 @@ TEST_F(ResourceAdvanceTest, TEST_RESOURCE_TEST) {
std::vector<std::shared_ptr<TestTask>> tasks; std::vector<std::shared_ptr<TestTask>> tasks;
TableFileSchemaPtr dummy = nullptr; TableFileSchemaPtr dummy = nullptr;
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
auto task = std::make_shared<TestTask>(dummy); auto label = std::make_shared<DefaultLabel>();
auto task = std::make_shared<TestTask>(dummy, label);
tasks.push_back(task); tasks.push_back(task);
test_resource_->task_table().Put(task); test_resource_->task_table().Put(task);
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "scheduler/resource/DiskResource.h" #include "scheduler/resource/DiskResource.h"
#include "scheduler/resource/TestResource.h" #include "scheduler/resource/TestResource.h"
#include "scheduler/task/TestTask.h" #include "scheduler/task/TestTask.h"
#include "scheduler/tasklabel/DefaultLabel.h"
#include "scheduler/ResourceMgr.h" #include "scheduler/ResourceMgr.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -184,7 +185,8 @@ TEST_F(ResourceMgrAdvanceTest, REGISTER_SUBSCRIBER) { ...@@ -184,7 +185,8 @@ TEST_F(ResourceMgrAdvanceTest, REGISTER_SUBSCRIBER) {
}; };
mgr1_->RegisterSubscriber(callback); mgr1_->RegisterSubscriber(callback);
TableFileSchemaPtr dummy = nullptr; TableFileSchemaPtr dummy = nullptr;
disk_res->task_table().Put(std::make_shared<TestTask>(dummy)); auto label = std::make_shared<DefaultLabel>();
disk_res->task_table().Put(std::make_shared<TestTask>(dummy, label));
sleep(1); sleep(1);
ASSERT_TRUE(flag); ASSERT_TRUE(flag);
} }
......
...@@ -155,7 +155,8 @@ TEST_F(SchedulerTest, ON_LOAD_COMPLETED) { ...@@ -155,7 +155,8 @@ TEST_F(SchedulerTest, ON_LOAD_COMPLETED) {
insert_dummy_index_into_gpu_cache(1); insert_dummy_index_into_gpu_cache(1);
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
auto task = std::make_shared<TestTask>(dummy); auto label = std::make_shared<DefaultLabel>();
auto task = std::make_shared<TestTask>(dummy, label);
task->label() = std::make_shared<DefaultLabel>(); task->label() = std::make_shared<DefaultLabel>();
tasks.push_back(task); tasks.push_back(task);
cpu_resource_.lock()->task_table().Put(task); cpu_resource_.lock()->task_table().Put(task);
...@@ -174,7 +175,8 @@ TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) { ...@@ -174,7 +175,8 @@ TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) {
tasks.clear(); tasks.clear();
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
auto task = std::make_shared<TestTask>(dummy1); auto label = std::make_shared<DefaultLabel>();
auto task = std::make_shared<TestTask>(dummy1, label);
task->label() = std::make_shared<DefaultLabel>(); task->label() = std::make_shared<DefaultLabel>();
tasks.push_back(task); tasks.push_back(task);
cpu_resource_.lock()->task_table().Put(task); cpu_resource_.lock()->task_table().Put(task);
...@@ -242,7 +244,8 @@ TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) { ...@@ -242,7 +244,8 @@ TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) {
dummy->location_ = "location"; dummy->location_ = "location";
for (uint64_t i = 0; i < NUM; ++i) { for (uint64_t i = 0; i < NUM; ++i) {
std::shared_ptr<TestTask> task = std::make_shared<TestTask>(dummy); auto label = std::make_shared<DefaultLabel>();
std::shared_ptr<TestTask> task = std::make_shared<TestTask>(dummy, label);
task->label() = std::make_shared<SpecResLabel>(disk_); task->label() = std::make_shared<SpecResLabel>(disk_);
tasks.push_back(task); tasks.push_back(task);
disk_.lock()->task_table().Put(task); disk_.lock()->task_table().Put(task);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "scheduler/TaskTable.h" #include "scheduler/TaskTable.h"
#include "scheduler/task/TestTask.h" #include "scheduler/task/TestTask.h"
#include "scheduler/tasklabel/DefaultLabel.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace { namespace {
...@@ -172,8 +173,9 @@ class TaskTableBaseTest : public ::testing::Test { ...@@ -172,8 +173,9 @@ class TaskTableBaseTest : public ::testing::Test {
SetUp() override { SetUp() override {
ms::TableFileSchemaPtr dummy = nullptr; ms::TableFileSchemaPtr dummy = nullptr;
invalid_task_ = nullptr; invalid_task_ = nullptr;
task1_ = std::make_shared<ms::TestTask>(dummy); auto label = std::make_shared<ms::DefaultLabel>();
task2_ = std::make_shared<ms::TestTask>(dummy); task1_ = std::make_shared<ms::TestTask>(dummy, label);
task2_ = std::make_shared<ms::TestTask>(dummy, label);
} }
ms::TaskPtr invalid_task_; ms::TaskPtr invalid_task_;
...@@ -340,7 +342,8 @@ class TaskTableAdvanceTest : public ::testing::Test { ...@@ -340,7 +342,8 @@ class TaskTableAdvanceTest : public ::testing::Test {
SetUp() override { SetUp() override {
ms::TableFileSchemaPtr dummy = nullptr; ms::TableFileSchemaPtr dummy = nullptr;
for (uint64_t i = 0; i < 8; ++i) { for (uint64_t i = 0; i < 8; ++i) {
auto task = std::make_shared<ms::TestTask>(dummy); auto label = std::make_shared<ms::DefaultLabel>();
auto task = std::make_shared<ms::TestTask>(dummy, label);
table1_.Put(task); table1_.Put(task);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册