提交 5b33aa57 编写于 作者: J Jingwu Chen 提交者: Will Zhang

Boxing state (#183)

* total_piece_num

* fw

* fix for review

* code style

* fix bug

* fix bug for data loader

* kStart in FW for data loader

* fix bug

* bp

* bp && fw

* unexpected_run

* boxing state

* fix bug

* ProcessEord & Boxing fixed

* ProcessEord in Actor and fix boxing

* remove useless override

* total_reading_cnt_

* use mut
上级 be62127c
......@@ -46,6 +46,23 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
total_reading_cnt_ = 0;
}
void Actor::ProcessEord() {
num_of_not_eord_ -= 1;
if (!num_of_not_eord_) {
if (num_of_read_empty_) {
if (!total_reading_cnt_) {
OF_SET_MSG_HANDLE(nullptr);
} else {
OF_SET_MSG_HANDLE(&Actor::HandleWaitUntilReadingCntEqualZero);
}
} else {
OF_SET_MSG_HANDLE(&Actor::HandleWaitUntilNoReadableRegst);
}
} else {
// do nothing
}
}
int64_t Actor::RegstDescId4Name(const std::string& name) const {
auto find_it = name2regst_desc_id_.find(name);
if (find_it != name2regst_desc_id_.end()) { return find_it->second; }
......
......@@ -40,8 +40,12 @@ class Actor {
std::unique_ptr<DeviceCtx>& mut_device_ctx() { return device_ctx_; }
KernelCtx GenDefaultKernelCtx() const;
int& mut_num_of_not_eord() { return num_of_not_eord_; }
int& mut_num_of_read_empty() { return num_of_read_empty_; }
// Msg Handle
using MsgHandle = int (Actor::*)(const ActorMsg&);
MsgHandle msg_handle() { return msg_handle_; }
void set_msg_handle(MsgHandle val) { msg_handle_ = val; }
#define OF_SET_MSG_HANDLE(val) \
do { \
......@@ -58,7 +62,7 @@ class Actor {
void ActUntilFail();
virtual void Act() = 0;
virtual bool IsReadReady() = 0;
virtual void ProcessEord();
// Async Do on KernelCtx
void AsyncLaunchKernel(
const KernelCtx&,
......@@ -100,6 +104,8 @@ class Actor {
int64_t writeable_produced_regst_desc_num_;
HashMap<Regst*, int64_t> produced_regst2reading_cnt_;
int64_t total_reading_cnt_;
int num_of_not_eord_;
int num_of_read_empty_;
};
} // namespace oneflow
......
......@@ -7,9 +7,9 @@ namespace oneflow {
void BoxingActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
num_of_subscribed_regsts_ = task_proto.subscribed_regst_desc_id().size();
num_of_read_empty_ = num_of_subscribed_regsts_;
num_of_eord_ = 0;
int num_of_subscribed_regsts = task_proto.subscribed_regst_desc_id().size();
mut_num_of_not_eord() = num_of_subscribed_regsts;
mut_num_of_read_empty() = num_of_subscribed_regsts;
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
OF_SET_MSG_HANDLE(&BoxingActor::HandleNormal);
......@@ -18,37 +18,28 @@ void BoxingActor::Init(const TaskProto& task_proto,
int BoxingActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_eord_ += 1;
if (num_of_eord_ == num_of_subscribed_regsts_) {
OF_SET_MSG_HANDLE(&BoxingActor::HandleWaitUntilNoReadableRegst);
}
ProcessEord();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr())
!= 0) {
std::shared_ptr<RegstWrapper> regst_wp = msg.regst_wrapper();
num_of_read_empty_ -= read_regst_[regst_wp->regst_desc_id()].empty();
mut_num_of_read_empty() -= read_regst_[regst_wp->regst_desc_id()].empty();
read_regst_.at(regst_wp->regst_desc_id()).push(regst_wp);
} else {
// do nothing
}
ActUntilFail();
}
ActUntilFail();
return 0;
return msg_handle() == nullptr;
}
int BoxingActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr()),
0);
ActUntilFail();
if (num_of_read_empty_ == num_of_subscribed_regsts_) {
if (mut_num_of_read_empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&BoxingActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&BoxingActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -73,7 +64,7 @@ void BoxingActor::Act() {
for (auto& pair : read_regst_) {
AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
num_of_read_empty_ += pair.second.empty();
mut_num_of_read_empty() += pair.second.empty();
}
}
......
......@@ -17,12 +17,9 @@ class BoxingActor final : public Actor {
int HandleNormal(const ActorMsg&) override;
int HandleWaitUntilNoReadableRegst(const ActorMsg&) override;
bool IsReadReady() override { return !num_of_read_empty_; }
bool IsReadReady() override { return !mut_num_of_read_empty(); }
void Act() override;
int num_of_subscribed_regsts_;
int num_of_read_empty_;
int num_of_eord_;
// <regst_desc_id, queue<regst_wp>>
HashMap<int64_t, std::queue<std::shared_ptr<RegstWrapper>>> read_regst_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册