From 997ec81a8dd7d2a716acacbe600147fae559a689 Mon Sep 17 00:00:00 2001 From: Jingwu Chen <1163643073@qq.com> Date: Sun, 9 Jul 2017 18:02:24 +0800 Subject: [PATCH] model update (#191) * total_piece_num * fw * fix for review * code style * fix bug * fix bug for data loader * kStart in FW for data loader * fix bug * bp * bp && fw * unexpected_run * boxing state * fix bug * ProcessEord & Boxing fixed * ProcessEord in Actor and fix boxing * remove useless override * total_reading_cnt_ * use mut * redefine fw * redefine bp * remove useless variable * redefine model update --- oneflow/core/actor/fw_data_comp_actor.cpp | 4 ++-- oneflow/core/actor/model_update_comp_actor.cpp | 16 +++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/oneflow/core/actor/fw_data_comp_actor.cpp b/oneflow/core/actor/fw_data_comp_actor.cpp index f7bb1a672f..840a4b9699 100644 --- a/oneflow/core/actor/fw_data_comp_actor.cpp +++ b/oneflow/core/actor/fw_data_comp_actor.cpp @@ -93,8 +93,8 @@ int FwDataCompActor::HandleNormal(const ActorMsg& msg) { readable_regst_[model_regst_desc_id_] = regst_wp; expected_model_version_id_ += 1; } else { - mut_num_of_read_empty() -= in_.empty(); in_.push(regst_wp); + mut_num_of_read_empty() = 0; } } ActUntilFail(); @@ -141,7 +141,7 @@ void FwDataCompActor::Act() { if (!in_.empty()) { AsyncSendRegstMsgToProducer(in_.front()); in_.pop(); - mut_num_of_read_empty() += in_.empty(); + mut_num_of_read_empty() = in_.empty(); } if (bp_actor_id_ != -1) { ActorMsg msg; diff --git a/oneflow/core/actor/model_update_comp_actor.cpp b/oneflow/core/actor/model_update_comp_actor.cpp index 7205fc8d97..ca2f3749c6 100644 --- a/oneflow/core/actor/model_update_comp_actor.cpp +++ b/oneflow/core/actor/model_update_comp_actor.cpp @@ -10,6 +10,8 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto, model_regst_desc_id_ = RegstDescId4Name("model"); model_tmp_regst_desc_id_ = RegstDescId4Name("model_tmp"); next_model_version_id_ = 0; + set_num_of_not_eord(1); + mut_num_of_read_empty() = 1; if (thread_ctx.cpu_stream) { mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream)); } else { @@ -67,17 +69,18 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) { int MdUpdtCompActor::HandleNormal(const ActorMsg& actor_msg) { if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD); - OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilNoReadableRegst); + ProcessEord(); } else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) { auto regst_wrapper = actor_msg.regst_wrapper(); if (TryUpdtStateAsProducedRegst(regst_wrapper->regst_raw_ptr()) != 0) { waiting_model_diff_acc_queue_.push(regst_wrapper); + mut_num_of_read_empty() = 0; } ActUntilFail(); } else { UNEXPECTED_RUN(); } - return 0; + return msg_handle() == nullptr; } int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) { @@ -87,13 +90,7 @@ int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) { ActUntilFail(); if (waiting_model_diff_acc_queue_.empty()) { AsyncSendEORDMsgToSubscribers(model_regst_desc_id_); - if (total_reading_cnt() == 0) { - OF_SET_MSG_HANDLE(nullptr); - return 1; - } else { - OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero); - return 0; - } + OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero); } return 0; } @@ -101,6 +98,7 @@ int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) { void MdUpdtCompActor::Act() { auto model_diff_acc_wpr = waiting_model_diff_acc_queue_.front(); waiting_model_diff_acc_queue_.pop(); + mut_num_of_read_empty() = waiting_model_diff_acc_queue_.empty(); Regst* model_regst = GetCurWriteableRegst(model_regst_desc_id_); auto model_wpr = std::make_shared(model_regst); model_regst->set_model_version_id(next_model_version_id_++); -- GitLab