diff --git a/oneflow/core/actor/copy_comm_net_actor.cpp b/oneflow/core/actor/copy_comm_net_actor.cpp index 24693edf535bd7d6d35dbe4241c43500176d88ac..82abb8b5d974914c10e7800210a22e20cc71e64b 100644 --- a/oneflow/core/actor/copy_comm_net_actor.cpp +++ b/oneflow/core/actor/copy_comm_net_actor.cpp @@ -8,6 +8,8 @@ void CopyCommNetActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) { Actor::Init(task_proto, thread_ctx); CHECK(thread_ctx.cpu_stream); + set_num_of_not_eord(1); + mut_num_of_read_empty() = 1; mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream)); OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleNormal); } @@ -15,16 +17,19 @@ void CopyCommNetActor::Init(const TaskProto& task_proto, int CopyCommNetActor::HandleNormal(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); - OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilNoReadableRegst); + ProcessEord(); } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { auto regst_wp = msg.regst_wrapper(); if (TryUpdtStateAsProducedRegst(regst_wp->regst_raw_ptr()) != 0) { + mut_num_of_read_empty() = 0; CHECK(piece_id2waiting_in_regst_.emplace(regst_wp->piece_id(), regst_wp) .second); + } else { + // do nothing } + ActUntilFail(); } - ActUntilFail(); - return 0; + return msg_handle() == nullptr; } int CopyCommNetActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) { @@ -33,13 +38,7 @@ int CopyCommNetActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) { ActUntilFail(); if (piece_id2waiting_in_regst_.empty()) { AsyncSendEORDMsgForAllProducedRegstDesc(); - if (total_reading_cnt() == 0) { - OF_SET_MSG_HANDLE(nullptr); - return 1; - } else { - OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilReadingCntEqualZero); - return 0; - } + OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleWaitUntilReadingCntEqualZero); } return 0; } @@ -64,6 +63,7 @@ void CopyCommNetActor::Act() { }); AsyncSendRegstMsgToProducer(regst_wp); piece_id2waiting_in_regst_.erase(next_regst_it); + mut_num_of_read_empty() = piece_id2waiting_in_regst_.empty(); } REGISTER_ACTOR(kCopyCommNetTask, true, CopyCommNetActor); diff --git a/oneflow/core/actor/copy_hd_actor.cpp b/oneflow/core/actor/copy_hd_actor.cpp index c7957a8a90715e81b6085d9ede5d2c1005316852..e42f265feaa56a2a7f4f536686653c1d01f995f6 100644 --- a/oneflow/core/actor/copy_hd_actor.cpp +++ b/oneflow/core/actor/copy_hd_actor.cpp @@ -10,36 +10,35 @@ void CopyHdActor::Init(const TaskProto& task_proto, CHECK(thread_ctx.copy_hd_cuda_stream); mut_device_ctx().reset( new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr)); + set_num_of_not_eord(1); + mut_num_of_read_empty() = 1; OF_SET_MSG_HANDLE(&CopyHdActor::HandleNormal); } int CopyHdActor::HandleNormal(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); - OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilNoReadableRegst); + ProcessEord(); } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr()) != 0) { + mut_num_of_read_empty() = 0; waiting_in_regst_.push(msg.regst_wrapper()); + } else { + // do nothing } + ActUntilFail(); } - ActUntilFail(); - return 0; + return msg_handle() == nullptr; } int CopyHdActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) { CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr()), 0); ActUntilFail(); - if (waiting_in_regst_.empty()) { + if (mut_num_of_read_empty()) { AsyncSendEORDMsgForAllProducedRegstDesc(); - if (total_reading_cnt() == 0) { - OF_SET_MSG_HANDLE(nullptr); - return 1; - } else { - OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilReadingCntEqualZero); - return 0; - } + OF_SET_MSG_HANDLE(&CopyHdActor::HandleWaitUntilReadingCntEqualZero); } return 0; } @@ -64,6 +63,7 @@ void CopyHdActor::Act() { }); AsyncSendRegstMsgToProducer(regst_wp); waiting_in_regst_.pop(); + mut_num_of_read_empty() = waiting_in_regst_.empty(); } REGISTER_ACTOR(kCopyHdTask, true, CopyHdActor); diff --git a/oneflow/core/actor/model_diff_accumulate_actor.cpp b/oneflow/core/actor/model_diff_accumulate_actor.cpp index 6f969f546723ae2c6e6aba7e66343841ff7bf887..f00b377721405f4d019e764d54396d9339caab8a 100644 --- a/oneflow/core/actor/model_diff_accumulate_actor.cpp +++ b/oneflow/core/actor/model_diff_accumulate_actor.cpp @@ -14,6 +14,8 @@ void MdDiffAccActor::Init(const TaskProto& task_proto, cuda_handle_.cublas_handle(), cuda_handle_.cudnn_handle())); } + set_num_of_not_eord(1); + mut_num_of_read_empty() = 1; OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleNormal); ForEachCurWriteableRegst( [this](Regst* regst) { model_diff_acc_cnt_[regst] = 0; }); @@ -22,15 +24,18 @@ void MdDiffAccActor::Init(const TaskProto& task_proto, int MdDiffAccActor::HandleNormal(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD); - OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilNoReadableRegst); + ProcessEord(); } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr()) != 0) { + mut_num_of_read_empty() = 0; waiting_in_regst_.push(msg.regst_wrapper()); + } else { + // do nothing } + ActUntilFail(); } - ActUntilFail(); - return 0; + return msg_handle() == nullptr; } int MdDiffAccActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) { @@ -39,13 +44,7 @@ int MdDiffAccActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) { ActUntilFail(); if (waiting_in_regst_.empty()) { AsyncSendEORDMsgForAllProducedRegstDesc(); - if (total_reading_cnt() == 0) { - OF_SET_MSG_HANDLE(nullptr); - return 1; - } else { - OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilReadingCntEqualZero); - return 0; - } + OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleWaitUntilReadingCntEqualZero); } return 0; } @@ -81,6 +80,7 @@ void MdDiffAccActor::Act() { }); AsyncSendRegstMsgToProducer(regst_wp); waiting_in_regst_.pop(); + mut_num_of_read_empty() = waiting_in_regst_.empty(); } } // namespace oneflow