提交 997ec81a 编写于 作者: J Jingwu Chen 提交者: Liu Guo

model update (#191)

* total_piece_num

* fw

* fix for review

* code style

* fix bug

* fix bug for data loader

* kStart in FW for data loader

* fix bug

* bp

* bp && fw

* unexpected_run

* boxing state

* fix bug

* ProcessEord & Boxing fixed

* ProcessEord in Actor and fix boxing

* remove useless override

* total_reading_cnt_

* use mut

* redefine fw

* redefine bp

* remove useless variable

* redefine model update
上级 87db4f65
......@@ -93,8 +93,8 @@ int FwDataCompActor::HandleNormal(const ActorMsg& msg) {
readable_regst_[model_regst_desc_id_] = regst_wp;
expected_model_version_id_ += 1;
} else {
mut_num_of_read_empty() -= in_.empty();
in_.push(regst_wp);
mut_num_of_read_empty() = 0;
}
}
ActUntilFail();
......@@ -141,7 +141,7 @@ void FwDataCompActor::Act() {
if (!in_.empty()) {
AsyncSendRegstMsgToProducer(in_.front());
in_.pop();
mut_num_of_read_empty() += in_.empty();
mut_num_of_read_empty() = in_.empty();
}
if (bp_actor_id_ != -1) {
ActorMsg msg;
......
......@@ -10,6 +10,8 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto,
model_regst_desc_id_ = RegstDescId4Name("model");
model_tmp_regst_desc_id_ = RegstDescId4Name("model_tmp");
next_model_version_id_ = 0;
set_num_of_not_eord(1);
mut_num_of_read_empty() = 1;
if (thread_ctx.cpu_stream) {
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
} else {
......@@ -67,17 +69,18 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
int MdUpdtCompActor::HandleNormal(const ActorMsg& actor_msg) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(actor_msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilNoReadableRegst);
ProcessEord();
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_wrapper = actor_msg.regst_wrapper();
if (TryUpdtStateAsProducedRegst(regst_wrapper->regst_raw_ptr()) != 0) {
waiting_model_diff_acc_queue_.push(regst_wrapper);
mut_num_of_read_empty() = 0;
}
ActUntilFail();
} else {
UNEXPECTED_RUN();
}
return 0;
return msg_handle() == nullptr;
}
int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) {
......@@ -87,13 +90,7 @@ int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) {
ActUntilFail();
if (waiting_model_diff_acc_queue_.empty()) {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -101,6 +98,7 @@ int MdUpdtCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& actor_msg) {
void MdUpdtCompActor::Act() {
auto model_diff_acc_wpr = waiting_model_diff_acc_queue_.front();
waiting_model_diff_acc_queue_.pop();
mut_num_of_read_empty() = waiting_model_diff_acc_queue_.empty();
Regst* model_regst = GetCurWriteableRegst(model_regst_desc_id_);
auto model_wpr = std::make_shared<LocalRegstWrapper>(model_regst);
model_regst->set_model_version_id(next_model_version_id_++);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册