提交 cf5ed195 编写于 作者: W willzhang4a58

Define Common Handle: WaitUntilNoReadableRegst, Common

上级 0c550414
...@@ -50,6 +50,8 @@ class Actor { ...@@ -50,6 +50,8 @@ class Actor {
} while (0) } while (0)
// Common Handles // Common Handles
virtual int HandleNormal(const ActorMsg& msg) = 0;
virtual int HandleWaitUntilNoReadableRegst(const ActorMsg& msg) = 0;
int HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg); int HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg);
// Status of Produced Registers // Status of Produced Registers
......
...@@ -12,15 +12,15 @@ void BoxingActor::Init(const TaskProto& task_proto, ...@@ -12,15 +12,15 @@ void BoxingActor::Init(const TaskProto& task_proto,
num_of_eord_ = 0; num_of_eord_ = 0;
CHECK(thread_ctx.cpu_stream); CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream)); mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxing); OF_SET_MSG_HANDLE(&BoxingActor::HandleNormal);
} }
int BoxingActor::HandleBoxing(const ActorMsg& msg) { int BoxingActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1; num_of_eord_ += 1;
if (num_of_eord_ == num_of_subscribed_regsts_) { if (num_of_eord_ == num_of_subscribed_regsts_) {
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxingWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&BoxingActor::HandleWaitUntilNoReadableRegst);
} }
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
...@@ -36,7 +36,7 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) { ...@@ -36,7 +36,7 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
return 0; return 0;
} }
int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) { int BoxingActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0); 0);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
......
...@@ -14,8 +14,8 @@ class BoxingActor final : public Actor { ...@@ -14,8 +14,8 @@ class BoxingActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override; void Init(const TaskProto&, const ThreadCtx&) override;
private: private:
int HandleBoxing(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleBoxingWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
......
...@@ -21,7 +21,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto, ...@@ -21,7 +21,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto,
cuda_handle_.cublas_handle(), cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle())); cuda_handle_.cudnn_handle()));
} }
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpComp); OF_SET_MSG_HANDLE(&BpDataCompActor::HandleNormal);
} }
bool BpDataCompActor::IsReadReady() { bool BpDataCompActor::IsReadReady() {
...@@ -37,12 +37,12 @@ bool BpDataCompActor::IsReadReady() { ...@@ -37,12 +37,12 @@ bool BpDataCompActor::IsReadReady() {
return !num_of_read_empty_; return !num_of_read_empty_;
} }
int BpDataCompActor::HandleBpComp(const ActorMsg& msg) { int BpDataCompActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1; num_of_eord_ += 1;
if (num_of_eord_ == 6) { if (num_of_eord_ == 6) {
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilNoReadableRegst);
} }
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
...@@ -63,7 +63,7 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) { ...@@ -63,7 +63,7 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
return 0; return 0;
} }
int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) { int BpDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0); 0);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
...@@ -81,7 +81,7 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) { ...@@ -81,7 +81,7 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
OF_SET_MSG_HANDLE(nullptr); OF_SET_MSG_HANDLE(nullptr);
return 1; return 1;
} else { } else {
OF_SET_MSG_HANDLE(nullptr); OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilReadingCntEqualZero);
return 0; return 0;
} }
} }
......
...@@ -14,8 +14,8 @@ class BpDataCompActor final : public Actor { ...@@ -14,8 +14,8 @@ class BpDataCompActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override; void Init(const TaskProto&, const ThreadCtx&) override;
private: private:
int HandleBpComp(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleBpCompWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
bool IsReadReady(); bool IsReadReady();
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
......
...@@ -9,14 +9,13 @@ void CopyCommNetActor::Init(const TaskProto& task_proto, ...@@ -9,14 +9,13 @@ void CopyCommNetActor::Init(const TaskProto& task_proto,
Actor::Init(task_proto, thread_ctx); Actor::Init(task_proto, thread_ctx);
CHECK(thread_ctx.cpu_stream); CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream)); mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleCopyCommNet); OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleNormal);
} }
int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) { int CopyCommNetActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE( OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilNoReadableRegst);
&CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_wp = msg.regst_warpper(); auto regst_wp = msg.regst_warpper();
if (TryUpdtStateAsProducedRegst(regst_wp->regst_raw_ptr()) != 0) { if (TryUpdtStateAsProducedRegst(regst_wp->regst_raw_ptr()) != 0) {
...@@ -28,8 +27,7 @@ int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) { ...@@ -28,8 +27,7 @@ int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) {
return 0; return 0;
} }
int CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg( int CopyCommNetActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0); 0);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
......
...@@ -14,8 +14,8 @@ class CopyCommNetActor final : public Actor { ...@@ -14,8 +14,8 @@ class CopyCommNetActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override; void Init(const TaskProto&, const ThreadCtx&) override;
private: private:
int HandleCopyCommNet(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleCopyCommNetWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
HashMap<int64_t, std::shared_ptr<RegstWarpper>> piece_id2waiting_in_regst_; HashMap<int64_t, std::shared_ptr<RegstWarpper>> piece_id2waiting_in_regst_;
......
...@@ -10,13 +10,13 @@ void CopyHdActor::Init(const TaskProto& task_proto, ...@@ -10,13 +10,13 @@ void CopyHdActor::Init(const TaskProto& task_proto,
CHECK(thread_ctx.copy_hd_cuda_stream); CHECK(thread_ctx.copy_hd_cuda_stream);
mut_device_ctx().reset( mut_device_ctx().reset(
new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr)); new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr));
OF_SET_MSG_HANDLE(&CopyHdActor::HandleCopyHd); OF_SET_MSG_HANDLE(&CopyHdActor::HandleNormal);
} }
int CopyHdActor::HandleCopyHd(const ActorMsg& msg) { int CopyHdActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&CopyHdActor::HandleCopyHdWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilNoReadableRegst);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) { != 0) {
...@@ -27,7 +27,7 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) { ...@@ -27,7 +27,7 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
return 0; return 0;
} }
int CopyHdActor::HandleCopyHdWhenNoReadableRegstMsg(const ActorMsg& msg) { int CopyHdActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0); 0);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
......
...@@ -14,8 +14,8 @@ class CopyHdActor final : public Actor { ...@@ -14,8 +14,8 @@ class CopyHdActor final : public Actor {
void Init(const TaskProto&, const ThreadCtx&) override; void Init(const TaskProto&, const ThreadCtx&) override;
private: private:
int HandleCopyHd(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleCopyHdWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
std::queue<std::shared_ptr<RegstWarpper>> waiting_in_regst_; std::queue<std::shared_ptr<RegstWarpper>> waiting_in_regst_;
......
...@@ -27,7 +27,7 @@ void FwDataCompActor::Init(const TaskProto& task_proto, ...@@ -27,7 +27,7 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
} else { } else {
num_of_not_eord_ = num_of_not_eord_ =
1 + (model_regst_desc_id_ != -1) + (model_tmp_regst_desc_id_ != -1); 1 + (model_regst_desc_id_ != -1) + (model_tmp_regst_desc_id_ != -1);
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwComp); OF_SET_MSG_HANDLE(&FwDataCompActor::HandleNormal);
} }
} }
...@@ -53,16 +53,16 @@ bool FwDataCompActor::IsReadReady() { ...@@ -53,16 +53,16 @@ bool FwDataCompActor::IsReadReady() {
int FwDataCompActor::WaitToStart(const ActorMsg& msg) { int FwDataCompActor::WaitToStart(const ActorMsg& msg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart); CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilNoReadableRegst);
return 0; return 0;
} }
int FwDataCompActor::HandleFwComp(const ActorMsg& msg) { int FwDataCompActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_not_eord_ -= 1; num_of_not_eord_ -= 1;
if (!num_of_not_eord_) { if (!num_of_not_eord_) {
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilNoReadableRegst);
} }
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
...@@ -87,7 +87,7 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) { ...@@ -87,7 +87,7 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
return 0; return 0;
} }
int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) { int FwDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0); 0);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
......
...@@ -15,8 +15,8 @@ class FwDataCompActor final : public CompActor { ...@@ -15,8 +15,8 @@ class FwDataCompActor final : public CompActor {
private: private:
int WaitToStart(const ActorMsg&); int WaitToStart(const ActorMsg&);
int HandleFwComp(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleFwCompWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
bool IsReadReady(); bool IsReadReady();
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
......
...@@ -14,15 +14,15 @@ void MdDiffAccActor::Init(const TaskProto& task_proto, ...@@ -14,15 +14,15 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
cuda_handle_.cublas_handle(), cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle())); cuda_handle_.cudnn_handle()));
} }
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleMdDiffAcc); OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleNormal);
ForEachCurWriteableRegst( ForEachCurWriteableRegst(
[this](Regst* regst) { model_diff_acc_cnt_[regst] = 0; }); [this](Regst* regst) { model_diff_acc_cnt_[regst] = 0; });
} }
int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) { int MdDiffAccActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilNoReadableRegst);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) { != 0) {
...@@ -33,7 +33,7 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) { ...@@ -33,7 +33,7 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
return 0; return 0;
} }
int MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg& msg) { int MdDiffAccActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0); 0);
TryLaunchKernelAndSendMsg(); TryLaunchKernelAndSendMsg();
......
...@@ -14,8 +14,8 @@ class MdDiffAccActor final : public CompActor { ...@@ -14,8 +14,8 @@ class MdDiffAccActor final : public CompActor {
void Init(const TaskProto&, const ThreadCtx&) override; void Init(const TaskProto&, const ThreadCtx&) override;
private: private:
int HandleMdDiffAcc(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
......
...@@ -9,12 +9,12 @@ void MdSaveCompActor::Init(const TaskProto& task_proto, ...@@ -9,12 +9,12 @@ void MdSaveCompActor::Init(const TaskProto& task_proto,
model_regst_desc_id_ = RegstDescId4Name("model"); model_regst_desc_id_ = RegstDescId4Name("model");
CHECK(thread_ctx.cpu_stream); CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream)); mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&MdSaveCompActor::HandleSaveModel); OF_SET_MSG_HANDLE(&MdSaveCompActor::HandleNormal);
} }
int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) { int MdSaveCompActor::HandleNormal(const ActorMsg& actor_msg) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) { if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kEORD); CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD);
return 1; return 1;
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
std::shared_ptr<RegstWarpper> regst_warpper = actor_msg.regst_warpper(); std::shared_ptr<RegstWarpper> regst_warpper = actor_msg.regst_warpper();
......
...@@ -14,7 +14,10 @@ class MdSaveCompActor final : public CompActor { ...@@ -14,7 +14,10 @@ class MdSaveCompActor final : public CompActor {
void Init(const TaskProto&, const ThreadCtx&) override; void Init(const TaskProto&, const ThreadCtx&) override;
private: private:
int HandleSaveModel(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg& msg) override {
UNEXPECTED_RUN();
}
int64_t model_regst_desc_id_; int64_t model_regst_desc_id_;
}; };
......
...@@ -56,7 +56,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) { ...@@ -56,7 +56,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_); SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_);
AsyncSendEORDMsgToSubscribers(model_tmp_regst_desc_id_); AsyncSendEORDMsgToSubscribers(model_tmp_regst_desc_id_);
if (JobDesc::Singleton()->is_train()) { if (JobDesc::Singleton()->is_train()) {
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdateModel); OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleNormal);
} else { } else {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_); AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero); OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero);
...@@ -64,10 +64,10 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) { ...@@ -64,10 +64,10 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
return 0; return 0;
} }
int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) { int MdUpdtCompActor::HandleNormal(const ActorMsg& actor_msg) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) { if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD); CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg); OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilNoReadableRegst);
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_warpper = actor_msg.regst_warpper(); auto regst_warpper = actor_msg.regst_warpper();
if (TryUpdtStateAsProducedRegst(regst_warpper->regst_raw_ptr()) != 0) { if (TryUpdtStateAsProducedRegst(regst_warpper->regst_raw_ptr()) != 0) {
...@@ -80,8 +80,7 @@ int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) { ...@@ -80,8 +80,7 @@ int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
return 0; return 0;
} }
int MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg( int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) {
const ActorMsg& actor_msg) {
CHECK_EQ( CHECK_EQ(
TryUpdtStateAsProducedRegst(actor_msg.regst_warpper()->regst_raw_ptr()), TryUpdtStateAsProducedRegst(actor_msg.regst_warpper()->regst_raw_ptr()),
0); 0);
......
...@@ -17,8 +17,8 @@ class MdUpdtCompActor final : public CompActor { ...@@ -17,8 +17,8 @@ class MdUpdtCompActor final : public CompActor {
int HandleBeforeInitDeviceCtx(const ActorMsg&); int HandleBeforeInitDeviceCtx(const ActorMsg&);
int HandleBeforeInitializeModel(const ActorMsg&); int HandleBeforeInitializeModel(const ActorMsg&);
int HandleBeforeSendInitialModel(const ActorMsg&); int HandleBeforeSendInitialModel(const ActorMsg&);
int HandleUpdateModel(const ActorMsg&); int HandleNormal(const ActorMsg&) override;
int HandleUpdtModelWhenNoReadableRegstMsg(const ActorMsg&); int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
void TryLaunchKernelAndSendMsg(); void TryLaunchKernelAndSendMsg();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册