提交 6832d2f0 编写于 作者: W willzhang4a58

process regst msg in ModelUpdateCompActor

上级 3f4ee5c9
......@@ -9,6 +9,8 @@
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/register/register.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/register/register_manager.h"
#include "oneflow/core/thread/thread_context.h"
......
......@@ -58,15 +58,38 @@ void MdUpdtCompActor::HandleForUpdateModel(
const KernelContext& kernel_ctx) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kStop);
cur_handle_ = nullptr;
TODO();
cur_handle_ = nullptr;
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
TODO();
ProcessRegstFromMsg(actor_msg.regst_warpper(), kernel_ctx);
} else {
UNEXPECTED_RUN();
}
}
void MdUpdtCompActor::ProcessRegstFromMsg(
std::shared_ptr<RegstWarpper> regst_warpper,
const KernelContext& kernel_ctx) {
if (TryUpdtStateAsFromRegstReader(regst_warpper->regst_raw_ptr()) != 0) {
waiting_model_diff_acc_queue_.push(regst_warpper);
}
if (!waiting_model_diff_acc_queue_.empty() && IsWriteReady()) {
auto model_diff_acc_wpr = waiting_model_diff_acc_queue_.front();
waiting_model_diff_acc_queue_.pop();
Regst* model_regst = GetCurWriteableRegst(model_regst_desc_id_);
auto model_wpr = std::make_shared<LocalRegstWarpper>(model_regst);
WardKernel(kernel_ctx,
[&](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
if (regst_desc_id == model_regst_desc_id_) {
return model_wpr;
} else {
return model_diff_acc_wpr;
}
});
model_regst->set_model_version_id(model_diff_acc_wpr->model_version_id());
}
}
REGISTER_ACTOR(kMdUpdtCompTask, true, MdUpdtCompActor);
} // namespace oneflow
......@@ -15,14 +15,16 @@ class MdUpdtCompActor final : public CompActor {
void ProcessMsg(const ActorMsg&, const ThreadContext&) override;
private:
void HandleBeforeInitializeModel(const ActorMsg&, const KernelContext&);
void HandleBeforeSendInitialModel(const ActorMsg&, const KernelContext&);
void HandleForUpdateModel(const ActorMsg&, const KernelContext&);
void ProcessRegstFromMsg(std::shared_ptr<RegstWarpper>, const KernelContext&);
void (MdUpdtCompActor::*cur_handle_)(const ActorMsg&, const KernelContext&);
uint64_t model_regst_desc_id_;
uint64_t model_tmp_regst_desc_id_;
std::queue<std::shared_ptr<RegstWarpper>> waiting_model_diff_acc_queue_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册