提交 58a3f57c 编写于 作者: P peng.xu

Merge branch 'branch-0.4.0' into 'branch-0.4.0'

update scheduler_test

See merge request megasearch/milvus!507

Former-commit-id: 8c40fa431acc1f6ba2d8abebd0e449b4fb8d8f45
...@@ -93,7 +93,10 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -93,7 +93,10 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-488 - Improve code format in scheduler - MS-488 - Improve code format in scheduler
- MS-495 - cmake: integrated knowhere - MS-495 - cmake: integrated knowhere
- MS-496 - Change the top_k limitation from 1024 to 2048 - MS-496 - Change the top_k limitation from 1024 to 2048
- MS-502 - Update tasktable_test in scheduler
- MS-504 - Update node_test in scheduler
- MS-505 - Install core unit test and add to coverage - MS-505 - Install core unit test and add to coverage
- MS-508 - Update normal_test in scheduler
## New Feature ## New Feature
- MS-343 - Implement ResourceMgr - MS-343 - Implement ResourceMgr
......
...@@ -108,6 +108,7 @@ void ...@@ -108,6 +108,7 @@ void
Scheduler::OnFinishTask(const EventPtr &event) { Scheduler::OnFinishTask(const EventPtr &event) {
} }
// TODO: refactor the function
void void
Scheduler::OnLoadCompleted(const EventPtr &event) { Scheduler::OnLoadCompleted(const EventPtr &event) {
auto load_completed_event = std::static_pointer_cast<LoadCompletedEvent>(event); auto load_completed_event = std::static_pointer_cast<LoadCompletedEvent>(event);
...@@ -120,18 +121,23 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { ...@@ -120,18 +121,23 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
if (not resource->HasExecutor() && load_completed_event->task_table_item_->Move()) { if (not resource->HasExecutor() && load_completed_event->task_table_item_->Move()) {
auto task = load_completed_event->task_table_item_->task; auto task = load_completed_event->task_table_item_->task;
auto search_task = std::static_pointer_cast<XSearchTask>(task); auto search_task = std::static_pointer_cast<XSearchTask>(task);
auto location = search_task->index_engine_->GetLocation();
bool moved = false; bool moved = false;
for (auto i = 0; i < res_mgr_.lock()->GetNumGpuResource(); ++i) { // to support test task, REFACTOR
auto index = zilliz::milvus::cache::GpuCacheMgr::GetInstance(i)->GetIndex(location); if (auto index_engine = search_task->index_engine_) {
if (index != nullptr) { auto location = index_engine->GetLocation();
moved = true;
auto dest_resource = res_mgr_.lock()->GetResource(ResourceType::GPU, i); for (auto i = 0; i < res_mgr_.lock()->GetNumGpuResource(); ++i) {
Action::PushTaskToResource(load_completed_event->task_table_item_->task, dest_resource); auto index = zilliz::milvus::cache::GpuCacheMgr::GetInstance(i)->GetIndex(location);
break; if (index != nullptr) {
moved = true;
auto dest_resource = res_mgr_.lock()->GetResource(ResourceType::GPU, i);
Action::PushTaskToResource(load_completed_event->task_table_item_->task, dest_resource);
break;
}
} }
} }
if (not moved) { if (not moved) {
Action::PushTaskToNeighbourRandomly(task, resource); Action::PushTaskToNeighbourRandomly(task, resource);
} }
...@@ -147,7 +153,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { ...@@ -147,7 +153,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
// step 1: calculate shortest path per resource, from disk to compute resource // step 1: calculate shortest path per resource, from disk to compute resource
auto compute_resources = res_mgr_.lock()->GetComputeResource(); auto compute_resources = res_mgr_.lock()->GetComputeResource();
std::vector<std::vector<std::string>> paths; std::vector<std::vector<std::string>> paths;
std::vector<uint64_t > transport_costs; std::vector<uint64_t> transport_costs;
for (auto &res : compute_resources) { for (auto &res : compute_resources) {
std::vector<std::string> path; std::vector<std::string> path;
uint64_t transport_cost = ShortestPath(self, res, res_mgr_.lock(), path); uint64_t transport_cost = ShortestPath(self, res, res_mgr_.lock(), path);
...@@ -176,7 +182,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { ...@@ -176,7 +182,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
task->path() = task_path; task->path() = task_path;
} }
if(self->name() == task->path().Last()) { if (self->name() == task->path().Last()) {
self->WakeupLoader(); self->WakeupLoader();
} else { } else {
auto next_res_name = task->path().Next(); auto next_res_name = task->path().Next();
......
...@@ -21,6 +21,7 @@ namespace milvus { ...@@ -21,6 +21,7 @@ namespace milvus {
namespace engine { namespace engine {
// TODO: refactor, not friendly to unittest, logical in framework code
class Scheduler { class Scheduler {
public: public:
explicit explicit
......
...@@ -136,7 +136,7 @@ std::vector<uint64_t> ...@@ -136,7 +136,7 @@ std::vector<uint64_t>
TaskTable::PickToLoad(uint64_t limit) { TaskTable::PickToLoad(uint64_t limit) {
std::vector<uint64_t> indexes; std::vector<uint64_t> indexes;
bool cross = false; bool cross = false;
for (uint64_t i = last_finish_, count = 0; i < table_.size() && count < limit; ++i) { for (uint64_t i = last_finish_ + 1, count = 0; i < table_.size() && count < limit; ++i) {
if (not cross && table_[i]->IsFinish()) { if (not cross && table_[i]->IsFinish()) {
last_finish_ = i; last_finish_ = i;
} else if (table_[i]->state == TaskTableItemState::START) { } else if (table_[i]->state == TaskTableItemState::START) {
...@@ -152,7 +152,7 @@ std::vector<uint64_t> ...@@ -152,7 +152,7 @@ std::vector<uint64_t>
TaskTable::PickToExecute(uint64_t limit) { TaskTable::PickToExecute(uint64_t limit) {
std::vector<uint64_t> indexes; std::vector<uint64_t> indexes;
bool cross = false; bool cross = false;
for (uint64_t i = last_finish_, count = 0; i < table_.size() && count < limit; ++i) { for (uint64_t i = last_finish_ + 1, count = 0; i < table_.size() && count < limit; ++i) {
if (not cross && table_[i]->IsFinish()) { if (not cross && table_[i]->IsFinish()) {
last_finish_ = i; last_finish_ = i;
} else if (table_[i]->state == TaskTableItemState::LOADED) { } else if (table_[i]->state == TaskTableItemState::LOADED) {
...@@ -200,15 +200,15 @@ TaskTable::Get(uint64_t index) { ...@@ -200,15 +200,15 @@ TaskTable::Get(uint64_t index) {
return table_[index]; return table_[index];
} }
void //void
TaskTable::Clear() { //TaskTable::Clear() {
// find first task is NOT (done or moved), erase from begin to it; //// find first task is NOT (done or moved), erase from begin to it;
// auto iterator = table_.begin(); //// auto iterator = table_.begin();
// while (iterator->state == TaskTableItemState::EXECUTED or //// while (iterator->state == TaskTableItemState::EXECUTED or
// iterator->state == TaskTableItemState::MOVED) //// iterator->state == TaskTableItemState::MOVED)
// iterator++; //// iterator++;
// table_.erase(table_.begin(), iterator); //// table_.erase(table_.begin(), iterator);
} //}
std::string std::string
......
...@@ -40,10 +40,10 @@ struct TaskTimestamp { ...@@ -40,10 +40,10 @@ struct TaskTimestamp {
}; };
struct TaskTableItem { struct TaskTableItem {
TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex() {} TaskTableItem() : id(0), task(nullptr), state(TaskTableItemState::INVALID), mutex() {}
TaskTableItem(const TaskTableItem &src) TaskTableItem(const TaskTableItem &src) = delete;
: id(src.id), state(src.state), mutex() {} TaskTableItem(TaskTableItem &&) = delete;
uint64_t id; // auto increment from 0; uint64_t id; // auto increment from 0;
TaskPtr task; // the task; TaskPtr task; // the task;
...@@ -114,8 +114,8 @@ public: ...@@ -114,8 +114,8 @@ public:
* Remove sequence task which is DONE or MOVED from front; * Remove sequence task which is DONE or MOVED from front;
* Called by ? * Called by ?
*/ */
void // void
Clear(); // Clear();
/* /*
* Return true if task table empty, otherwise false; * Return true if task table empty, otherwise false;
...@@ -229,7 +229,9 @@ private: ...@@ -229,7 +229,9 @@ private:
std::function<void(void)> subscriber_ = nullptr; std::function<void(void)> subscriber_ = nullptr;
// cache last finish avoid Pick task from begin always // cache last finish avoid Pick task from begin always
uint64_t last_finish_ = 0; // pick from (last_finish_ + 1)
// init with -1, pick from (last_finish_ + 1) = 0
uint64_t last_finish_ = -1;
}; };
......
...@@ -17,27 +17,6 @@ Node::Node() { ...@@ -17,27 +17,6 @@ Node::Node() {
id_ = counter++; id_ = counter++;
} }
void Node::DelNeighbour(const NeighbourNodePtr &neighbour_ptr) {
std::lock_guard<std::mutex> lk(mutex_);
if (auto s = neighbour_ptr.lock()) {
auto search = neighbours_.find(s->id_);
if (search != neighbours_.end()) {
neighbours_.erase(search);
}
}
}
bool Node::IsNeighbour(const NeighbourNodePtr &neighbour_ptr) {
std::lock_guard<std::mutex> lk(mutex_);
if (auto s = neighbour_ptr.lock()) {
auto search = neighbours_.find(s->id_);
if (search != neighbours_.end()) {
return true;
}
}
return false;
}
std::vector<Neighbour> Node::GetNeighbours() { std::vector<Neighbour> Node::GetNeighbours() {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
std::vector<Neighbour> ret; std::vector<Neighbour> ret;
...@@ -48,8 +27,13 @@ std::vector<Neighbour> Node::GetNeighbours() { ...@@ -48,8 +27,13 @@ std::vector<Neighbour> Node::GetNeighbours() {
} }
std::string Node::Dump() { std::string Node::Dump() {
// TODO(linxj): what's that? std::stringstream ss;
return std::__cxx11::string(); ss << "<Node, id=" << std::to_string(id_) << ">::neighbours:" << std::endl;
for (auto &neighbour : neighbours_) {
ss << "\t<Neighbour, id=" << std::to_string(neighbour.first);
ss << ", connection: " << neighbour.second.connection.Dump() << ">" << std::endl;
}
return ss.str();
} }
void Node::AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection) { void Node::AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection) {
......
...@@ -37,12 +37,6 @@ public: ...@@ -37,12 +37,6 @@ public:
void void
AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection); AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection);
void
DelNeighbour(const NeighbourNodePtr &neighbour_ptr);
bool
IsNeighbour(const NeighbourNodePtr& neighbour_ptr);
std::vector<Neighbour> std::vector<Neighbour>
GetNeighbours(); GetNeighbours();
......
...@@ -83,11 +83,14 @@ CollectFileMetrics(int file_type, size_t file_size) { ...@@ -83,11 +83,14 @@ CollectFileMetrics(int file_type, size_t file_size) {
XSearchTask::XSearchTask(TableFileSchemaPtr file) XSearchTask::XSearchTask(TableFileSchemaPtr file)
: Task(TaskType::SearchTask), file_(file) { : Task(TaskType::SearchTask), file_(file) {
index_engine_ = EngineFactory::Build(file_->dimension_, if (file_) {
file_->location_, index_engine_ = EngineFactory::Build(file_->dimension_,
(EngineType) file_->engine_type_, file_->location_,
(MetricType) file_->metric_type_, (EngineType) file_->engine_type_,
file_->nlist_); (MetricType) file_->metric_type_,
file_->nlist_);
}
} }
void void
......
...@@ -12,6 +12,7 @@ namespace zilliz { ...@@ -12,6 +12,7 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
// TODO: rewrite
class XSearchTask : public Task { class XSearchTask : public Task {
public: public:
explicit explicit
......
...@@ -34,6 +34,7 @@ class Task; ...@@ -34,6 +34,7 @@ class Task;
using TaskPtr = std::shared_ptr<Task>; using TaskPtr = std::shared_ptr<Task>;
// TODO: re-design
class Task { class Task {
public: public:
explicit explicit
......
...@@ -13,7 +13,7 @@ namespace milvus { ...@@ -13,7 +13,7 @@ namespace milvus {
namespace engine { namespace engine {
TestTask::TestTask(TableFileSchemaPtr& file) : XSearchTask(file) {} TestTask::TestTask(TableFileSchemaPtr &file) : XSearchTask(file) {}
void void
TestTask::Load(LoadType type, uint8_t device_id) { TestTask::Load(LoadType type, uint8_t device_id) {
...@@ -22,9 +22,12 @@ TestTask::Load(LoadType type, uint8_t device_id) { ...@@ -22,9 +22,12 @@ TestTask::Load(LoadType type, uint8_t device_id) {
void void
TestTask::Execute() { TestTask::Execute() {
std::lock_guard<std::mutex> lock(mutex_); {
exec_count_++; std::lock_guard<std::mutex> lock(mutex_);
done_ = true; exec_count_++;
done_ = true;
}
cv_.notify_one();
} }
void void
......
...@@ -11,58 +11,71 @@ protected: ...@@ -11,58 +11,71 @@ protected:
node1_ = std::make_shared<Node>(); node1_ = std::make_shared<Node>();
node2_ = std::make_shared<Node>(); node2_ = std::make_shared<Node>();
node3_ = std::make_shared<Node>(); node3_ = std::make_shared<Node>();
node4_ = std::make_shared<Node>(); isolated_node1_ = std::make_shared<Node>();
isolated_node2_ = std::make_shared<Node>();
auto pcie = Connection("PCIe", 11.0); auto pcie = Connection("PCIe", 11.0);
node1_->AddNeighbour(node2_, pcie); node1_->AddNeighbour(node2_, pcie);
node1_->AddNeighbour(node3_, pcie);
node2_->AddNeighbour(node1_, pcie); node2_->AddNeighbour(node1_, pcie);
} }
NodePtr node1_; NodePtr node1_;
NodePtr node2_; NodePtr node2_;
NodePtr node3_; NodePtr node3_;
NodePtr node4_; NodePtr isolated_node1_;
NodePtr isolated_node2_;
}; };
TEST_F(NodeTest, add_neighbour) { TEST_F(NodeTest, add_neighbour) {
ASSERT_EQ(node3_->GetNeighbours().size(), 0); ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 0);
ASSERT_EQ(node4_->GetNeighbours().size(), 0); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0);
auto pcie = Connection("PCIe", 11.0); auto pcie = Connection("PCIe", 11.0);
node3_->AddNeighbour(node4_, pcie); isolated_node1_->AddNeighbour(isolated_node2_, pcie);
node4_->AddNeighbour(node3_, pcie); ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 1);
ASSERT_EQ(node3_->GetNeighbours().size(), 1); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0);
ASSERT_EQ(node4_->GetNeighbours().size(), 1);
} }
TEST_F(NodeTest, del_neighbour) { TEST_F(NodeTest, repeat_add_neighbour) {
ASSERT_EQ(node1_->GetNeighbours().size(), 1); ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 0);
ASSERT_EQ(node2_->GetNeighbours().size(), 1); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0);
ASSERT_EQ(node3_->GetNeighbours().size(), 0); auto pcie = Connection("PCIe", 11.0);
node1_->DelNeighbour(node2_); isolated_node1_->AddNeighbour(isolated_node2_, pcie);
node2_->DelNeighbour(node2_); isolated_node1_->AddNeighbour(isolated_node2_, pcie);
node3_->DelNeighbour(node2_); ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 1);
ASSERT_EQ(node1_->GetNeighbours().size(), 0); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0);
ASSERT_EQ(node2_->GetNeighbours().size(), 1);
ASSERT_EQ(node3_->GetNeighbours().size(), 0);
} }
TEST_F(NodeTest, is_neighbour) { TEST_F(NodeTest, get_neighbours) {
ASSERT_TRUE(node1_->IsNeighbour(node2_)); {
ASSERT_TRUE(node2_->IsNeighbour(node1_)); bool n2 = false, n3 = false;
auto node1_neighbours = node1_->GetNeighbours();
ASSERT_EQ(node1_neighbours.size(), 2);
for (auto &n : node1_neighbours) {
if (n.neighbour_node.lock() == node2_) n2 = true;
if (n.neighbour_node.lock() == node3_) n3 = true;
}
ASSERT_TRUE(n2);
ASSERT_TRUE(n3);
}
ASSERT_FALSE(node1_->IsNeighbour(node3_)); {
ASSERT_FALSE(node2_->IsNeighbour(node3_)); auto node2_neighbours = node2_->GetNeighbours();
ASSERT_FALSE(node3_->IsNeighbour(node1_)); ASSERT_EQ(node2_neighbours.size(), 1);
ASSERT_FALSE(node3_->IsNeighbour(node2_)); ASSERT_EQ(node2_neighbours[0].neighbour_node.lock(), node1_);
}
{
auto node3_neighbours = node3_->GetNeighbours();
ASSERT_EQ(node3_neighbours.size(), 0);
}
} }
TEST_F(NodeTest, get_neighbours) { TEST_F(NodeTest, dump) {
auto node1_neighbours = node1_->GetNeighbours(); std::cout << node1_->Dump();
ASSERT_EQ(node1_neighbours.size(), 1); ASSERT_FALSE(node1_->Dump().empty());
ASSERT_EQ(node1_neighbours[0].neighbour_node.lock(), node2_);
auto node2_neighbours = node2_->GetNeighbours(); std::cout << node2_->Dump();
ASSERT_EQ(node2_neighbours.size(), 1); ASSERT_FALSE(node2_->Dump().empty());
ASSERT_EQ(node2_neighbours[0].neighbour_node.lock(), node1_);
} }
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "scheduler/ResourceMgr.h" #include "scheduler/ResourceMgr.h"
#include "scheduler/Scheduler.h" #include "scheduler/Scheduler.h"
#include "scheduler/task/TestTask.h" #include "scheduler/task/TestTask.h"
#include "scheduler/tasklabel/DefaultLabel.h"
#include "scheduler/SchedInst.h" #include "scheduler/SchedInst.h"
#include "utils/Log.h" #include "utils/Log.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -9,48 +10,44 @@ ...@@ -9,48 +10,44 @@
using namespace zilliz::milvus::engine; using namespace zilliz::milvus::engine;
TEST(normal_test, test1) {
TEST(normal_test, inst_test) {
// ResourceMgr only compose resources, provide unified event // ResourceMgr only compose resources, provide unified event
// auto res_mgr = std::make_shared<ResourceMgr>();
auto res_mgr = ResMgrInst::GetInstance(); auto res_mgr = ResMgrInst::GetInstance();
auto disk = res_mgr->Add(ResourceFactory::Create("disk", "ssd", true, false));
auto cpu = res_mgr->Add(ResourceFactory::Create("cpu", "CPU", 0));
auto gpu1 = res_mgr->Add(ResourceFactory::Create("gpu", "gpu0", false, false));
auto gpu2 = res_mgr->Add(ResourceFactory::Create("gpu", "gpu2", false, false));
auto IO = Connection("IO", 500.0); res_mgr->Add(ResourceFactory::Create("disk", "DISK", 0, true, false));
auto PCIE = Connection("IO", 11000.0); res_mgr->Add(ResourceFactory::Create("cpu", "CPU", 0, true, true));
res_mgr->Connect(disk, cpu, IO);
res_mgr->Connect(cpu, gpu1, PCIE);
res_mgr->Connect(cpu, gpu2, PCIE);
res_mgr->Start(); auto IO = Connection("IO", 500.0);
res_mgr->Connect("disk", "cpu", IO);
// auto scheduler = new Scheduler(res_mgr);
auto scheduler = SchedInst::GetInstance(); auto scheduler = SchedInst::GetInstance();
res_mgr->Start();
scheduler->Start(); scheduler->Start();
const uint64_t NUM_TASK = 1000; const uint64_t NUM_TASK = 1000;
std::vector<std::shared_ptr<TestTask>> tasks; std::vector<std::shared_ptr<TestTask>> tasks;
TableFileSchemaPtr dummy = nullptr; TableFileSchemaPtr dummy = nullptr;
for (uint64_t i = 0; i < NUM_TASK; ++i) { auto disks = res_mgr->GetDiskResources();
if (auto observe = disk.lock()) { ASSERT_FALSE(disks.empty());
if (auto observe = disks[0].lock()) {
for (uint64_t i = 0; i < NUM_TASK; ++i) {
auto task = std::make_shared<TestTask>(dummy); auto task = std::make_shared<TestTask>(dummy);
task->label() = std::make_shared<DefaultLabel>();
tasks.push_back(task); tasks.push_back(task);
observe->task_table().Put(task); observe->task_table().Put(task);
} }
} }
sleep(1); for (auto &task : tasks) {
task->Wait();
ASSERT_EQ(task->load_count_, 1);
ASSERT_EQ(task->exec_count_, 1);
}
scheduler->Stop(); scheduler->Stop();
res_mgr->Stop(); res_mgr->Stop();
auto pcpu = cpu.lock();
for (uint64_t i = 0; i < NUM_TASK; ++i) {
auto task = std::static_pointer_cast<TestTask>(pcpu->task_table()[i]->task);
ASSERT_EQ(task->load_count_, 1);
ASSERT_EQ(task->exec_count_, 1);
}
} }
...@@ -5,30 +5,37 @@ ...@@ -5,30 +5,37 @@
using namespace zilliz::milvus::engine; using namespace zilliz::milvus::engine;
/************ TaskTableBaseTest ************/
class TaskTableItemTest : public ::testing::Test { class TaskTableItemTest : public ::testing::Test {
protected: protected:
void void
SetUp() override { SetUp() override {
item1_.id = 0; std::vector<TaskTableItemState> states{
item1_.state = TaskTableItemState::MOVED; TaskTableItemState::INVALID,
item1_.priority = 10; TaskTableItemState::START,
TaskTableItemState::LOADING,
TaskTableItemState::LOADED,
TaskTableItemState::EXECUTING,
TaskTableItemState::EXECUTED,
TaskTableItemState::MOVING,
TaskTableItemState::MOVED};
for (auto &state : states) {
auto item = std::make_shared<TaskTableItem>();
item->state = state;
items_.emplace_back(item);
}
} }
TaskTableItem default_; TaskTableItem default_;
TaskTableItem item1_; std::vector<TaskTableItemPtr> items_;
}; };
TEST_F(TaskTableItemTest, construct) { TEST_F(TaskTableItemTest, construct) {
ASSERT_EQ(default_.id, 0); ASSERT_EQ(default_.id, 0);
ASSERT_EQ(default_.task, nullptr);
ASSERT_EQ(default_.state, TaskTableItemState::INVALID); ASSERT_EQ(default_.state, TaskTableItemState::INVALID);
ASSERT_EQ(default_.priority, 0);
}
TEST_F(TaskTableItemTest, copy) {
TaskTableItem another(item1_);
ASSERT_EQ(another.id, item1_.id);
ASSERT_EQ(another.state, item1_.state);
ASSERT_EQ(another.priority, item1_.priority);
} }
TEST_F(TaskTableItemTest, destruct) { TEST_F(TaskTableItemTest, destruct) {
...@@ -36,6 +43,107 @@ TEST_F(TaskTableItemTest, destruct) { ...@@ -36,6 +43,107 @@ TEST_F(TaskTableItemTest, destruct) {
delete p_item; delete p_item;
} }
TEST_F(TaskTableItemTest, is_finish) {
for (auto &item : items_) {
if (item->state == TaskTableItemState::EXECUTED
|| item->state == TaskTableItemState::MOVED) {
ASSERT_TRUE(item->IsFinish());
} else {
ASSERT_FALSE(item->IsFinish());
}
}
}
TEST_F(TaskTableItemTest, dump) {
for (auto &item : items_) {
ASSERT_FALSE(item->Dump().empty());
}
}
TEST_F(TaskTableItemTest, load) {
for (auto &item : items_) {
auto before_state = item->state;
auto ret = item->Load();
if (before_state == TaskTableItemState::START) {
ASSERT_TRUE(ret);
ASSERT_EQ(item->state, TaskTableItemState::LOADING);
} else {
ASSERT_FALSE(ret);
ASSERT_EQ(item->state, before_state);
}
}
}
TEST_F(TaskTableItemTest, loaded) {
for (auto &item : items_) {
auto before_state = item->state;
auto ret = item->Loaded();
if (before_state == TaskTableItemState::LOADING) {
ASSERT_TRUE(ret);
ASSERT_EQ(item->state, TaskTableItemState::LOADED);
} else {
ASSERT_FALSE(ret);
ASSERT_EQ(item->state, before_state);
}
}
}
TEST_F(TaskTableItemTest, execute) {
for (auto &item : items_) {
auto before_state = item->state;
auto ret = item->Execute();
if (before_state == TaskTableItemState::LOADED) {
ASSERT_TRUE(ret);
ASSERT_EQ(item->state, TaskTableItemState::EXECUTING);
} else {
ASSERT_FALSE(ret);
ASSERT_EQ(item->state, before_state);
}
}
}
TEST_F(TaskTableItemTest, executed) {
for (auto &item : items_) {
auto before_state = item->state;
auto ret = item->Executed();
if (before_state == TaskTableItemState::EXECUTING) {
ASSERT_TRUE(ret);
ASSERT_EQ(item->state, TaskTableItemState::EXECUTED);
} else {
ASSERT_FALSE(ret);
ASSERT_EQ(item->state, before_state);
}
}
}
TEST_F(TaskTableItemTest, move) {
for (auto &item : items_) {
auto before_state = item->state;
auto ret = item->Move();
if (before_state == TaskTableItemState::LOADED) {
ASSERT_TRUE(ret);
ASSERT_EQ(item->state, TaskTableItemState::MOVING);
} else {
ASSERT_FALSE(ret);
ASSERT_EQ(item->state, before_state);
}
}
}
TEST_F(TaskTableItemTest, moved) {
for (auto &item : items_) {
auto before_state = item->state;
auto ret = item->Moved();
if (before_state == TaskTableItemState::MOVING) {
ASSERT_TRUE(ret);
ASSERT_EQ(item->state, TaskTableItemState::MOVED);
} else {
ASSERT_FALSE(ret);
ASSERT_EQ(item->state, before_state);
}
}
}
/************ TaskTableBaseTest ************/ /************ TaskTableBaseTest ************/
...@@ -55,6 +163,16 @@ protected: ...@@ -55,6 +163,16 @@ protected:
TaskTable empty_table_; TaskTable empty_table_;
}; };
TEST_F(TaskTableBaseTest, subscriber) {
bool flag = false;
auto callback = [&]() {
flag = true;
};
empty_table_.RegisterSubscriber(callback);
empty_table_.Put(task1_);
ASSERT_TRUE(flag);
}
TEST_F(TaskTableBaseTest, put_task) { TEST_F(TaskTableBaseTest, put_task) {
empty_table_.Put(task1_); empty_table_.Put(task1_);
...@@ -78,6 +196,125 @@ TEST_F(TaskTableBaseTest, put_empty_batch) { ...@@ -78,6 +196,125 @@ TEST_F(TaskTableBaseTest, put_empty_batch) {
empty_table_.Put(tasks); empty_table_.Put(tasks);
} }
TEST_F(TaskTableBaseTest, empty) {
ASSERT_TRUE(empty_table_.Empty());
empty_table_.Put(task1_);
ASSERT_FALSE(empty_table_.Empty());
}
TEST_F(TaskTableBaseTest, size) {
ASSERT_EQ(empty_table_.Size(), 0);
empty_table_.Put(task1_);
ASSERT_EQ(empty_table_.Size(), 1);
}
TEST_F(TaskTableBaseTest, operator_) {
empty_table_.Put(task1_);
ASSERT_EQ(empty_table_.Get(0), empty_table_[0]);
}
TEST_F(TaskTableBaseTest, pick_to_load) {
const size_t NUM_TASKS = 10;
for (size_t i = 0; i < NUM_TASKS; ++i) {
empty_table_.Put(task1_);
}
empty_table_[0]->state = TaskTableItemState::MOVED;
empty_table_[1]->state = TaskTableItemState::EXECUTED;
auto indexes = empty_table_.PickToLoad(1);
ASSERT_EQ(indexes.size(), 1);
ASSERT_EQ(indexes[0], 2);
}
TEST_F(TaskTableBaseTest, pick_to_load_limit) {
const size_t NUM_TASKS = 10;
for (size_t i = 0; i < NUM_TASKS; ++i) {
empty_table_.Put(task1_);
}
empty_table_[0]->state = TaskTableItemState::MOVED;
empty_table_[1]->state = TaskTableItemState::EXECUTED;
auto indexes = empty_table_.PickToLoad(3);
ASSERT_EQ(indexes.size(), 3);
ASSERT_EQ(indexes[0], 2);
ASSERT_EQ(indexes[1], 3);
ASSERT_EQ(indexes[2], 4);
}
TEST_F(TaskTableBaseTest, pick_to_load_cache) {
const size_t NUM_TASKS = 10;
for (size_t i = 0; i < NUM_TASKS; ++i) {
empty_table_.Put(task1_);
}
empty_table_[0]->state = TaskTableItemState::MOVED;
empty_table_[1]->state = TaskTableItemState::EXECUTED;
// first pick, non-cache
auto indexes = empty_table_.PickToLoad(1);
ASSERT_EQ(indexes.size(), 1);
ASSERT_EQ(indexes[0], 2);
// second pick, iterate from 2
// invalid state change
empty_table_[1]->state = TaskTableItemState::START;
indexes = empty_table_.PickToLoad(1);
ASSERT_EQ(indexes.size(), 1);
ASSERT_EQ(indexes[0], 2);
}
TEST_F(TaskTableBaseTest, pick_to_execute) {
const size_t NUM_TASKS = 10;
for (size_t i = 0; i < NUM_TASKS; ++i) {
empty_table_.Put(task1_);
}
empty_table_[0]->state = TaskTableItemState::MOVED;
empty_table_[1]->state = TaskTableItemState::EXECUTED;
empty_table_[2]->state = TaskTableItemState::LOADED;
auto indexes = empty_table_.PickToExecute(1);
ASSERT_EQ(indexes.size(), 1);
ASSERT_EQ(indexes[0], 2);
}
TEST_F(TaskTableBaseTest, pick_to_execute_limit) {
const size_t NUM_TASKS = 10;
for (size_t i = 0; i < NUM_TASKS; ++i) {
empty_table_.Put(task1_);
}
empty_table_[0]->state = TaskTableItemState::MOVED;
empty_table_[1]->state = TaskTableItemState::EXECUTED;
empty_table_[2]->state = TaskTableItemState::LOADED;
empty_table_[3]->state = TaskTableItemState::LOADED;
auto indexes = empty_table_.PickToExecute(3);
ASSERT_EQ(indexes.size(), 2);
ASSERT_EQ(indexes[0], 2);
ASSERT_EQ(indexes[1], 3);
}
TEST_F(TaskTableBaseTest, pick_to_execute_cache) {
const size_t NUM_TASKS = 10;
for (size_t i = 0; i < NUM_TASKS; ++i) {
empty_table_.Put(task1_);
}
empty_table_[0]->state = TaskTableItemState::MOVED;
empty_table_[1]->state = TaskTableItemState::EXECUTED;
empty_table_[2]->state = TaskTableItemState::LOADED;
// first pick, non-cache
auto indexes = empty_table_.PickToExecute(1);
ASSERT_EQ(indexes.size(), 1);
ASSERT_EQ(indexes[0], 2);
// second pick, iterate from 2
// invalid state change
empty_table_[1]->state = TaskTableItemState::START;
indexes = empty_table_.PickToExecute(1);
ASSERT_EQ(indexes.size(), 1);
ASSERT_EQ(indexes[0], 2);
}
/************ TaskTableAdvanceTest ************/ /************ TaskTableAdvanceTest ************/
class TaskTableAdvanceTest : public ::testing::Test { class TaskTableAdvanceTest : public ::testing::Test {
...@@ -104,25 +341,116 @@ protected: ...@@ -104,25 +341,116 @@ protected:
}; };
TEST_F(TaskTableAdvanceTest, load) { TEST_F(TaskTableAdvanceTest, load) {
table1_.Load(1); std::vector<TaskTableItemState> before_state;
table1_.Loaded(2); for (auto &task : table1_) {
before_state.push_back(task->state);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
table1_.Load(i);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
if (before_state[i] == TaskTableItemState::START) {
ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::LOADING);
} else {
ASSERT_EQ(table1_.Get(i)->state, before_state[i]);
}
}
}
TEST_F(TaskTableAdvanceTest, loaded) {
std::vector<TaskTableItemState> before_state;
for (auto &task : table1_) {
before_state.push_back(task->state);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
table1_.Loaded(i);
}
ASSERT_EQ(table1_.Get(1)->state, TaskTableItemState::LOADING); for (size_t i = 0; i < table1_.Size(); ++i) {
ASSERT_EQ(table1_.Get(2)->state, TaskTableItemState::LOADED); if (before_state[i] == TaskTableItemState::LOADING) {
ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::LOADED);
} else {
ASSERT_EQ(table1_.Get(i)->state, before_state[i]);
}
}
} }
TEST_F(TaskTableAdvanceTest, execute) { TEST_F(TaskTableAdvanceTest, execute) {
table1_.Execute(3); std::vector<TaskTableItemState> before_state;
table1_.Executed(4); for (auto &task : table1_) {
before_state.push_back(task->state);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
table1_.Execute(i);
}
ASSERT_EQ(table1_.Get(3)->state, TaskTableItemState::EXECUTING); for (size_t i = 0; i < table1_.Size(); ++i) {
ASSERT_EQ(table1_.Get(4)->state, TaskTableItemState::EXECUTED); if (before_state[i] == TaskTableItemState::LOADED) {
ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::EXECUTING);
} else {
ASSERT_EQ(table1_.Get(i)->state, before_state[i]);
}
}
}
TEST_F(TaskTableAdvanceTest, executed) {
std::vector<TaskTableItemState> before_state;
for (auto &task : table1_) {
before_state.push_back(task->state);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
table1_.Executed(i);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
if (before_state[i] == TaskTableItemState::EXECUTING) {
ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::EXECUTED);
} else {
ASSERT_EQ(table1_.Get(i)->state, before_state[i]);
}
}
} }
TEST_F(TaskTableAdvanceTest, move) { TEST_F(TaskTableAdvanceTest, move) {
table1_.Move(3); std::vector<TaskTableItemState> before_state;
table1_.Moved(6); for (auto &task : table1_) {
before_state.push_back(task->state);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
table1_.Move(i);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
if (before_state[i] == TaskTableItemState::LOADED) {
ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::MOVING);
} else {
ASSERT_EQ(table1_.Get(i)->state, before_state[i]);
}
}
}
ASSERT_EQ(table1_.Get(3)->state, TaskTableItemState::MOVING); TEST_F(TaskTableAdvanceTest, moved) {
ASSERT_EQ(table1_.Get(6)->state, TaskTableItemState::MOVED); std::vector<TaskTableItemState> before_state;
for (auto &task : table1_) {
before_state.push_back(task->state);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
table1_.Moved(i);
}
for (size_t i = 0; i < table1_.Size(); ++i) {
if (before_state[i] == TaskTableItemState::MOVING) {
ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::MOVED);
} else {
ASSERT_EQ(table1_.Get(i)->state, before_state[i]);
}
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册