提交 cc2ee9fc 编写于 作者: W willzhang4a58

naive readable register manager

上级 6f0754f6
...@@ -36,6 +36,8 @@ class Actor { ...@@ -36,6 +36,8 @@ class Actor {
int64_t actor_id() const { return actor_id_; } int64_t actor_id() const { return actor_id_; }
protected: protected:
friend class NaiveReadableRegstMgr;
struct ExecKernel { struct ExecKernel {
std::unique_ptr<const Kernel> kernel; std::unique_ptr<const Kernel> kernel;
HashMap<std::string, int64_t> bn_in_op2regst_desc_id; HashMap<std::string, int64_t> bn_in_op2regst_desc_id;
......
...@@ -4,11 +4,8 @@ ...@@ -4,11 +4,8 @@
namespace oneflow { namespace oneflow {
void BoxingActor::VirtualActorInit(const TaskProto& task_proto) { void BoxingActor::VirtualActorInit(const TaskProto& task_proto) {
for (const auto& pair : task_proto.consumed_regst_desc_id()) { readable_regst_mgr_.Init(task_proto);
readable_regst_[pair.second] = {};
}
previous_pid_cid_ = new HashMap<int64_t, std::pair<int64_t, int32_t>>; previous_pid_cid_ = new HashMap<int64_t, std::pair<int64_t, int32_t>>;
readable_regst_cnt_ = 0;
col_id_order_ = ColIdOrder::kUnCertain; col_id_order_ = ColIdOrder::kUnCertain;
is_eord_ = false; is_eord_ = false;
OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal); OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal);
...@@ -48,9 +45,7 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) { ...@@ -48,9 +45,7 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
&& col_id_order_ == ColIdOrder::kUnCertain) { && col_id_order_ == ColIdOrder::kUnCertain) {
TrySetColIdOrder(msg.regst()); TrySetColIdOrder(msg.regst());
} }
std::queue<Regst*>& rq = readable_regst_.at(msg.regst()->regst_desc_id()); readable_regst_mgr_.Push(msg.regst());
if (rq.empty()) { readable_regst_cnt_ += 1; }
rq.push(msg.regst());
} }
ActUntilFail(); ActUntilFail();
} else { } else {
...@@ -60,57 +55,53 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) { ...@@ -60,57 +55,53 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
} }
void BoxingActor::Act() { void BoxingActor::Act() {
int64_t piece_id = readable_regst_.begin()->second.front()->piece_id(); int64_t piece_id = readable_regst_mgr_.GetFirstCurReadable()->piece_id();
AsyncLaunchKernel(GenDefaultKernelCtx(), AsyncLaunchKernel(
[this](int64_t regst_desc_id) -> Regst* { GenDefaultKernelCtx(), [this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id); Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) { if (regst == nullptr) {
return readable_regst_.at(regst_desc_id).front(); return readable_regst_mgr_.GetCurReadable(regst_desc_id);
} else { } else {
return regst; return regst;
} }
}); });
AsyncSendRegstMsgToConsumer([&](Regst* regst) { AsyncSendRegstMsgToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id); regst->set_piece_id(piece_id);
return regst->col_id() <= regst->max_col_id(); return regst->col_id() <= regst->max_col_id();
}); });
int32_t cur_max_cid = 0; int32_t cur_max_cid = 0;
int32_t cur_max_maxcid = 0; int32_t cur_max_maxcid = 0;
for (const auto& pair : readable_regst_) { readable_regst_mgr_.ForEachCurReadableRegst([&](Regst* regst) {
cur_max_cid = std::max(cur_max_cid, pair.second.front()->col_id()); cur_max_cid = std::max(cur_max_cid, regst->col_id());
cur_max_maxcid = cur_max_maxcid = std::max(cur_max_maxcid, regst->max_col_id());
std::max(cur_max_maxcid, pair.second.front()->max_col_id()); });
} readable_regst_mgr_.ReturnToProducerAndPopCurReadable(
for (auto& pair : readable_regst_) { this, [&](Regst* regst) {
if (col_id_order_ == ColIdOrder::kAscending) { if (col_id_order_ == ColIdOrder::kAscending) {
if (pair.second.front()->IsMaxCol() && cur_max_cid < cur_max_maxcid) { if (regst->IsMaxCol() && cur_max_cid < cur_max_maxcid) {
continue; return false;
} }
} else if (col_id_order_ == ColIdOrder::kDescending) { } else if (col_id_order_ == ColIdOrder::kDescending) {
if (pair.second.front()->col_id() < cur_max_cid) { continue; } if (regst->col_id() < cur_max_cid) { return false; }
} else { // do nothing } else { // do nothing
} }
AsyncSendRegstMsgToProducer(pair.second.front()); return true;
pair.second.pop(); });
if (pair.second.empty()) { readable_regst_cnt_ -= 1; }
}
} }
bool BoxingActor::IsReadReady() { bool BoxingActor::IsReadReady() { return readable_regst_mgr_.IsReadReady(); }
return readable_regst_.size() == readable_regst_cnt_;
}
bool BoxingActor::IsReadAlwaysUnReadyFromNow() { bool BoxingActor::IsReadAlwaysUnReadyFromNow() {
return is_eord_ && readable_regst_cnt_ == 0; return is_eord_ && readable_regst_mgr_.IsEmpty();
} }
void BoxingActor::AsyncReturnAllReadableRegst() { void BoxingActor::AsyncReturnAllReadableRegst() {
CHECK_EQ(readable_regst_cnt_, 0); CHECK(readable_regst_mgr_.IsEmpty());
} }
void BoxingActor::ForEachCurReadableRegst( void BoxingActor::ForEachCurReadableRegst(
std::function<void(const Regst*)> handler) { std::function<void(const Regst*)> func) {
for (const auto& pair : readable_regst_) { handler(pair.second.front()); } readable_regst_mgr_.ForEachCurReadableRegst(func);
} }
REGISTER_ACTOR(TaskType::kBoxing, BoxingActor); REGISTER_ACTOR(TaskType::kBoxing, BoxingActor);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_ #define ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
#include "oneflow/core/actor/actor.h" #include "oneflow/core/actor/actor.h"
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace oneflow { namespace oneflow {
...@@ -25,11 +26,9 @@ class BoxingActor final : public Actor { ...@@ -25,11 +26,9 @@ class BoxingActor final : public Actor {
void TrySetColIdOrder(const Regst*); void TrySetColIdOrder(const Regst*);
// <regst_desc_id, regst*> NaiveReadableRegstMgr readable_regst_mgr_;
HashMap<int64_t, std::queue<Regst*>> readable_regst_;
// <regst_desc_id, <pid, cid>> // <regst_desc_id, <pid, cid>>
HashMap<int64_t, std::pair<int64_t, int32_t>>* previous_pid_cid_; HashMap<int64_t, std::pair<int64_t, int32_t>>* previous_pid_cid_;
int8_t readable_regst_cnt_;
ColIdOrder col_id_order_; ColIdOrder col_id_order_;
bool is_eord_; bool is_eord_;
}; };
......
...@@ -14,11 +14,7 @@ void MdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) { ...@@ -14,11 +14,7 @@ void MdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
related_save_model_actor_id_ = task_proto.related_save_model_task_id(); related_save_model_actor_id_ = task_proto.related_save_model_task_id();
related_init_model_actor_id_ = task_proto.related_init_model_task_id(); related_init_model_actor_id_ = task_proto.related_init_model_task_id();
pre_model_regst_ = nullptr; pre_model_regst_ = nullptr;
for (const auto& kv : task_proto.consumed_regst_desc_id()) { readable_regst_mgr_.Init(task_proto);
CHECK(
model_diff_acc_regsts_.emplace(kv.second, std::queue<Regst*>()).second);
}
readable_model_diff_acc_cnt_ = 0;
OF_SET_MSG_HANDLER(&MdUpdtCompActor::HandlerInitModelAndModelTmp); OF_SET_MSG_HANDLER(&MdUpdtCompActor::HandlerInitModelAndModelTmp);
} }
...@@ -71,9 +67,7 @@ int MdUpdtCompActor::HandlerNormal(const ActorMsg& actor_msg) { ...@@ -71,9 +67,7 @@ int MdUpdtCompActor::HandlerNormal(const ActorMsg& actor_msg) {
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
Regst* regst = actor_msg.regst(); Regst* regst = actor_msg.regst();
if (TryUpdtStateAsProducedRegst(regst) != 0) { if (TryUpdtStateAsProducedRegst(regst) != 0) {
auto it = model_diff_acc_regsts_.find(regst->regst_desc_id()); readable_regst_mgr_.Push(regst);
if (it->second.empty()) { readable_model_diff_acc_cnt_ += 1; }
it->second.push(regst);
} }
ActUntilFail(); ActUntilFail();
} else { } else {
...@@ -93,16 +87,12 @@ void MdUpdtCompActor::Act() { ...@@ -93,16 +87,12 @@ void MdUpdtCompActor::Act() {
AsyncLaunchKernel(kernel_ctx, [&](int64_t regst_desc_id) -> Regst* { AsyncLaunchKernel(kernel_ctx, [&](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id); Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) { if (regst == nullptr) {
return model_diff_acc_regsts_.at(regst_desc_id).front(); return readable_regst_mgr_.GetCurReadable(regst_desc_id);
} else { } else {
return regst; return regst;
} }
}); });
for (auto& kv : model_diff_acc_regsts_) { readable_regst_mgr_.ReturnToProducerAndPopCurReadable(this);
AsyncSendRegstMsgToProducer(kv.second.front());
kv.second.pop();
if (kv.second.empty()) { readable_model_diff_acc_cnt_ -= 1; }
}
const JobDesc* job_desc = JobDesc::Singleton(); const JobDesc* job_desc = JobDesc::Singleton();
auto RegstPreProcess = [&](Regst* regst) { return regst == cur_model_regst; }; auto RegstPreProcess = [&](Regst* regst) { return regst == cur_model_regst; };
if (next_model_version_id_ == job_desc->TotalBatchNum()) { if (next_model_version_id_ == job_desc->TotalBatchNum()) {
...@@ -122,11 +112,11 @@ void MdUpdtCompActor::Act() { ...@@ -122,11 +112,11 @@ void MdUpdtCompActor::Act() {
} }
bool MdUpdtCompActor::IsReadReady() { bool MdUpdtCompActor::IsReadReady() {
return readable_model_diff_acc_cnt_ == model_diff_acc_regsts_.size(); return readable_regst_mgr_.IsReadReady();
} }
bool MdUpdtCompActor::IsReadAlwaysUnReadyFromNow() { bool MdUpdtCompActor::IsReadAlwaysUnReadyFromNow() {
return is_model_diff_acc_eord_ && readable_model_diff_acc_cnt_ == 0; return is_model_diff_acc_eord_ && readable_regst_mgr_.IsEmpty();
} }
bool MdUpdtCompActor::IsWriteReady() { bool MdUpdtCompActor::IsWriteReady() {
...@@ -134,14 +124,12 @@ bool MdUpdtCompActor::IsWriteReady() { ...@@ -134,14 +124,12 @@ bool MdUpdtCompActor::IsWriteReady() {
} }
void MdUpdtCompActor::AsyncReturnAllReadableRegst() { void MdUpdtCompActor::AsyncReturnAllReadableRegst() {
CHECK_EQ(0, readable_model_diff_acc_cnt_); CHECK(readable_regst_mgr_.IsEmpty());
} }
void MdUpdtCompActor::ForEachCurReadableRegst( void MdUpdtCompActor::ForEachCurReadableRegst(
std::function<void(const Regst*)> handler) { std::function<void(const Regst*)> func) {
for (const auto& pair : model_diff_acc_regsts_) { readable_regst_mgr_.ForEachCurReadableRegst(func);
handler(pair.second.front());
}
} }
REGISTER_ACTOR(TaskType::kMdUpdt, MdUpdtCompActor); REGISTER_ACTOR(TaskType::kMdUpdt, MdUpdtCompActor);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMPUTE_ACTOR_H_ #define ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/compute_actor.h" #include "oneflow/core/actor/compute_actor.h"
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace oneflow { namespace oneflow {
...@@ -32,8 +33,7 @@ class MdUpdtCompActor final : public CompActor { ...@@ -32,8 +33,7 @@ class MdUpdtCompActor final : public CompActor {
int64_t model_tmp_regst_desc_id_; int64_t model_tmp_regst_desc_id_;
int8_t init_remaining_cnt_; int8_t init_remaining_cnt_;
bool is_model_diff_acc_eord_; bool is_model_diff_acc_eord_;
int64_t readable_model_diff_acc_cnt_; NaiveReadableRegstMgr readable_regst_mgr_;
HashMap<int64_t, std::queue<Regst*>> model_diff_acc_regsts_;
int64_t next_model_version_id_; int64_t next_model_version_id_;
int64_t related_save_model_actor_id_; int64_t related_save_model_actor_id_;
int64_t related_init_model_actor_id_; int64_t related_init_model_actor_id_;
......
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace oneflow {
void NaiveReadableRegstMgr::Init(const TaskProto& task_proto) {
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
readable_regst_[pair.second] = {};
}
readable_regst_cnt_ = 0;
}
void NaiveReadableRegstMgr::Push(Regst* regst) {
std::queue<Regst*>& rq = readable_regst_.at(regst->regst_desc_id());
if (rq.empty()) { readable_regst_cnt_ += 1; }
rq.push(regst);
}
void NaiveReadableRegstMgr::ReturnToProducerAndPopCurReadable(
Actor* actor, std::function<bool(Regst*)> IsAllowed) {
for (auto& pair : readable_regst_) {
CHECK_EQ(pair.second.empty(), false);
if (IsAllowed(pair.second.front()) == false) { continue; }
actor->AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
if (pair.second.empty()) { readable_regst_cnt_ -= 1; }
}
}
void NaiveReadableRegstMgr::ReturnToProducerAndPopCurReadable(Actor* actor) {
ReturnToProducerAndPopCurReadable(actor, [](Regst*) { return true; });
}
Regst* NaiveReadableRegstMgr::GetCurReadable(int64_t regst_desc_id) {
auto it = readable_regst_.find(regst_desc_id);
if (it != readable_regst_.end() && it->second.empty() == false) {
return it->second.front();
} else {
return nullptr;
}
}
void NaiveReadableRegstMgr::ForEachCurReadableRegst(
std::function<void(Regst*)> func) {
for (const auto& pair : readable_regst_) {
if (pair.second.empty() == false) { func(pair.second.front()); }
}
}
bool NaiveReadableRegstMgr::IsReadReady() {
return readable_regst_.size() == readable_regst_cnt_;
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
#define ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
#include "oneflow/core/actor/actor.h"
namespace oneflow {
class NaiveReadableRegstMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(NaiveReadableRegstMgr);
NaiveReadableRegstMgr() : readable_regst_cnt_(0) {}
~NaiveReadableRegstMgr() = default;
void Init(const TaskProto& task_proto);
void Push(Regst* regst);
void ReturnToProducerAndPopCurReadable(Actor* actor,
std::function<bool(Regst*)> IsAllowed);
void ReturnToProducerAndPopCurReadable(Actor* actor);
Regst* GetCurReadable(int64_t regst_desc_id);
Regst* GetFirstCurReadable() {
return readable_regst_.begin()->second.front();
}
void ForEachCurReadableRegst(std::function<void(Regst*)> func);
bool IsReadReady();
bool IsEmpty() { return readable_regst_cnt_ == 0; }
private:
HashMap<int64_t, std::queue<Regst*>> readable_regst_; // regst_desc_id
size_t readable_regst_cnt_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册