提交 dbae0993 编写于 作者: W wxyu

MS-488 Improve code format in scheduler


Former-commit-id: e9e4c2a0271b49cd16b25f0c4f712c9b54b10da6
上级 b2e029c4
...@@ -83,6 +83,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -83,6 +83,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-455 - Distribute tasks by minimal cost in scheduler - MS-455 - Distribute tasks by minimal cost in scheduler
- MS-460 - Put transport speed as weight when choosing neighbour to execute task - MS-460 - Put transport speed as weight when choosing neighbour to execute task
- MS-459 - Add cache for pick function in tasktable - MS-459 - Add cache for pick function in tasktable
- MS-488 - Improve code format in scheduler
## New Feature ## New Feature
- MS-343 - Implement ResourceMgr - MS-343 - Implement ResourceMgr
......
...@@ -20,12 +20,12 @@ ShortestPath(const ResourcePtr &src, ...@@ -20,12 +20,12 @@ ShortestPath(const ResourcePtr &src,
std::vector<std::vector<std::string>> paths; std::vector<std::vector<std::string>> paths;
uint64_t num_of_resources = res_mgr->GetAllResouces().size(); uint64_t num_of_resources = res_mgr->GetAllResources().size();
std::unordered_map<uint64_t, std::string> id_name_map; std::unordered_map<uint64_t, std::string> id_name_map;
std::unordered_map<std::string, uint64_t> name_id_map; std::unordered_map<std::string, uint64_t> name_id_map;
for (uint64_t i = 0; i < num_of_resources; ++i) { for (uint64_t i = 0; i < num_of_resources; ++i) {
id_name_map.insert(std::make_pair(i, res_mgr->GetAllResouces().at(i)->Name())); id_name_map.insert(std::make_pair(i, res_mgr->GetAllResources().at(i)->name()));
name_id_map.insert(std::make_pair(res_mgr->GetAllResouces().at(i)->Name(), i)); name_id_map.insert(std::make_pair(res_mgr->GetAllResources().at(i)->name(), i));
} }
std::vector<std::vector<uint64_t> > dis_matrix; std::vector<std::vector<uint64_t> > dis_matrix;
...@@ -40,23 +40,23 @@ ShortestPath(const ResourcePtr &src, ...@@ -40,23 +40,23 @@ ShortestPath(const ResourcePtr &src,
std::vector<bool> vis(num_of_resources, false); std::vector<bool> vis(num_of_resources, false);
std::vector<uint64_t> dis(num_of_resources, MAXINT); std::vector<uint64_t> dis(num_of_resources, MAXINT);
for (auto &res : res_mgr->GetAllResouces()) { for (auto &res : res_mgr->GetAllResources()) {
auto cur_node = std::static_pointer_cast<Node>(res); auto cur_node = std::static_pointer_cast<Node>(res);
auto cur_neighbours = cur_node->GetNeighbours(); auto cur_neighbours = cur_node->GetNeighbours();
for (auto &neighbour : cur_neighbours) { for (auto &neighbour : cur_neighbours) {
auto neighbour_res = std::static_pointer_cast<Resource>(neighbour.neighbour_node.lock()); auto neighbour_res = std::static_pointer_cast<Resource>(neighbour.neighbour_node.lock());
dis_matrix[name_id_map.at(res->Name())][name_id_map.at(neighbour_res->Name())] = dis_matrix[name_id_map.at(res->name())][name_id_map.at(neighbour_res->name())] =
neighbour.connection.transport_cost(); neighbour.connection.transport_cost();
} }
} }
for (uint64_t i = 0; i < num_of_resources; ++i) { for (uint64_t i = 0; i < num_of_resources; ++i) {
dis[i] = dis_matrix[name_id_map.at(src->Name())][i]; dis[i] = dis_matrix[name_id_map.at(src->name())][i];
} }
vis[name_id_map.at(src->Name())] = true; vis[name_id_map.at(src->name())] = true;
std::vector<int64_t> parent(num_of_resources, -1); std::vector<int64_t> parent(num_of_resources, -1);
for (uint64_t i = 0; i < num_of_resources; ++i) { for (uint64_t i = 0; i < num_of_resources; ++i) {
...@@ -71,7 +71,7 @@ ShortestPath(const ResourcePtr &src, ...@@ -71,7 +71,7 @@ ShortestPath(const ResourcePtr &src,
vis[temp] = true; vis[temp] = true;
if (i == 0) { if (i == 0) {
parent[temp] = name_id_map.at(src->Name()); parent[temp] = name_id_map.at(src->name());
} }
for (uint64_t j = 0; j < num_of_resources; ++j) { for (uint64_t j = 0; j < num_of_resources; ++j) {
...@@ -82,15 +82,15 @@ ShortestPath(const ResourcePtr &src, ...@@ -82,15 +82,15 @@ ShortestPath(const ResourcePtr &src,
} }
} }
int64_t parent_idx = parent[name_id_map.at(dest->Name())]; int64_t parent_idx = parent[name_id_map.at(dest->name())];
if (parent_idx != -1) { if (parent_idx != -1) {
path.push_back(dest->Name()); path.push_back(dest->name());
} }
while (parent_idx != -1) { while (parent_idx != -1) {
path.push_back(id_name_map.at(parent_idx)); path.push_back(id_name_map.at(parent_idx));
parent_idx = parent[parent_idx]; parent_idx = parent[parent_idx];
} }
return dis[name_id_map.at(dest->Name())]; return dis[name_id_map.at(dest->name())];
} }
} }
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <memory>
namespace zilliz {
namespace milvus {
namespace engine {
// dummy cache_mgr
class CacheMgr {
};
using CacheMgrPtr = std::shared_ptr<CacheMgr>;
}
}
}
...@@ -12,67 +12,31 @@ namespace zilliz { ...@@ -12,67 +12,31 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
ResourceMgr::ResourceMgr()
: running_(false) {
} void
ResourceMgr::Start() {
uint64_t std::lock_guard<std::mutex> lck(resources_mutex_);
ResourceMgr::GetNumOfComputeResource() {
uint64_t count = 0;
for (auto &res : resources_) {
if (res->HasExecutor()) {
++count;
}
}
return count;
}
std::vector<ResourcePtr>
ResourceMgr::GetComputeResource() {
std::vector<ResourcePtr > result;
for (auto &resource : resources_) { for (auto &resource : resources_) {
if (resource->HasExecutor()) { resource->Start();
result.emplace_back(resource);
}
}
return result;
}
uint64_t
ResourceMgr::GetNumGpuResource() const {
uint64_t num = 0;
for (auto &res : resources_) {
if (res->Type() == ResourceType::GPU) {
num++;
}
} }
return num; running_ = true;
worker_thread_ = std::thread(&ResourceMgr::event_process, this);
} }
ResourcePtr void
ResourceMgr::GetResource(ResourceType type, uint64_t device_id) { ResourceMgr::Stop() {
for (auto &resource : resources_) { {
if (resource->Type() == type && resource->DeviceId() == device_id) { std::lock_guard<std::mutex> lock(event_mutex_);
return resource; running_ = false;
} queue_.push(nullptr);
event_cv_.notify_one();
} }
return nullptr; worker_thread_.join();
}
ResourcePtr std::lock_guard<std::mutex> lck(resources_mutex_);
ResourceMgr::GetResourceByName(std::string name) {
for (auto &resource : resources_) { for (auto &resource : resources_) {
if (resource->Name() == name) { resource->Stop();
return resource;
}
} }
return nullptr;
}
std::vector<ResourcePtr>
ResourceMgr::GetAllResouces() {
return resources_;
} }
ResourceWPtr ResourceWPtr
...@@ -85,75 +49,85 @@ ResourceMgr::Add(ResourcePtr &&resource) { ...@@ -85,75 +49,85 @@ ResourceMgr::Add(ResourcePtr &&resource) {
return ret; return ret;
} }
if (resource->Type() == ResourceType::DISK) { resource->RegisterSubscriber(std::bind(&ResourceMgr::post_event, this, std::placeholders::_1));
if (resource->type() == ResourceType::DISK) {
disk_resources_.emplace_back(ResourceWPtr(resource)); disk_resources_.emplace_back(ResourceWPtr(resource));
} }
resources_.emplace_back(resource); resources_.emplace_back(resource);
size_t index = resources_.size() - 1;
resource->RegisterSubscriber(std::bind(&ResourceMgr::PostEvent, this, std::placeholders::_1));
return ret; return ret;
} }
void void
ResourceMgr::Connect(const std::string &name1, const std::string &name2, Connection &connection) { ResourceMgr::Connect(const std::string &name1, const std::string &name2, Connection &connection) {
auto res1 = get_resource_by_name(name1); auto res1 = GetResource(name1);
auto res2 = get_resource_by_name(name2); auto res2 = GetResource(name2);
if (res1 && res2) { if (res1 && res2) {
res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection); res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection);
// TODO: enable when task balance supported
// res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection); // res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection);
} }
} }
void void
ResourceMgr::Connect(ResourceWPtr &res1, ResourceWPtr &res2, Connection &connection) { ResourceMgr::Clear() {
if (auto observe_a = res1.lock()) { std::lock_guard<std::mutex> lck(resources_mutex_);
if (auto observe_b = res2.lock()) { disk_resources_.clear();
observe_a->AddNeighbour(std::static_pointer_cast<Node>(observe_b), connection); resources_.clear();
observe_b->AddNeighbour(std::static_pointer_cast<Node>(observe_a), connection);
}
}
} }
std::vector<ResourcePtr>
void ResourceMgr::GetComputeResource() {
ResourceMgr::Start() { std::vector<ResourcePtr> result;
std::lock_guard<std::mutex> lck(resources_mutex_);
for (auto &resource : resources_) { for (auto &resource : resources_) {
resource->Start(); if (resource->HasExecutor()) {
result.emplace_back(resource);
}
} }
running_ = true; return result;
worker_thread_ = std::thread(&ResourceMgr::event_process, this);
} }
void ResourcePtr
ResourceMgr::Stop() { ResourceMgr::GetResource(ResourceType type, uint64_t device_id) {
{ for (auto &resource : resources_) {
std::lock_guard<std::mutex> lock(event_mutex_); if (resource->type() == type && resource->device_id() == device_id) {
running_ = false; return resource;
queue_.push(nullptr); }
event_cv_.notify_one();
} }
worker_thread_.join(); return nullptr;
}
std::lock_guard<std::mutex> lck(resources_mutex_); ResourcePtr
ResourceMgr::GetResource(const std::string &name) {
for (auto &resource : resources_) { for (auto &resource : resources_) {
resource->Stop(); if (resource->name() == name) {
return resource;
}
} }
return nullptr;
} }
void uint64_t
ResourceMgr::Clear() { ResourceMgr::GetNumOfComputeResource() {
std::lock_guard<std::mutex> lck(resources_mutex_); uint64_t count = 0;
disk_resources_.clear(); for (auto &res : resources_) {
resources_.clear(); if (res->HasExecutor()) {
++count;
}
}
return count;
} }
void uint64_t
ResourceMgr::PostEvent(const EventPtr &event) { ResourceMgr::GetNumGpuResource() const {
std::lock_guard<std::mutex> lock(event_mutex_); uint64_t num = 0;
queue_.emplace(event); for (auto &res : resources_) {
event_cv_.notify_one(); if (res->type() == ResourceType::GPU) {
num++;
}
}
return num;
} }
std::string std::string
...@@ -180,14 +154,13 @@ ResourceMgr::DumpTaskTables() { ...@@ -180,14 +154,13 @@ ResourceMgr::DumpTaskTables() {
return ss.str(); return ss.str();
} }
ResourcePtr void
ResourceMgr::get_resource_by_name(const std::string &name) { ResourceMgr::post_event(const EventPtr &event) {
for (auto &res : resources_) { {
if (res->Name() == name) { std::lock_guard<std::mutex> lock(event_mutex_);
return res; queue_.emplace(event);
}
} }
return nullptr; event_cv_.notify_one();
} }
void void
...@@ -203,8 +176,6 @@ ResourceMgr::event_process() { ...@@ -203,8 +176,6 @@ ResourceMgr::event_process() {
break; break;
} }
// ENGINE_LOG_DEBUG << "ResourceMgr process " << *event;
if (subscriber_) { if (subscriber_) {
subscriber_(event); subscriber_(event);
} }
......
...@@ -22,78 +22,63 @@ namespace engine { ...@@ -22,78 +22,63 @@ namespace engine {
class ResourceMgr { class ResourceMgr {
public: public:
ResourceMgr(); ResourceMgr() = default;
public:
/******** Management Interface ********/ /******** Management Interface ********/
void
Start();
void
Stop();
ResourceWPtr
Add(ResourcePtr &&resource);
void
Connect(const std::string &res1, const std::string &res2, Connection &connection);
void
Clear();
inline void inline void
RegisterSubscriber(std::function<void(EventPtr)> subscriber) { RegisterSubscriber(std::function<void(EventPtr)> subscriber) {
subscriber_ = std::move(subscriber); subscriber_ = std::move(subscriber);
} }
std::vector<ResourceWPtr> & public:
/******** Management Interface ********/
inline std::vector<ResourceWPtr> &
GetDiskResources() { GetDiskResources() {
return disk_resources_; return disk_resources_;
} }
uint64_t // TODO: why return shared pointer
GetNumGpuResource() const; inline std::vector<ResourcePtr>
GetAllResources() {
return resources_;
}
std::vector<ResourcePtr>
GetComputeResource();
ResourcePtr ResourcePtr
GetResource(ResourceType type, uint64_t device_id); GetResource(ResourceType type, uint64_t device_id);
ResourcePtr ResourcePtr
GetResourceByName(std::string name); GetResource(const std::string &name);
std::vector<ResourcePtr>
GetAllResouces();
/*
* Return account of resource which enable executor;
*/
uint64_t uint64_t
GetNumOfComputeResource(); GetNumOfComputeResource();
std::vector<ResourcePtr> uint64_t
GetComputeResource(); GetNumGpuResource() const;
/*
* Add resource into Resource Management;
* Generate functions on events;
* Functions only modify bool variable, like event trigger;
*/
ResourceWPtr
Add(ResourcePtr &&resource);
void
Connect(const std::string &res1, const std::string &res2, Connection &connection);
/*
* Create connection between A and B;
*/
void
Connect(ResourceWPtr &res1, ResourceWPtr &res2, Connection &connection);
/*
* Synchronous start all resource;
* Last, start event process thread;
*/
void
Start();
void
Stop();
void
Clear();
void
PostEvent(const EventPtr &event);
public:
// TODO: add stats interface(low) // TODO: add stats interface(low)
public: public:
/******** Utlitity Functions ********/ /******** Utility Functions ********/
std::string std::string
Dump(); Dump();
...@@ -101,26 +86,26 @@ public: ...@@ -101,26 +86,26 @@ public:
DumpTaskTables(); DumpTaskTables();
private: private:
ResourcePtr void
get_resource_by_name(const std::string &name); post_event(const EventPtr &event);
void void
event_process(); event_process();
private: private:
std::queue<EventPtr> queue_; bool running_ = false;
std::function<void(EventPtr)> subscriber_ = nullptr;
bool running_;
std::vector<ResourceWPtr> disk_resources_; std::vector<ResourceWPtr> disk_resources_;
std::vector<ResourcePtr> resources_; std::vector<ResourcePtr> resources_;
mutable std::mutex resources_mutex_; mutable std::mutex resources_mutex_;
std::thread worker_thread_;
std::queue<EventPtr> queue_;
std::function<void(EventPtr)> subscriber_ = nullptr;
std::mutex event_mutex_; std::mutex event_mutex_;
std::condition_variable event_cv_; std::condition_variable event_cv_;
std::thread worker_thread_;
}; };
using ResourceMgrPtr = std::shared_ptr<ResourceMgr>; using ResourceMgrPtr = std::shared_ptr<ResourceMgr>;
......
...@@ -38,7 +38,7 @@ StartSchedulerService() { ...@@ -38,7 +38,7 @@ StartSchedulerService() {
enable_loader, enable_loader,
enable_executor)); enable_executor));
if (res.lock()->Type() == ResourceType::GPU) { if (res.lock()->type() == ResourceType::GPU) {
auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY, 300); auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY, 300);
auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY, 300); auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY, 300);
auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM, 2); auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM, 2);
......
...@@ -143,7 +143,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { ...@@ -143,7 +143,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
auto task = load_completed_event->task_table_item_->task; auto task = load_completed_event->task_table_item_->task;
// if this resource is disk, assign it to smallest cost resource // if this resource is disk, assign it to smallest cost resource
if (self->Type() == ResourceType::DISK) { if (self->type() == ResourceType::DISK) {
// step 1: calculate shortest path per resource, from disk to compute resource // step 1: calculate shortest path per resource, from disk to compute resource
auto compute_resources = res_mgr_.lock()->GetComputeResource(); auto compute_resources = res_mgr_.lock()->GetComputeResource();
std::vector<std::vector<std::string>> paths; std::vector<std::vector<std::string>> paths;
...@@ -176,11 +176,11 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { ...@@ -176,11 +176,11 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
task->path() = task_path; task->path() = task_path;
} }
if(self->Name() == task->path().Last()) { if(self->name() == task->path().Last()) {
self->WakeupLoader(); self->WakeupLoader();
} else { } else {
auto next_res_name = task->path().Next(); auto next_res_name = task->path().Next();
auto next_res = res_mgr_.lock()->GetResourceByName(next_res_name); auto next_res = res_mgr_.lock()->GetResource(next_res_name);
load_completed_event->task_table_item_->Move(); load_completed_event->task_table_item_->Move();
next_res->task_table().Put(task); next_res->task_table().Put(task);
} }
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "TaskTable.h" #include "TaskTable.h"
#include "event/TaskTableUpdatedEvent.h" #include "event/TaskTableUpdatedEvent.h"
#include "Utils.h"
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <ctime> #include <ctime>
...@@ -15,14 +17,6 @@ namespace zilliz { ...@@ -15,14 +17,6 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
uint64_t
get_now_timestamp() {
std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
return millis;
}
std::string std::string
ToString(TaskTableItemState state) { ToString(TaskTableItemState state) {
switch (state) { switch (state) {
...@@ -64,7 +58,7 @@ TaskTableItem::Load() { ...@@ -64,7 +58,7 @@ TaskTableItem::Load() {
if (state == TaskTableItemState::START) { if (state == TaskTableItemState::START) {
state = TaskTableItemState::LOADING; state = TaskTableItemState::LOADING;
lock.unlock(); lock.unlock();
timestamp.load = get_now_timestamp(); timestamp.load = get_current_timestamp();
return true; return true;
} }
return false; return false;
...@@ -75,7 +69,7 @@ TaskTableItem::Loaded() { ...@@ -75,7 +69,7 @@ TaskTableItem::Loaded() {
if (state == TaskTableItemState::LOADING) { if (state == TaskTableItemState::LOADING) {
state = TaskTableItemState::LOADED; state = TaskTableItemState::LOADED;
lock.unlock(); lock.unlock();
timestamp.loaded = get_now_timestamp(); timestamp.loaded = get_current_timestamp();
return true; return true;
} }
return false; return false;
...@@ -86,7 +80,7 @@ TaskTableItem::Execute() { ...@@ -86,7 +80,7 @@ TaskTableItem::Execute() {
if (state == TaskTableItemState::LOADED) { if (state == TaskTableItemState::LOADED) {
state = TaskTableItemState::EXECUTING; state = TaskTableItemState::EXECUTING;
lock.unlock(); lock.unlock();
timestamp.execute = get_now_timestamp(); timestamp.execute = get_current_timestamp();
return true; return true;
} }
return false; return false;
...@@ -97,8 +91,8 @@ TaskTableItem::Executed() { ...@@ -97,8 +91,8 @@ TaskTableItem::Executed() {
if (state == TaskTableItemState::EXECUTING) { if (state == TaskTableItemState::EXECUTING) {
state = TaskTableItemState::EXECUTED; state = TaskTableItemState::EXECUTED;
lock.unlock(); lock.unlock();
timestamp.executed = get_now_timestamp(); timestamp.executed = get_current_timestamp();
timestamp.finish = get_now_timestamp(); timestamp.finish = get_current_timestamp();
return true; return true;
} }
return false; return false;
...@@ -109,7 +103,7 @@ TaskTableItem::Move() { ...@@ -109,7 +103,7 @@ TaskTableItem::Move() {
if (state == TaskTableItemState::LOADED) { if (state == TaskTableItemState::LOADED) {
state = TaskTableItemState::MOVING; state = TaskTableItemState::MOVING;
lock.unlock(); lock.unlock();
timestamp.move = get_now_timestamp(); timestamp.move = get_current_timestamp();
return true; return true;
} }
return false; return false;
...@@ -120,8 +114,8 @@ TaskTableItem::Moved() { ...@@ -120,8 +114,8 @@ TaskTableItem::Moved() {
if (state == TaskTableItemState::MOVING) { if (state == TaskTableItemState::MOVING) {
state = TaskTableItemState::MOVED; state = TaskTableItemState::MOVED;
lock.unlock(); lock.unlock();
timestamp.moved = get_now_timestamp(); timestamp.moved = get_current_timestamp();
timestamp.finish = get_now_timestamp(); timestamp.finish = get_current_timestamp();
return true; return true;
} }
return false; return false;
...@@ -177,7 +171,7 @@ TaskTable::Put(TaskPtr task) { ...@@ -177,7 +171,7 @@ TaskTable::Put(TaskPtr task) {
item->id = id_++; item->id = id_++;
item->task = std::move(task); item->task = std::move(task);
item->state = TaskTableItemState::START; item->state = TaskTableItemState::START;
item->timestamp.start = get_now_timestamp(); item->timestamp.start = get_current_timestamp();
table_.push_back(item); table_.push_back(item);
if (subscriber_) { if (subscriber_) {
subscriber_(); subscriber_();
...@@ -192,7 +186,7 @@ TaskTable::Put(std::vector<TaskPtr> &tasks) { ...@@ -192,7 +186,7 @@ TaskTable::Put(std::vector<TaskPtr> &tasks) {
item->id = id_++; item->id = id_++;
item->task = std::move(task); item->task = std::move(task);
item->state = TaskTableItemState::START; item->state = TaskTableItemState::START;
item->timestamp.start = get_now_timestamp(); item->timestamp.start = get_current_timestamp();
table_.push_back(item); table_.push_back(item);
} }
if (subscriber_) { if (subscriber_) {
......
...@@ -40,20 +40,17 @@ struct TaskTimestamp { ...@@ -40,20 +40,17 @@ struct TaskTimestamp {
}; };
struct TaskTableItem { struct TaskTableItem {
TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex(), priority(0) {} TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex() {}
TaskTableItem(const TaskTableItem &src) TaskTableItem(const TaskTableItem &src)
: id(src.id), state(src.state), mutex(), priority(src.priority) {} : id(src.id), state(src.state), mutex() {}
uint64_t id; // auto increment from 0; uint64_t id; // auto increment from 0;
// TODO: add tag into task
TaskPtr task; // the task; TaskPtr task; // the task;
TaskTableItemState state; // the state; TaskTableItemState state; // the state;
std::mutex mutex; std::mutex mutex;
TaskTimestamp timestamp; TaskTimestamp timestamp;
uint8_t priority; // just a number, meaningless;
bool bool
IsFinish(); IsFinish();
...@@ -113,7 +110,7 @@ public: ...@@ -113,7 +110,7 @@ public:
Get(uint64_t index); Get(uint64_t index);
/* /*
* TODO * TODO(wxyu): BIG GC
* Remove sequence task which is DONE or MOVED from front; * Remove sequence task which is DONE or MOVED from front;
* Called by ? * Called by ?
*/ */
...@@ -135,6 +132,7 @@ public: ...@@ -135,6 +132,7 @@ public:
Size() { Size() {
return table_.size(); return table_.size();
} }
public: public:
TaskTableItemPtr & TaskTableItemPtr &
operator[](uint64_t index) { operator[](uint64_t index) {
...@@ -225,7 +223,6 @@ public: ...@@ -225,7 +223,6 @@ public:
Dump(); Dump();
private: private:
// TODO: map better ?
std::uint64_t id_ = 0; std::uint64_t id_ = 0;
mutable std::mutex id_mutex_; mutable std::mutex id_mutex_;
std::deque<TaskTableItemPtr> table_; std::deque<TaskTableItemPtr> table_;
......
...@@ -4,16 +4,17 @@ ...@@ -4,16 +4,17 @@
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include <chrono>
#include "Utils.h" #include "Utils.h"
#include <chrono>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
uint64_t uint64_t
get_current_timestamp() get_current_timestamp() {
{
std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now(); std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
auto duration = now.time_since_epoch(); auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count(); auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include <cstdint> #include <cstdint>
......
...@@ -15,6 +15,7 @@ namespace engine { ...@@ -15,6 +15,7 @@ namespace engine {
class Connection { class Connection {
public: public:
// TODO: update construct function, speed: double->uint64_t
Connection(std::string name, double speed) Connection(std::string name, double speed)
: name_(std::move(name)), speed_(speed) {} : name_(std::move(name)), speed_(speed) {}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <memory>
namespace zilliz {
namespace milvus {
namespace engine {
class RegisterHandler {
public:
virtual void Exec() = 0;
};
using RegisterHandlerPtr = std::shared_ptr<RegisterHandler>;
}
}
}
\ No newline at end of file
...@@ -12,7 +12,8 @@ namespace zilliz { ...@@ -12,7 +12,8 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
std::ostream &operator<<(std::ostream &out, const Resource &resource) { std::ostream &
operator<<(std::ostream &out, const Resource &resource) {
out << resource.Dump(); out << resource.Dump();
return out; return out;
} }
...@@ -25,11 +26,9 @@ Resource::Resource(std::string name, ...@@ -25,11 +26,9 @@ Resource::Resource(std::string name,
: name_(std::move(name)), : name_(std::move(name)),
type_(type), type_(type),
device_id_(device_id), device_id_(device_id),
running_(false),
enable_loader_(enable_loader), enable_loader_(enable_loader),
enable_executor_(enable_executor), enable_executor_(enable_executor) {
load_flag_(false), // register subscriber in tasktable
exec_flag_(false) {
task_table_.RegisterSubscriber([&] { task_table_.RegisterSubscriber([&] {
if (subscriber_) { if (subscriber_) {
auto event = std::make_shared<TaskTableUpdatedEvent>(shared_from_this()); auto event = std::make_shared<TaskTableUpdatedEvent>(shared_from_this());
...@@ -38,7 +37,8 @@ Resource::Resource(std::string name, ...@@ -38,7 +37,8 @@ Resource::Resource(std::string name,
}); });
} }
void Resource::Start() { void
Resource::Start() {
running_ = true; running_ = true;
if (enable_loader_) { if (enable_loader_) {
loader_thread_ = std::thread(&Resource::loader_function, this); loader_thread_ = std::thread(&Resource::loader_function, this);
...@@ -48,7 +48,8 @@ void Resource::Start() { ...@@ -48,7 +48,8 @@ void Resource::Start() {
} }
} }
void Resource::Stop() { void
Resource::Stop() {
running_ = false; running_ = false;
if (enable_loader_) { if (enable_loader_) {
WakeupLoader(); WakeupLoader();
...@@ -60,11 +61,8 @@ void Resource::Stop() { ...@@ -60,11 +61,8 @@ void Resource::Stop() {
} }
} }
TaskTable &Resource::task_table() { void
return task_table_; Resource::WakeupLoader() {
}
void Resource::WakeupLoader() {
{ {
std::lock_guard<std::mutex> lock(load_mutex_); std::lock_guard<std::mutex> lock(load_mutex_);
load_flag_ = true; load_flag_ = true;
...@@ -72,7 +70,8 @@ void Resource::WakeupLoader() { ...@@ -72,7 +70,8 @@ void Resource::WakeupLoader() {
load_cv_.notify_one(); load_cv_.notify_one();
} }
void Resource::WakeupExecutor() { void
Resource::WakeupExecutor() {
{ {
std::lock_guard<std::mutex> lock(exec_mutex_); std::lock_guard<std::mutex> lock(exec_mutex_);
exec_flag_ = true; exec_flag_ = true;
...@@ -80,6 +79,15 @@ void Resource::WakeupExecutor() { ...@@ -80,6 +79,15 @@ void Resource::WakeupExecutor() {
exec_cv_.notify_one(); exec_cv_.notify_one();
} }
uint64_t
Resource::NumOfTaskToExec() {
uint64_t count = 0;
for (auto &task : task_table_) {
if (task->state == TaskTableItemState::LOADED) ++count;
}
return count;
}
TaskTableItemPtr Resource::pick_task_load() { TaskTableItemPtr Resource::pick_task_load() {
auto indexes = task_table_.PickToLoad(10); auto indexes = task_table_.PickToLoad(10);
for (auto index : indexes) { for (auto index : indexes) {
...@@ -156,11 +164,6 @@ void Resource::executor_function() { ...@@ -156,11 +164,6 @@ void Resource::executor_function() {
} }
} }
RegisterHandlerPtr Resource::GetRegisterFunc(const RegisterType &type) {
// construct object each time.
return register_table_[type]();
}
} }
} }
} }
\ No newline at end of file
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "../task/Task.h" #include "../task/Task.h"
#include "Connection.h" #include "Connection.h"
#include "Node.h" #include "Node.h"
#include "RegisterHandler.h"
namespace zilliz { namespace zilliz {
...@@ -35,13 +34,6 @@ enum class ResourceType { ...@@ -35,13 +34,6 @@ enum class ResourceType {
GPU = 2 GPU = 2
}; };
enum class RegisterType {
START_UP,
ON_FINISH_TASK,
ON_COPY_COMPLETED,
ON_TASK_TABLE_UPDATED,
};
class Resource : public Node, public std::enable_shared_from_this<Resource> { class Resource : public Node, public std::enable_shared_from_this<Resource> {
public: public:
/* /*
...@@ -68,56 +60,51 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> { ...@@ -68,56 +60,51 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
void void
WakeupExecutor(); WakeupExecutor();
public:
template<typename T>
void Register_T(const RegisterType &type) {
register_table_.emplace(type, [] { return std::make_shared<T>(); });
}
RegisterHandlerPtr
GetRegisterFunc(const RegisterType &type);
inline void inline void
RegisterSubscriber(std::function<void(EventPtr)> subscriber) { RegisterSubscriber(std::function<void(EventPtr)> subscriber) {
subscriber_ = std::move(subscriber); subscriber_ = std::move(subscriber);
} }
inline virtual std::string
Dump() const {
return "<Resource>";
}
public:
inline std::string inline std::string
Name() const { name() const {
return name_; return name_;
} }
inline ResourceType inline ResourceType
Type() const { type() const {
return type_; return type_;
} }
inline uint64_t inline uint64_t
DeviceId() { device_id() const {
return device_id_; return device_id_;
} }
// TODO: better name? TaskTable &
task_table() {
return task_table_;
}
public:
inline bool inline bool
HasLoader() { HasLoader() const {
return enable_loader_; return enable_loader_;
} }
// TODO: better name?
inline bool inline bool
HasExecutor() { HasExecutor() const {
return enable_executor_; return enable_executor_;
} }
// TODO: const // TODO: const
uint64_t uint64_t
NumOfTaskToExec() { NumOfTaskToExec();
uint64_t count = 0;
for (auto &task : task_table_) {
if (task->state == TaskTableItemState::LOADED) ++count;
}
return count;
}
// TODO: need double ? // TODO: need double ?
inline uint64_t inline uint64_t
...@@ -130,14 +117,6 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> { ...@@ -130,14 +117,6 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
return total_task_; return total_task_;
} }
TaskTable &
task_table();
inline virtual std::string
Dump() const {
return "<Resource>";
}
friend std::ostream &operator<<(std::ostream &out, const Resource &resource); friend std::ostream &operator<<(std::ostream &out, const Resource &resource);
protected: protected:
...@@ -198,6 +177,7 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> { ...@@ -198,6 +177,7 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
protected: protected:
uint64_t device_id_; uint64_t device_id_;
std::string name_; std::string name_;
private: private:
ResourceType type_; ResourceType type_;
...@@ -206,17 +186,16 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> { ...@@ -206,17 +186,16 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
uint64_t total_cost_ = 0; uint64_t total_cost_ = 0;
uint64_t total_task_ = 0; uint64_t total_task_ = 0;
std::map<RegisterType, std::function<RegisterHandlerPtr()>> register_table_;
std::function<void(EventPtr)> subscriber_ = nullptr; std::function<void(EventPtr)> subscriber_ = nullptr;
bool running_; bool running_ = false;
bool enable_loader_ = true; bool enable_loader_ = true;
bool enable_executor_ = true; bool enable_executor_ = true;
std::thread loader_thread_; std::thread loader_thread_;
std::thread executor_thread_; std::thread executor_thread_;
bool load_flag_; bool load_flag_ = false;
bool exec_flag_; bool exec_flag_ = false;
std::mutex load_mutex_; std::mutex load_mutex_;
std::mutex exec_mutex_; std::mutex exec_mutex_;
std::condition_variable load_cv_; std::condition_variable load_cv_;
......
...@@ -24,12 +24,6 @@ XDeleteTask::Execute() { ...@@ -24,12 +24,6 @@ XDeleteTask::Execute() {
delete_context_ptr_->ResourceDone(); delete_context_ptr_->ResourceDone();
} }
TaskPtr
XDeleteTask::Clone() {
auto task = std::make_shared<XDeleteTask>(delete_context_ptr_);
return task;
}
} }
} }
} }
...@@ -24,9 +24,6 @@ public: ...@@ -24,9 +24,6 @@ public:
void void
Execute() override; Execute() override;
TaskPtr
Clone() override;
public: public:
DeleteContextPtr delete_context_ptr_; DeleteContextPtr delete_context_ptr_;
}; };
......
...@@ -193,16 +193,6 @@ XSearchTask::Execute() { ...@@ -193,16 +193,6 @@ XSearchTask::Execute() {
index_engine_ = nullptr; index_engine_ = nullptr;
} }
TaskPtr
XSearchTask::Clone() {
auto ret = std::make_shared<XSearchTask>(file_);
ret->index_id_ = index_id_;
ret->index_engine_ = index_engine_->Clone();
ret->search_contexts_ = search_contexts_;
ret->metric_l2 = metric_l2;
return ret;
}
Status XSearchTask::ClusterResult(const std::vector<long> &output_ids, Status XSearchTask::ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence, const std::vector<float> &output_distence,
uint64_t nq, uint64_t nq,
......
...@@ -23,9 +23,6 @@ public: ...@@ -23,9 +23,6 @@ public:
void void
Execute() override; Execute() override;
TaskPtr
Clone() override;
public: public:
static Status ClusterResult(const std::vector<long> &output_ids, static Status ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence, const std::vector<float> &output_distence,
......
...@@ -68,14 +68,9 @@ public: ...@@ -68,14 +68,9 @@ public:
virtual void virtual void
Execute() = 0; Execute() = 0;
// TODO: dont use this method to support task move
virtual TaskPtr
Clone() = 0;
public: public:
Path task_path_; Path task_path_;
std::vector<SearchContextPtr> search_contexts_; std::vector<SearchContextPtr> search_contexts_;
ScheduleTaskPtr task_;
TaskType type_; TaskType type_;
TaskLabelPtr label_ = nullptr; TaskLabelPtr label_ = nullptr;
}; };
......
...@@ -21,7 +21,6 @@ TaskConvert(const ScheduleTaskPtr &schedule_task) { ...@@ -21,7 +21,6 @@ TaskConvert(const ScheduleTaskPtr &schedule_task) {
auto task = std::make_shared<XSearchTask>(load_task->file_); auto task = std::make_shared<XSearchTask>(load_task->file_);
task->label() = std::make_shared<DefaultLabel>(); task->label() = std::make_shared<DefaultLabel>();
task->search_contexts_ = load_task->search_contexts_; task->search_contexts_ = load_task->search_contexts_;
task->task_ = schedule_task;
return task; return task;
} }
case ScheduleTaskType::kDelete: { case ScheduleTaskType::kDelete: {
......
...@@ -27,15 +27,6 @@ TestTask::Execute() { ...@@ -27,15 +27,6 @@ TestTask::Execute() {
done_ = true; done_ = true;
} }
TaskPtr
TestTask::Clone() {
TableFileSchemaPtr dummy = nullptr;
auto ret = std::make_shared<TestTask>(dummy);
ret->load_count_ = load_count_;
ret->exec_count_ = exec_count_;
return ret;
}
void void
TestTask::Wait() { TestTask::Wait() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
......
...@@ -23,9 +23,6 @@ public: ...@@ -23,9 +23,6 @@ public:
void void
Execute() override; Execute() override;
TaskPtr
Clone() override;
void void
Wait(); Wait();
......
...@@ -25,6 +25,7 @@ public: ...@@ -25,6 +25,7 @@ public:
} }
protected: protected:
explicit
TaskLabel(TaskLabelType type) : type_(type) {} TaskLabel(TaskLabelType type) : type_(type) {}
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册