提交 ad0783cb 编写于 作者: W willzhang4a58

make msg_handle in base actor

上级 8fcade0c
......@@ -60,6 +60,15 @@ KernelCtx Actor::GenDefaultKernelCtx() const {
return ctx;
}
int Actor::HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
if (total_reading_cnt_ == 0) {
msg_handle_ = nullptr;
return 1;
}
return 0;
}
void Actor::AsyncWardKernel(
const KernelCtx& kernel_ctx,
std::function<std::shared_ptr<RegstWarpper>(int64_t)> Regst4RegstDescId) {
......
......@@ -24,7 +24,9 @@ class Actor {
virtual void Init(const TaskProto&, const ThreadCtx&) = 0;
// 1: success, and actor finish
// 0: success, and actor not finish
virtual int ProcessMsg(const ActorMsg&) = 0;
int ProcessMsg(const ActorMsg& msg) {
return (this->*msg_handle_)(msg);
}
int64_t actor_id() const { return actor_id_; }
......@@ -40,6 +42,15 @@ class Actor {
std::unique_ptr<DeviceCtx>& mut_device_ctx() { return device_ctx_; }
KernelCtx GenDefaultKernelCtx() const;
// Msg Handle
using MsgHandle = int (Actor::*)(const ActorMsg&);
void set_msg_handle(MsgHandle val) { msg_handle_ = val; }
#define OF_SET_MSG_HANDLE(val) \
do { \
set_msg_handle(static_cast<MsgHandle>(val)); \
} while(0)
int HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg);
// Status of Produced Registers
int64_t expected_piece_id() const { return expected_piece_id_; }
void AsyncWardKernel(
......@@ -70,6 +81,8 @@ class Actor {
HashMap<std::string, int64_t> name2regst_desc_id_;
std::unique_ptr<DeviceCtx> device_ctx_;
MsgHandle msg_handle_;
// Status of Produced Registers
int64_t expected_piece_id_;
......
......@@ -11,11 +11,7 @@ void BoxingActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx)
num_of_eord_ = 0;
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
cur_msg_handle_ = &BoxingActor::HandleBoxing;
}
int BoxingActor::ProcessMsg(const ActorMsg& msg) {
return (this->*cur_msg_handle_)(msg);
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxing);
}
int BoxingActor::HandleBoxing(const ActorMsg& msg) {
......@@ -23,7 +19,7 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1;
if (num_of_eord_ == num_of_subscribed_regsts_) {
cur_msg_handle_ = &BoxingActor::HandleBoxingWhenNoReadableRegstMsg;
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxingWhenNoReadableRegstMsg);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
......@@ -44,25 +40,16 @@ int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) {
if (num_of_read_empty_ == num_of_subscribed_regsts_) {
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
cur_msg_handle_ = &BoxingActor::HandleWaitUntilReadingCntEqualZero;
OF_SET_MSG_HANDLE(&BoxingActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
}
return 0;
}
int BoxingActor::HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
return 1;
}
return 0;
}
void BoxingActor::TryWardKernelAndSendMsg() {
if (!num_of_read_empty_ && IsWriteReady()) {
int64_t piece_id = expected_piece_id();
......
......@@ -12,17 +12,14 @@ class BoxingActor final : public Actor {
~BoxingActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
int ProcessMsg(const ActorMsg&) override;
private:
int HandleInitDeviceCtx(const ActorMsg&);
int HandleBoxing(const ActorMsg&);
int HandleBoxingWhenNoReadableRegstMsg(const ActorMsg&);
int HandleWaitUntilReadingCntEqualZero(const ActorMsg&);
void TryWardKernelAndSendMsg();
int (BoxingActor::*cur_msg_handle_)(const ActorMsg&);
int num_of_subscribed_regsts_;
int num_of_read_empty_;
int num_of_eord_;
......
......@@ -20,7 +20,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
}
cur_msg_handle_ = &BpDataCompActor::HandleBpComp;
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpComp);
}
bool BpDataCompActor::IsReadReady() {
......@@ -36,16 +36,12 @@ bool BpDataCompActor::IsReadReady() {
return !num_of_read_empty_;
}
int BpDataCompActor::ProcessMsg(const ActorMsg& msg) {
return (this->*cur_msg_handle_)(msg);
}
int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1;
if (num_of_eord_ == 6) {
cur_msg_handle_ = &BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg;
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
......@@ -78,25 +74,16 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
AsyncSendEORDMsgForAllProducedRegstDesc();
num_of_read_empty_ = 6;
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
cur_msg_handle_ = &BpDataCompActor::HandleWaitUntilReadingCntEqualZero;
OF_SET_MSG_HANDLE(nullptr);
return 0;
}
}
return 0;
}
int BpDataCompActor::HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
return 1;
}
return 0;
}
void BpDataCompActor::TryWardKernelAndSendMsg() {
while (IsReadReady() && IsWriteReady()) {
int64_t cur_model = read_regst_.at(model_regst_desc_id_).front()->model_version_id();
......
......@@ -12,19 +12,16 @@ public:
~BpDataCompActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
int ProcessMsg(const ActorMsg&) override;
private:
int HandleInitDeviceCtx(const ActorMsg&);
int HandleBpComp(const ActorMsg&);
int HandleBpCompWhenNoReadableRegstMsg(const ActorMsg&);
int HandleWaitUntilReadingCntEqualZero(const ActorMsg&);
bool IsReadReady();
void TryWardKernelAndSendMsg();
CudaStreamHandle cuda_handle_;
int (BpDataCompActor::*cur_msg_handle_)(const ActorMsg&);
int num_of_read_empty_;
int num_of_eord_;
int64_t expected_model_version_id_;
......
......@@ -18,7 +18,7 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
}
cur_msg_handle_ = &FwDataCompActor::HandleFwComp;
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwComp);
}
bool FwDataCompActor::IsReadReady() {
......@@ -33,16 +33,12 @@ bool FwDataCompActor::IsReadReady() {
return false;
}
int FwDataCompActor::ProcessMsg(const ActorMsg& msg) {
return (this->*cur_msg_handle_)(msg);
}
int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1;
if (num_of_eord_ == 3) {
cur_msg_handle_ = &FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg;
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
......@@ -78,25 +74,16 @@ int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
model_tmp_regst_ = nullptr;
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
cur_msg_handle_ = &FwDataCompActor::HandleWaitUntilReadingCntEqualZero;
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
}
return 0;
}
int FwDataCompActor::HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
return 1;
}
return 0;
}
void FwDataCompActor::TryWardKernelAndSendMsg() {
while (IsReadReady() && IsWriteReady()) {
CHECK_EQ(in_.front()->piece_id(), expected_piece_id());
......
......@@ -12,18 +12,15 @@ public:
~FwDataCompActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
int ProcessMsg(const ActorMsg&) override;
private:
int HandleFwComp(const ActorMsg&);
int HandleFwCompWhenNoReadableRegstMsg(const ActorMsg&);
int HandleWaitUntilReadingCntEqualZero(const ActorMsg&);
bool IsReadReady();
void TryWardKernelAndSendMsg();
CudaStreamHandle cuda_handle_;
int (FwDataCompActor::*cur_msg_handle_)(const ActorMsg&);
int num_of_eord_;
int64_t expected_model_version_id_;
int64_t model_regst_desc_id_;
......
......@@ -8,11 +8,7 @@ void MdSaveCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
model_regst_desc_id_ = RegstDescId4Name("model");
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
cur_msg_handle_ = &MdSaveCompActor::HandleSaveModel;
}
int MdSaveCompActor::ProcessMsg(const ActorMsg& actor_msg) {
return (this->*cur_msg_handle_)(actor_msg);
OF_SET_MSG_HANDLE(&MdSaveCompActor::HandleSaveModel);
}
int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) {
......
......@@ -12,11 +12,9 @@ class MdSaveCompActor final : public CompActor {
~MdSaveCompActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
int ProcessMsg(const ActorMsg&) override;
private:
int HandleSaveModel(const ActorMsg&);
int (MdSaveCompActor::*cur_msg_handle_)(const ActorMsg&);
int HandleSaveModel(const ActorMsg&);
int64_t model_regst_desc_id_;
};
......
......@@ -15,11 +15,7 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
}
cur_msg_handle_ = &MdUpdtCompActor::HandleBeforeInitializeModel;
}
int MdUpdtCompActor::ProcessMsg(const ActorMsg& actor_msg) {
return (this->*cur_msg_handle_)(actor_msg);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleBeforeInitializeModel);
}
int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
......@@ -50,7 +46,7 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
return ret;
});
}
cur_msg_handle_ = &MdUpdtCompActor::HandleBeforeSendInitialModel;
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleBeforeSendInitialModel);
return 0;
}
......@@ -60,10 +56,10 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_);
AsyncSendEORDMsgToSubscribers(model_tmp_regst_desc_id_);
if (JobDesc::Singleton().is_train()) {
cur_msg_handle_ = &MdUpdtCompActor::HandleUpdateModel;
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdateModel);
} else {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
cur_msg_handle_ = &MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero;
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -71,7 +67,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD);
cur_msg_handle_ = &MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg;
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg);
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_warpper = actor_msg.regst_warpper();
if (TryUpdtStateAsProducedRegst(regst_warpper->regst_raw_ptr()) != 0) {
......@@ -92,27 +88,16 @@ int MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg(
if (waiting_model_diff_acc_queue_.empty()) {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
cur_msg_handle_ = &MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero;
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
}
return 0;
}
int MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero(
const ActorMsg& actor_msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(
actor_msg.regst_warpper()->regst_raw_ptr()), 0);
if (total_reading_cnt() == 0) {
cur_msg_handle_ = nullptr;
return 1;
}
return 0;
}
void MdUpdtCompActor::TryWardKernelAndSendMsg() {
if (!waiting_model_diff_acc_queue_.empty() && IsWriteReady()) {
auto model_diff_acc_wpr = waiting_model_diff_acc_queue_.front();
......
......@@ -12,7 +12,6 @@ class MdUpdtCompActor final : public CompActor {
~MdUpdtCompActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
int ProcessMsg(const ActorMsg&) override;
private:
int HandleBeforeInitDeviceCtx(const ActorMsg&);
......@@ -20,12 +19,10 @@ class MdUpdtCompActor final : public CompActor {
int HandleBeforeSendInitialModel(const ActorMsg&);
int HandleUpdateModel(const ActorMsg&);
int HandleUpdtModelWhenNoReadableRegstMsg(const ActorMsg&);
int HandleWaitUntilReadingCntEqualZero(const ActorMsg&);
void TryWardKernelAndSendMsg();
CudaStreamHandle cuda_handle_;
int (MdUpdtCompActor::*cur_msg_handle_)(const ActorMsg&);
int64_t model_regst_desc_id_;
int64_t model_tmp_regst_desc_id_;
std::queue<std::shared_ptr<RegstWarpper>> waiting_model_diff_acc_queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册