提交 d551ea6b 编写于 作者: A azural 提交者: Zhongjun Ni

framework: Fix future_error while getting result of the task.

上级 1ba2e85f
......@@ -26,7 +26,7 @@
#include "cybertron/common/global_data.h"
#include "cybertron/croutine/croutine.h"
#include "cybertron/croutine/routine_factory.h"
#include "cybertron/data/data_dispatcher.h"
#include "cybertron/data/data_notifier.h"
#include "cybertron/init.h"
#include "cybertron/scheduler/scheduler.h"
......@@ -36,8 +36,7 @@ namespace cybertron {
static const char* task_prefix = "/internal/task/";
template <typename Type, typename Ret>
class TaskData {
public:
struct TaskData {
std::shared_ptr<Type> raw_data;
std::promise<Ret> prom;
};
......@@ -45,52 +44,45 @@ class TaskData {
template <typename T, typename R, typename Derived>
class TaskBase {
public:
~TaskBase();
virtual ~TaskBase();
void Stop();
bool IsRunning() const;
std::future<R> Execute(const std::shared_ptr<T>& val);
std::shared_ptr<TaskData<T, R>> GetTaskData();
protected:
void RegisterCallback(
std::function<void(const std::shared_ptr<TaskData<T, R>>&)>&& func);
void RegisterCallback(std::function<void()>&& func);
uint32_t num_threads_;
std::string name_;
std::vector<uint64_t> name_ids_;
private:
TaskBase(const std::string& name, const uint8_t& num_threads = 1);
friend Derived;
std::string name_;
bool running_;
uint32_t num_threads_;
uint64_t task_id_;
mutable std::mutex mutex_;
std::list<std::shared_ptr<TaskData<T, R>>> data_list_;
};
template <typename T, typename R, typename Derived>
TaskBase<T, R, Derived>::TaskBase(const std::string& name,
const uint8_t& num_threads)
: num_threads_(num_threads) {
name_ = task_prefix + name;
name_ids_.reserve(num_threads);
}
template <typename T, typename R, typename Derived>
void TaskBase<T, R, Derived>::RegisterCallback(
std::function<void(const std::shared_ptr<TaskData<T, R>>&)>&& func) {
for (int i = 0; i < num_threads_; ++i) {
auto channel_name = name_ + std::to_string(i);
auto name_id = common::GlobalData::RegisterChannel(channel_name);
name_ids_.push_back(std::move(name_id));
auto dv = std::make_shared<data::DataVisitor<TaskData<T, R>>>(name_id, 1);
croutine::RoutineFactory factory =
croutine::CreateRoutineFactory<TaskData<T, R>>(func, dv);
scheduler::Scheduler::Instance()->CreateTask(factory, channel_name);
: num_threads_(num_threads), running_(true) {
if (num_threads_ > scheduler::Scheduler::ProcessorNum()) {
num_threads_ = scheduler::Scheduler::ProcessorNum();
}
name_ = task_prefix + name;
task_id_ = common::GlobalData::RegisterChannel(name_);
}
template <typename T, typename R, typename Derived>
void TaskBase<T, R, Derived>::RegisterCallback(std::function<void()>&& func) {
croutine::RoutineFactory factory = croutine::CreateRoutineFactory(func);
for (int i = 0; i < num_threads_; ++i) {
auto channel_name = name_ + std::to_string(i);
auto name_id = common::GlobalData::RegisterChannel(channel_name);
name_ids_.push_back(std::move(name_id));
croutine::RoutineFactory factory = croutine::CreateRoutineFactory(func);
scheduler::Scheduler::Instance()->CreateTask(factory, channel_name);
scheduler::Scheduler::Instance()->CreateTask(factory,
name_ + std::to_string(i));
}
}
......@@ -106,9 +98,16 @@ template <typename Function>
Task<T, R>::Task(const std::string& name, Function&& f,
const uint8_t& num_threads)
: TaskBase<T, R, Task<T, R>>(name, num_threads) {
auto func = [f = std::forward<Function&&>(f)](
const std::shared_ptr<TaskData<T, R>>& msg) {
msg->prom.set_value(f(msg->raw_data));
auto func = [ f = std::forward<Function&&>(f), this ]() {
while (this->IsRunning()) {
auto msg = this->GetTaskData();
if (msg == nullptr) {
auto routine = croutine::CRoutine::GetCurrentRoutine();
routine->Sleep(1000);
continue;
}
msg->prom.set_value(f(msg->raw_data));
}
};
this->RegisterCallback(std::move(func));
}
......@@ -128,35 +127,69 @@ template <typename Function>
Task<T, void>::Task(const std::string& name, Function&& f,
const uint8_t& num_threads)
: TaskBase<T, void, Task<T, void>>(name, num_threads) {
auto func = [f = std::forward<Function&&>(f)](
const std::shared_ptr<TaskData<T, void>>& msg) {
f(msg->raw_data);
msg->prom.set_value();
auto func = [ f = std::forward<Function&&>(f), this ]() {
while (this->IsRunning()) {
auto msg = this->GetTaskData();
if (msg == nullptr) {
auto routine = croutine::CRoutine::GetCurrentRoutine();
routine->Sleep(1000);
continue;
}
f(msg->raw_data);
msg->prom.set_value();
}
};
this->RegisterCallback(std::move(func));
}
template <typename T, typename R, typename Derived>
bool TaskBase<T, R, Derived>::IsRunning() const {
std::lock_guard<std::mutex> lg(mutex_);
return running_;
}
template <typename T, typename R, typename Derived>
void TaskBase<T, R, Derived>::Stop() {
std::lock_guard<std::mutex> lg(mutex_);
data_list_.clear();
running_ = false;
}
template <typename T, typename R, typename Derived>
TaskBase<T, R, Derived>::~TaskBase() {
Stop();
for (int i = 0; i < num_threads_; ++i) {
auto channel_name = task_prefix + name_ + std::to_string(i);
scheduler::Scheduler::Instance()->RemoveTask(channel_name);
scheduler::Scheduler::Instance()->RemoveTask(name_ + std::to_string(i));
}
}
template <typename T, typename R, typename Derived>
std::future<R> TaskBase<T, R, Derived>::Execute(const std::shared_ptr<T>& val) {
static auto it = name_ids_.begin();
if (it == name_ids_.end()) {
it = name_ids_.begin();
}
auto task = std::make_shared<TaskData<T, R>>();
task->raw_data = val;
data::DataDispatcher<TaskData<T, R>>::Instance()->Dispatch(*it, task);
++it;
{
std::lock_guard<std::mutex> lg(mutex_);
data_list_.emplace_back(task);
}
// data::DataNotifier::Instance()->Notify(task_id_);
return task->prom.get_future();
}
template <typename T, typename R, typename Derived>
std::shared_ptr<TaskData<T, R>> TaskBase<T, R, Derived>::GetTaskData() {
std::lock_guard<std::mutex> lg(mutex_);
if (!running_) {
return nullptr;
}
if (data_list_.empty()) {
return nullptr;
}
auto task = data_list_.front();
data_list_.pop_front();
return task;
}
template <>
template <typename Function>
Task<void, void>::Task(const std::string& name, Function&& f,
......
......@@ -6,24 +6,78 @@
#include "cybertron/message/raw_message.h"
#include "cybertron/proto/driver.pb.h"
#include "gtest/gtest.h"
using apollo::cybertron::proto::CarStatus;
using apollo::cybertron::Task;
namespace apollo {
namespace cybertron {
namespace scheduler {
void VoidTask() { AINFO << "VoidTask running"; }
int UserTask(const std::shared_ptr<CarStatus>& msg) {
struct Message {
uint64_t id;
};
void Task1() { ADEBUG << "Task1 running"; }
void Task2(const std::shared_ptr<Message>& input) { ADEBUG << "Task2 running"; }
uint64_t Task3(const std::shared_ptr<Message>& input) {
ADEBUG << "Task3 running";
return input->id;
}
uint64_t Task4(const std::shared_ptr<Message>& input) {
ADEBUG << "Task4 running";
usleep(10000);
return input->id;
}
/*
Message UserTask(const std::shared_ptr<CarStatus>& msg) {
AINFO << "receive msg";
return 0;
}
*/
TEST(TaskTest, create_task) {
auto task_1 = std::make_shared<Task<>>("task1", &Task1);
task_1.reset();
auto task_2 = std::make_shared<Task<void>>("task2", &Task1);
task_2.reset();
auto task_3 = std::make_shared<Task<void, void>>("task3", &Task1);
task_3.reset();
}
TEST(TaskTest, return_value) {
auto msg = std::make_shared<Message>();
msg->id = 1;
auto task_1 = std::make_shared<Task<>>("task1", &Task1);
usleep(100000);
task_1.reset();
auto task_2 = std::make_shared<Task<Message, void>>("task2", &Task2);
auto ret_2 = task_2->Execute(msg);
ret_2.get();
auto msg3 = std::make_shared<Message>();
msg3->id = 1;
auto task_3 = std::make_shared<Task<Message, int>>("task3", &Task3);
auto ret_3 = task_3->Execute(msg3);
EXPECT_EQ(ret_3.get(), 1);
ret_3 = task_3->Execute(msg3);
usleep(100000);
EXPECT_EQ(ret_3.get(), 1);
TEST(TaskTest, all) {
std::shared_ptr<apollo::cybertron::Task<CarStatus, int>> task_ = nullptr;
std::shared_ptr<apollo::cybertron::Task<>> void_task_ = nullptr;
task_.reset(new apollo::cybertron::Task<CarStatus, int>("task", &UserTask));
void_task_.reset(
new apollo::cybertron::Task<void, void>("void_task", &VoidTask));
auto task_4 = std::make_shared<Task<Message, uint64_t>>("task4", &Task4, 20);
std::vector<std::future<uint64_t>> results;
for (int i = 0; i < 1000; i++) {
results.emplace_back(task_4->Execute(msg));
}
for (auto& result : results) {
result.get();
}
}
} // namespace scheduler
......
......@@ -44,7 +44,7 @@ int TaskProcessor(const std::shared_ptr<Message>& msg) {
}
void VoidTaskProcessor(const std::shared_ptr<Message>& msg) {
AINFO << "Task Processor[" << msg->task_id
ADEBUG << "Task Processor[" << msg->task_id
<< "] is running: " << msg->msg_id;
}
......@@ -63,7 +63,7 @@ int main(int argc, char* argv[]) {
uint64_t i = 0;
while (!apollo::cybertron::IsShutdown()) {
std::vector<std::future<int>> futures;
for (int j = 0; j < num_threads; ++j) {
for (int j = 0; j < 1000; ++j) {
auto msg = std::make_shared<Message>();
msg->msg_id = i++;
msg->task_id = j;
......@@ -76,7 +76,7 @@ int main(int argc, char* argv[]) {
break;
}
for (auto& future: futures) {
AINFO << "Finish task:" << future.get();
//AINFO << "Finish task:" << future.get();
}
AINFO << "All task are finished.";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册