提交 cc2ee9fc 编写于 作者: W willzhang4a58

naive readable register manager

上级 6f0754f6
......@@ -36,6 +36,8 @@ class Actor {
int64_t actor_id() const { return actor_id_; }
protected:
friend class NaiveReadableRegstMgr;
struct ExecKernel {
std::unique_ptr<const Kernel> kernel;
HashMap<std::string, int64_t> bn_in_op2regst_desc_id;
......
......@@ -4,11 +4,8 @@
namespace oneflow {
void BoxingActor::VirtualActorInit(const TaskProto& task_proto) {
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
readable_regst_[pair.second] = {};
}
readable_regst_mgr_.Init(task_proto);
previous_pid_cid_ = new HashMap<int64_t, std::pair<int64_t, int32_t>>;
readable_regst_cnt_ = 0;
col_id_order_ = ColIdOrder::kUnCertain;
is_eord_ = false;
OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal);
......@@ -48,9 +45,7 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
&& col_id_order_ == ColIdOrder::kUnCertain) {
TrySetColIdOrder(msg.regst());
}
std::queue<Regst*>& rq = readable_regst_.at(msg.regst()->regst_desc_id());
if (rq.empty()) { readable_regst_cnt_ += 1; }
rq.push(msg.regst());
readable_regst_mgr_.Push(msg.regst());
}
ActUntilFail();
} else {
......@@ -60,57 +55,53 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
}
void BoxingActor::Act() {
int64_t piece_id = readable_regst_.begin()->second.front()->piece_id();
AsyncLaunchKernel(GenDefaultKernelCtx(),
[this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return readable_regst_.at(regst_desc_id).front();
} else {
return regst;
}
});
int64_t piece_id = readable_regst_mgr_.GetFirstCurReadable()->piece_id();
AsyncLaunchKernel(
GenDefaultKernelCtx(), [this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return readable_regst_mgr_.GetCurReadable(regst_desc_id);
} else {
return regst;
}
});
AsyncSendRegstMsgToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id);
return regst->col_id() <= regst->max_col_id();
});
int32_t cur_max_cid = 0;
int32_t cur_max_maxcid = 0;
for (const auto& pair : readable_regst_) {
cur_max_cid = std::max(cur_max_cid, pair.second.front()->col_id());
cur_max_maxcid =
std::max(cur_max_maxcid, pair.second.front()->max_col_id());
}
for (auto& pair : readable_regst_) {
if (col_id_order_ == ColIdOrder::kAscending) {
if (pair.second.front()->IsMaxCol() && cur_max_cid < cur_max_maxcid) {
continue;
}
} else if (col_id_order_ == ColIdOrder::kDescending) {
if (pair.second.front()->col_id() < cur_max_cid) { continue; }
} else { // do nothing
}
AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
if (pair.second.empty()) { readable_regst_cnt_ -= 1; }
}
readable_regst_mgr_.ForEachCurReadableRegst([&](Regst* regst) {
cur_max_cid = std::max(cur_max_cid, regst->col_id());
cur_max_maxcid = std::max(cur_max_maxcid, regst->max_col_id());
});
readable_regst_mgr_.ReturnToProducerAndPopCurReadable(
this, [&](Regst* regst) {
if (col_id_order_ == ColIdOrder::kAscending) {
if (regst->IsMaxCol() && cur_max_cid < cur_max_maxcid) {
return false;
}
} else if (col_id_order_ == ColIdOrder::kDescending) {
if (regst->col_id() < cur_max_cid) { return false; }
} else { // do nothing
}
return true;
});
}
bool BoxingActor::IsReadReady() {
return readable_regst_.size() == readable_regst_cnt_;
}
bool BoxingActor::IsReadReady() { return readable_regst_mgr_.IsReadReady(); }
bool BoxingActor::IsReadAlwaysUnReadyFromNow() {
return is_eord_ && readable_regst_cnt_ == 0;
return is_eord_ && readable_regst_mgr_.IsEmpty();
}
void BoxingActor::AsyncReturnAllReadableRegst() {
CHECK_EQ(readable_regst_cnt_, 0);
CHECK(readable_regst_mgr_.IsEmpty());
}
void BoxingActor::ForEachCurReadableRegst(
std::function<void(const Regst*)> handler) {
for (const auto& pair : readable_regst_) { handler(pair.second.front()); }
std::function<void(const Regst*)> func) {
readable_regst_mgr_.ForEachCurReadableRegst(func);
}
REGISTER_ACTOR(TaskType::kBoxing, BoxingActor);
......
......@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace oneflow {
......@@ -25,11 +26,9 @@ class BoxingActor final : public Actor {
void TrySetColIdOrder(const Regst*);
// <regst_desc_id, regst*>
HashMap<int64_t, std::queue<Regst*>> readable_regst_;
NaiveReadableRegstMgr readable_regst_mgr_;
// <regst_desc_id, <pid, cid>>
HashMap<int64_t, std::pair<int64_t, int32_t>>* previous_pid_cid_;
int8_t readable_regst_cnt_;
ColIdOrder col_id_order_;
bool is_eord_;
};
......
......@@ -14,11 +14,7 @@ void MdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
related_save_model_actor_id_ = task_proto.related_save_model_task_id();
related_init_model_actor_id_ = task_proto.related_init_model_task_id();
pre_model_regst_ = nullptr;
for (const auto& kv : task_proto.consumed_regst_desc_id()) {
CHECK(
model_diff_acc_regsts_.emplace(kv.second, std::queue<Regst*>()).second);
}
readable_model_diff_acc_cnt_ = 0;
readable_regst_mgr_.Init(task_proto);
OF_SET_MSG_HANDLER(&MdUpdtCompActor::HandlerInitModelAndModelTmp);
}
......@@ -71,9 +67,7 @@ int MdUpdtCompActor::HandlerNormal(const ActorMsg& actor_msg) {
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
Regst* regst = actor_msg.regst();
if (TryUpdtStateAsProducedRegst(regst) != 0) {
auto it = model_diff_acc_regsts_.find(regst->regst_desc_id());
if (it->second.empty()) { readable_model_diff_acc_cnt_ += 1; }
it->second.push(regst);
readable_regst_mgr_.Push(regst);
}
ActUntilFail();
} else {
......@@ -93,16 +87,12 @@ void MdUpdtCompActor::Act() {
AsyncLaunchKernel(kernel_ctx, [&](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return model_diff_acc_regsts_.at(regst_desc_id).front();
return readable_regst_mgr_.GetCurReadable(regst_desc_id);
} else {
return regst;
}
});
for (auto& kv : model_diff_acc_regsts_) {
AsyncSendRegstMsgToProducer(kv.second.front());
kv.second.pop();
if (kv.second.empty()) { readable_model_diff_acc_cnt_ -= 1; }
}
readable_regst_mgr_.ReturnToProducerAndPopCurReadable(this);
const JobDesc* job_desc = JobDesc::Singleton();
auto RegstPreProcess = [&](Regst* regst) { return regst == cur_model_regst; };
if (next_model_version_id_ == job_desc->TotalBatchNum()) {
......@@ -122,11 +112,11 @@ void MdUpdtCompActor::Act() {
}
bool MdUpdtCompActor::IsReadReady() {
return readable_model_diff_acc_cnt_ == model_diff_acc_regsts_.size();
return readable_regst_mgr_.IsReadReady();
}
bool MdUpdtCompActor::IsReadAlwaysUnReadyFromNow() {
return is_model_diff_acc_eord_ && readable_model_diff_acc_cnt_ == 0;
return is_model_diff_acc_eord_ && readable_regst_mgr_.IsEmpty();
}
bool MdUpdtCompActor::IsWriteReady() {
......@@ -134,14 +124,12 @@ bool MdUpdtCompActor::IsWriteReady() {
}
void MdUpdtCompActor::AsyncReturnAllReadableRegst() {
CHECK_EQ(0, readable_model_diff_acc_cnt_);
CHECK(readable_regst_mgr_.IsEmpty());
}
void MdUpdtCompActor::ForEachCurReadableRegst(
std::function<void(const Regst*)> handler) {
for (const auto& pair : model_diff_acc_regsts_) {
handler(pair.second.front());
}
std::function<void(const Regst*)> func) {
readable_regst_mgr_.ForEachCurReadableRegst(func);
}
REGISTER_ACTOR(TaskType::kMdUpdt, MdUpdtCompActor);
......
......@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/compute_actor.h"
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace oneflow {
......@@ -32,8 +33,7 @@ class MdUpdtCompActor final : public CompActor {
int64_t model_tmp_regst_desc_id_;
int8_t init_remaining_cnt_;
bool is_model_diff_acc_eord_;
int64_t readable_model_diff_acc_cnt_;
HashMap<int64_t, std::queue<Regst*>> model_diff_acc_regsts_;
NaiveReadableRegstMgr readable_regst_mgr_;
int64_t next_model_version_id_;
int64_t related_save_model_actor_id_;
int64_t related_init_model_actor_id_;
......
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace oneflow {
void NaiveReadableRegstMgr::Init(const TaskProto& task_proto) {
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
readable_regst_[pair.second] = {};
}
readable_regst_cnt_ = 0;
}
void NaiveReadableRegstMgr::Push(Regst* regst) {
std::queue<Regst*>& rq = readable_regst_.at(regst->regst_desc_id());
if (rq.empty()) { readable_regst_cnt_ += 1; }
rq.push(regst);
}
void NaiveReadableRegstMgr::ReturnToProducerAndPopCurReadable(
Actor* actor, std::function<bool(Regst*)> IsAllowed) {
for (auto& pair : readable_regst_) {
CHECK_EQ(pair.second.empty(), false);
if (IsAllowed(pair.second.front()) == false) { continue; }
actor->AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
if (pair.second.empty()) { readable_regst_cnt_ -= 1; }
}
}
void NaiveReadableRegstMgr::ReturnToProducerAndPopCurReadable(Actor* actor) {
ReturnToProducerAndPopCurReadable(actor, [](Regst*) { return true; });
}
Regst* NaiveReadableRegstMgr::GetCurReadable(int64_t regst_desc_id) {
auto it = readable_regst_.find(regst_desc_id);
if (it != readable_regst_.end() && it->second.empty() == false) {
return it->second.front();
} else {
return nullptr;
}
}
void NaiveReadableRegstMgr::ForEachCurReadableRegst(
std::function<void(Regst*)> func) {
for (const auto& pair : readable_regst_) {
if (pair.second.empty() == false) { func(pair.second.front()); }
}
}
bool NaiveReadableRegstMgr::IsReadReady() {
return readable_regst_.size() == readable_regst_cnt_;
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
#define ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
#include "oneflow/core/actor/actor.h"
namespace oneflow {
class NaiveReadableRegstMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(NaiveReadableRegstMgr);
NaiveReadableRegstMgr() : readable_regst_cnt_(0) {}
~NaiveReadableRegstMgr() = default;
void Init(const TaskProto& task_proto);
void Push(Regst* regst);
void ReturnToProducerAndPopCurReadable(Actor* actor,
std::function<bool(Regst*)> IsAllowed);
void ReturnToProducerAndPopCurReadable(Actor* actor);
Regst* GetCurReadable(int64_t regst_desc_id);
Regst* GetFirstCurReadable() {
return readable_regst_.begin()->second.front();
}
void ForEachCurReadableRegst(std::function<void(Regst*)> func);
bool IsReadReady();
bool IsEmpty() { return readable_regst_cnt_ == 0; }
private:
HashMap<int64_t, std::queue<Regst*>> readable_regst_; // regst_desc_id
size_t readable_regst_cnt_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册