提交 74acbc55 编写于 作者: J Jingwu Chen 提交者: Liu Guo

FW and BP (#190)

* total_piece_num

* fw

* fix for review

* code style

* fix bug

* fix bug for data loader

* kStart in FW for data loader

* fix bug

* bp

* bp && fw

* unexpected_run

* boxing state

* fix bug

* ProcessEord & Boxing fixed

* ProcessEord in Actor and fix boxing

* remove useless override

* total_reading_cnt_

* use mut

* redefine fw

* redefine bp

* remove useless variable
上级 b80bacb8
......@@ -55,6 +55,7 @@ void Actor::ProcessEord() {
} else {
OF_SET_MSG_HANDLE(&Actor::HandleWaitUntilReadingCntEqualZero);
}
AsyncSendEORDMsgForAllProducedRegstDesc();
} else {
OF_SET_MSG_HANDLE(&Actor::HandleWaitUntilNoReadableRegst);
}
......
......@@ -40,7 +40,7 @@ class Actor {
std::unique_ptr<DeviceCtx>& mut_device_ctx() { return device_ctx_; }
KernelCtx GenDefaultKernelCtx() const;
int& mut_num_of_not_eord() { return num_of_not_eord_; }
void set_num_of_not_eord(int val) { num_of_not_eord_ = val; }
int& mut_num_of_read_empty() { return num_of_read_empty_; }
// Msg Handle
......
......@@ -8,7 +8,7 @@ void BoxingActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
int num_of_subscribed_regsts = task_proto.subscribed_regst_desc_id().size();
mut_num_of_not_eord() = num_of_subscribed_regsts;
set_num_of_not_eord(num_of_subscribed_regsts);
mut_num_of_read_empty() = num_of_subscribed_regsts;
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
......
......@@ -12,10 +12,10 @@ void BpDataCompActor::Init(const TaskProto& task_proto,
activation_regst_desc_id_ = RegstDescId4Name("activation");
data_tmp_regst_desc_id_ = RegstDescId4Name("data_tmp");
expected_model_version_id_ = 0;
num_of_read_empty_ =
mut_num_of_read_empty() =
2 + (model_regst_desc_id_ != -1) + (model_tmp_regst_desc_id_ != -1)
+ (activation_regst_desc_id_ != -1) + (data_tmp_regst_desc_id_ != -1);
num_of_not_eord_ = num_of_read_empty_;
set_num_of_not_eord(mut_num_of_read_empty());
if (thread_ctx.cpu_stream) {
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
} else {
......@@ -27,7 +27,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto,
}
bool BpDataCompActor::IsReadReady() {
if (num_of_read_empty_ || piece_model_id_.empty()) { return false; }
if (mut_num_of_read_empty() || piece_model_id_.empty()) { return false; }
if (model_regst_desc_id_ != -1) {
CHECK_GE(piece_model_id_.front().second, 0);
while (read_regst_.at(model_regst_desc_id_).front()->model_version_id()
......@@ -36,17 +36,31 @@ bool BpDataCompActor::IsReadReady() {
AsyncSendRegstMsgToProducer(read_regst_.at(model_regst_desc_id_).front());
read_regst_.at(model_regst_desc_id_).pop();
}
num_of_read_empty_ += read_regst_.at(model_regst_desc_id_).empty();
mut_num_of_read_empty() += read_regst_.at(model_regst_desc_id_).empty();
}
return !mut_num_of_read_empty();
}
void BpDataCompActor::AsyncSendMsgToModelAndModelTmpProducer() {
while (model_regst_desc_id_ != -1
&& !read_regst_.at(model_regst_desc_id_).empty()) {
AsyncSendRegstMsgToProducer(read_regst_.at(model_regst_desc_id_).front());
read_regst_.at(model_regst_desc_id_).pop();
}
if (model_tmp_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(
read_regst_.at(model_tmp_regst_desc_id_).front());
read_regst_.at(model_tmp_regst_desc_id_).pop();
}
return !num_of_read_empty_;
}
int BpDataCompActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_not_eord_ -= 1;
if (!num_of_not_eord_) {
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilNoReadableRegst);
ProcessEord();
if (msg_handle() == &BpDataCompActor::HandleWaitUntilReadingCntEqualZero
|| msg_handle() == nullptr) {
AsyncSendMsgToModelAndModelTmpProducer();
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr())
......@@ -59,42 +73,26 @@ int BpDataCompActor::HandleNormal(const ActorMsg& msg) {
} else {
// do nothing
}
num_of_read_empty_ -= read_regst_[regst_wp->regst_desc_id()].empty();
mut_num_of_read_empty() -= read_regst_[regst_wp->regst_desc_id()].empty();
read_regst_.at(regst_wp->regst_desc_id()).push(regst_wp);
}
ActUntilFail();
} else if (msg.msg_type() == ActorMsgType::kPieceModelIdMsg) {
piece_model_id_.emplace(msg.piece_id(), msg.model_version_id());
} else {
UNEXPECTED_RUN();
}
ActUntilFail();
return 0;
return msg_handle() == nullptr;
}
int BpDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr()),
0);
ActUntilFail();
if (piece_model_id_.empty()) {
while (model_regst_desc_id_ != -1
&& !read_regst_.at(model_regst_desc_id_).empty()) {
AsyncSendRegstMsgToProducer(read_regst_.at(model_regst_desc_id_).front());
read_regst_.at(model_regst_desc_id_).pop();
}
if (model_tmp_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(
read_regst_.at(model_tmp_regst_desc_id_).front());
read_regst_.at(model_tmp_regst_desc_id_).pop();
}
if (mut_num_of_read_empty()) {
AsyncSendMsgToModelAndModelTmpProducer();
AsyncSendEORDMsgForAllProducedRegstDesc();
num_of_read_empty_ = 6;
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -126,7 +124,7 @@ void BpDataCompActor::Act() {
&& pair.first != model_tmp_regst_desc_id_) {
AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
num_of_read_empty_ += pair.second.empty();
mut_num_of_read_empty() += pair.second.empty();
}
}
}
......
......@@ -19,10 +19,9 @@ class BpDataCompActor final : public Actor {
bool IsReadReady() override;
void Act() override;
void AsyncSendMsgToModelAndModelTmpProducer();
CudaStreamHandle cuda_handle_;
int num_of_read_empty_;
int num_of_not_eord_;
int64_t expected_model_version_id_;
int64_t model_regst_desc_id_;
int64_t model_tmp_regst_desc_id_;
......
......@@ -25,8 +25,9 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
kernel_ctx_.other = reinterpret_cast<void*>(parallel_id());
OF_SET_MSG_HANDLE(&FwDataCompActor::WaitToStart);
} else {
num_of_not_eord_ =
1 + (model_regst_desc_id_ != -1) + (model_tmp_regst_desc_id_ != -1);
set_num_of_not_eord(1 + (model_regst_desc_id_ != -1)
+ (model_tmp_regst_desc_id_ != -1));
mut_num_of_read_empty() = 1; // only consider "in"regst
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleNormal);
}
bp_actor_id_ = IDMgr::Singleton()->ActorId4TaskId(task_proto.bp_task_id());
......@@ -58,12 +59,24 @@ int FwDataCompActor::WaitToStart(const ActorMsg& msg) {
return 0;
}
void FwDataCompActor::AsyncSendMsgToModelAndModelTmpProducer() {
if (model_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(model_regst_);
model_regst_ = nullptr;
}
if (model_tmp_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(model_tmp_regst_);
model_tmp_regst_ = nullptr;
}
}
int FwDataCompActor::HandleNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
num_of_not_eord_ -= 1;
if (!num_of_not_eord_) {
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilNoReadableRegst);
ProcessEord();
if (msg_handle() == &FwDataCompActor::HandleWaitUntilReadingCntEqualZero
|| msg_handle() == nullptr) {
AsyncSendMsgToModelAndModelTmpProducer();
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_wrapper()->regst_raw_ptr())
......@@ -80,12 +93,13 @@ int FwDataCompActor::HandleNormal(const ActorMsg& msg) {
readable_regst_[model_regst_desc_id_] = regst_wp;
expected_model_version_id_ += 1;
} else {
mut_num_of_read_empty() -= in_.empty();
in_.push(regst_wp);
}
}
ActUntilFail();
}
ActUntilFail();
return 0;
return msg_handle() == nullptr;
}
int FwDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
......@@ -95,22 +109,9 @@ int FwDataCompActor::HandleWaitUntilNoReadableRegst(const ActorMsg& msg) {
int total_piece_num = JobDesc::Singleton()->total_piece_num();
if ((in_desc_id_ != -1 && in_.empty())
|| expected_piece_id() == total_piece_num) {
if (model_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(model_regst_);
model_regst_ = nullptr;
}
if (model_tmp_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(model_tmp_regst_);
model_tmp_regst_ = nullptr;
}
AsyncSendMsgToModelAndModelTmpProducer();
AsyncSendEORDMsgForAllProducedRegstDesc();
if (total_reading_cnt() == 0) {
OF_SET_MSG_HANDLE(nullptr);
return 1;
} else {
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilReadingCntEqualZero);
return 0;
}
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleWaitUntilReadingCntEqualZero);
}
return 0;
}
......@@ -140,6 +141,7 @@ void FwDataCompActor::Act() {
if (!in_.empty()) {
AsyncSendRegstMsgToProducer(in_.front());
in_.pop();
mut_num_of_read_empty() += in_.empty();
}
if (bp_actor_id_ != -1) {
ActorMsg msg;
......
......@@ -20,9 +20,9 @@ class FwDataCompActor final : public CompActor {
bool IsReadReady() override;
void Act() override;
void AsyncSendMsgToModelAndModelTmpProducer();
CudaStreamHandle cuda_handle_;
int num_of_not_eord_;
int64_t expected_model_version_id_;
int64_t in_desc_id_;
int64_t model_regst_desc_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册