提交 87db4f65 编写于 作者: L Liu Guo 提交者: Jingwu Chen

refine copy actors and mode diff actor (#189)

* refine copy actors and model diff acc actor

* use set_num_of_not_eord
上级 74acbc55
......@@ -8,6 +8,8 @@ void CopyCommNetActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
CHECK(thread_ctx.cpu_stream);
set_num_of_not_eord(1);
mut_num_of_read_empty() = 1;
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleNormal);
}
......@@ -15,16 +17,19 @@ void CopyCommNetActor::Init(const TaskProto& task_proto,
int CopyCommNetActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilNoReadableRegst);
ProcessEord();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_wp = msg.regst_wrapper();
if (TryUpdtStateAsProducedRegst(regst_wp->regst_raw_ptr()) != 0) {
mut_num_of_read_empty() = 0;
CHECK(piece_id2waiting_in_regst_.emplace(regst_wp->piece_id(), regst_wp)
.second);
} else {
// do nothing
}
ActUntilFail();
}
ActUntilFail();
return 0;
return msg_handle() == nullptr;
}
int CopyCommNetActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
......@@ -33,13 +38,7 @@ int CopyCommNetActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
ActUntilFail();
if (piece_id2waiting_in_regst_.empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -64,6 +63,7 @@ void CopyCommNetActor::Act() {
});
AsyncSendRegstMsgToProducer(regst_wp);
piece_id2waiting_in_regst_.erase(next_regst_it);
mut_num_of_read_empty() = piece_id2waiting_in_regst_.empty();
}
REGISTER_ACTOR(kCopyCommNetTask, true, CopyCommNetActor);
......
......@@ -10,36 +10,35 @@ void CopyHdActor::Init(const TaskProto& task_proto,
CHECK(thread_ctx.copy_hd_cuda_stream);
mut_device_ctx().reset(
new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr));
set_num_of_not_eord(1);
mut_num_of_read_empty() = 1;
OF_SET_MSG_HANDLE(&CopyHdActor::HandleNormal);
}
int CopyHdActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilNoReadableRegst);
ProcessEord();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr())
!= 0) {
mut_num_of_read_empty() = 0;
waiting_in_regst_.push(msg.regst_wrapper());
} else {
// do nothing
}
ActUntilFail();
}
ActUntilFail();
return 0;
return msg_handle() == nullptr;
}
int CopyHdActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr()),
0);
ActUntilFail();
if (waiting_in_regst_.empty()) {
if (mut_num_of_read_empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -64,6 +63,7 @@ void CopyHdActor::Act() {
});
AsyncSendRegstMsgToProducer(regst_wp);
waiting_in_regst_.pop();
mut_num_of_read_empty() = waiting_in_regst_.empty();
}
REGISTER_ACTOR(kCopyHdTask, true, CopyHdActor);
......
......@@ -14,6 +14,8 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
}
set_num_of_not_eord(1);
mut_num_of_read_empty() = 1;
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleNormal);
ForEachCurWriteableRegst(
[this](Regst* regst) { model_diff_acc_cnt_[regst] = 0; });
......@@ -22,15 +24,18 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
int MdDiffAccActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilNoReadableRegst);
ProcessEord();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr())
!= 0) {
mut_num_of_read_empty() = 0;
waiting_in_regst_.push(msg.regst_wrapper());
} else {
// do nothing
}
ActUntilFail();
}
ActUntilFail();
return 0;
return msg_handle() == nullptr;
}
int MdDiffAccActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
......@@ -39,13 +44,7 @@ int MdDiffAccActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
ActUntilFail();
if (waiting_in_regst_.empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -81,6 +80,7 @@ void MdDiffAccActor::Act() {
});
AsyncSendRegstMsgToProducer(regst_wp);
waiting_in_regst_.pop();
mut_num_of_read_empty() = waiting_in_regst_.empty();
}
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册