提交 cf5ed195 编写于 作者: W willzhang4a58

Define Common Handle: WaitUntilNoReadableRegst, Common

上级 0c550414
......@@ -50,6 +50,8 @@ class Actor {
} while (0)
// Common Handles
virtual int HandleNormal(const ActorMsg& msg) = 0;
virtual int HandleWaitUntilNoReadableRegst(const ActorMsg& msg) = 0;
int HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg);
// Status of Produced Registers
......
......@@ -12,15 +12,15 @@ void BoxingActor::Init(const TaskProto& task_proto,
num_of_eord_ = 0;
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxing);
OF_SET_MSG_HANDLE(&BoxingActor::HandleNormal);
}
int BoxingActor::HandleBoxing(const ActorMsg& msg) {
int BoxingActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1;
if (num_of_eord_ == num_of_subscribed_regsts_) {
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxingWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&BoxingActor::HandleWaitUntilNoReadableRegst);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
......@@ -36,7 +36,7 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
return 0;
}
int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) {
int BoxingActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryLaunchKernelAndSendMsg();
......
......@@ -14,8 +14,8 @@ class BoxingActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override;
private:
int HandleBoxing(const ActorMsg&);
int HandleBoxingWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg();
......
......@@ -21,7 +21,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto,
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
}
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpComp);
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleNormal);
}
bool BpDataCompActor::IsReadReady() {
......@@ -37,12 +37,12 @@ bool BpDataCompActor::IsReadReady() {
return !num_of_read_empty_;
}
int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
int BpDataCompActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1;
if (num_of_eord_ == 6) {
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilNoReadableRegst);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
......@@ -63,7 +63,7 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
return 0;
}
int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
int BpDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryLaunchKernelAndSendMsg();
......@@ -81,7 +81,7 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(nullptr);
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
}
......
......@@ -14,8 +14,8 @@ class BpDataCompActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override;
private:
int HandleBpComp(const ActorMsg&);
int HandleBpCompWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
bool IsReadReady();
void TryLaunchKernelAndSendMsg();
......
......@@ -9,14 +9,13 @@ void CopyCommNetActor::Init(const TaskProto& task_proto,
Actor::Init(task_proto, thread_ctx);
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleCopyCommNet);
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleNormal);
}
int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) {
int CopyCommNetActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(
&CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilNoReadableRegst);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_wp = msg.regst_warpper();
if (TryUpdtStateAsProducedRegst(regst_wp->regst_raw_ptr()) != 0) {
......@@ -28,8 +27,7 @@ int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) {
return 0;
}
int CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg(
const ActorMsg& msg) {
int CopyCommNetActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryLaunchKernelAndSendMsg();
......
......@@ -14,8 +14,8 @@ class CopyCommNetActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override;
private:
int HandleCopyCommNet(const ActorMsg&);
int HandleCopyCommNetWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg();
HashMap<int64_t, std::shared_ptr<RegstWarpper>> piece_id2waiting_in_regst_;
......
......@@ -10,13 +10,13 @@ void CopyHdActor::Init(const TaskProto& task_proto,
CHECK(thread_ctx.copy_hd_cuda_stream);
mut_device_ctx().reset(
new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr));
OF_SET_MSG_HANDLE(&CopyHdActor::HandleCopyHd);
OF_SET_MSG_HANDLE(&CopyHdActor::HandleNormal);
}
int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
int CopyHdActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&CopyHdActor::HandleCopyHdWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilNoReadableRegst);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
......@@ -27,7 +27,7 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
return 0;
}
int CopyHdActor::HandleCopyHdWhenNoReadableRegstMsg(const ActorMsg& msg) {
int CopyHdActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryLaunchKernelAndSendMsg();
......
......@@ -14,8 +14,8 @@ class CopyHdActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override;
private:
int HandleCopyHd(const ActorMsg&);
int HandleCopyHdWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg();
std::queue<std::shared_ptr<RegstWarpper>> waiting_in_regst_;
......
......@@ -27,7 +27,7 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
} else {
num_of_not_eord_ =
1 + (model_regst_desc_id_ != -1) + (model_tmp_regst_desc_id_ != -1);
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwComp);
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleNormal);
}
}
......@@ -53,16 +53,16 @@ bool FwDataCompActor::IsReadReady() {
int FwDataCompActor::WaitToStart(const ActorMsg& msg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);
TryLaunchKernelAndSendMsg();
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilNoReadableRegst);
return 0;
}
int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
int FwDataCompActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_not_eord_ -= 1;
if (!num_of_not_eord_) {
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilNoReadableRegst);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
......@@ -87,7 +87,7 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
return 0;
}
int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
int FwDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryLaunchKernelAndSendMsg();
......
......@@ -15,8 +15,8 @@ class FwDataCompActor final : public CompActor {
private:
int WaitToStart(const ActorMsg&);
int HandleFwComp(const ActorMsg&);
int HandleFwCompWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
bool IsReadReady();
void TryLaunchKernelAndSendMsg();
......
......@@ -14,15 +14,15 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
}
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleMdDiffAcc);
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleNormal);
ForEachCurWriteableRegst(
[this](Regst* regst) { model_diff_acc_cnt_[regst] = 0; });
}
int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
int MdDiffAccActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilNoReadableRegst);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
......@@ -33,7 +33,7 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
return 0;
}
int MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg& msg) {
int MdDiffAccActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryLaunchKernelAndSendMsg();
......
......@@ -14,8 +14,8 @@ class MdDiffAccActor final : public CompActor {
void Init(const TaskProto&, const ThreadCtx&) override;
private:
int HandleMdDiffAcc(const ActorMsg&);
int HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg();
......
......@@ -9,12 +9,12 @@ void MdSaveCompActor::Init(const TaskProto& task_proto,
model_regst_desc_id_ = RegstDescId4Name("model");
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&MdSaveCompActor::HandleSaveModel);
OF_SET_MSG_HANDLE(&MdSaveCompActor::HandleNormal);
}
int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) {
int MdSaveCompActor::HandleNormal(const ActorMsg& actor_msg) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kEORD);
CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD);
return 1;
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
std::shared_ptr<RegstWarpper> regst_warpper = actor_msg.regst_warpper();
......
......@@ -14,7 +14,10 @@ class MdSaveCompActor final : public CompActor {
void Init(const TaskProto&, const ThreadCtx&) override;
private:
int HandleSaveModel(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg& msg) override {
UNEXPECTED_RUN();
}
int64_t model_regst_desc_id_;
};
......
......@@ -56,7 +56,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_);
AsyncSendEORDMsgToSubscribers(model_tmp_regst_desc_id_);
if (JobDesc::Singleton()->is_train()) {
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdateModel);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleNormal);
} else {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero);
......@@ -64,10 +64,10 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
return 0;
}
int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
int MdUpdtCompActor::HandleNormal(const ActorMsg& actor_msg) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilNoReadableRegst);
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_warpper = actor_msg.regst_warpper();
if (TryUpdtStateAsProducedRegst(regst_warpper->regst_raw_ptr()) != 0) {
......@@ -80,8 +80,7 @@ int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
return 0;
}
int MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg(
const ActorMsg& actor_msg) {
int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) {
CHECK_EQ(
TryUpdtStateAsProducedRegst(actor_msg.regst_warpper()->regst_raw_ptr()),
0);
......
......@@ -17,8 +17,8 @@ class MdUpdtCompActor final : public CompActor {
int HandleBeforeInitDeviceCtx(const ActorMsg&);
int HandleBeforeInitializeModel(const ActorMsg&);
int HandleBeforeSendInitialModel(const ActorMsg&);
int HandleUpdateModel(const ActorMsg&);
int HandleUpdtModelWhenNoReadableRegstMsg(const ActorMsg&);
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册