提交 09761973 编写于 作者: N Niu Chong 提交者: Jinhui Yuan

fix: add AsyncSednRegstMsgToConsumer() for send single produced regst, e.g....

fix: add AsyncSednRegstMsgToConsumer() for send single produced regst, e.g. forward_model_regst (#1274)

* fix(normal_model_update_compute_actor): fix send forward_model_regst_ to consumer

* fix: add AsyncSednRegstMsgToConsumer() for send single produced regst, e.g. forward_model_regst


Former-commit-id: 139c2241
上级 8626f4c2
......@@ -415,6 +415,17 @@ void Actor::HandleProducedNaiveDataRegstToConsumer() {
HandleProducedNaiveDataRegstToConsumer([](Regst*) { return true; });
}
void Actor::AsyncSendRegstMsgToConsumer(Regst* regst) {
AsyncSendRegstMsgToConsumer(regst, [](int64_t) { return true; });
}
void Actor::AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor) {
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, IsAllowedActor);
if (real_consumer_cnt > 0) {
CHECK_EQ(0, naive_produced_rs_.TryPopFrontRegst(regst->regst_desc_id()));
}
}
void Actor::HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst) {
std::vector<int64_t> regst_desc_ids;
naive_consumed_rs_.ForEachFrontRegst([&](Regst* regst) {
......
......@@ -89,12 +89,13 @@ class Actor {
// Util For Derived Actor to Send Msg
void AsyncSendMsg(const ActorMsg&);
int64_t HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer();
void AsyncSendRegstMsgToConsumer(Regst* regst);
void AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
void HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst);
void AsyncSendRegstMsgToProducer(Regst*);
......@@ -171,6 +172,7 @@ class Actor {
virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {}
void AsyncSendConsumedCtrlRegstMsgToProducer();
void AsyncSendProducedCtrlRegstMsgToConsumer();
int64_t HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
virtual void AsyncReturnAllCustomizedReadableRegst() {}
int64_t actor_id_;
......
......@@ -206,7 +206,7 @@ void NormalForwardCompActor::SendMsgToForwardModelSaveActor(int64_t batch_id) {
Regst* fw_model_regst = GetNaiveCurWriteable(forward_model_regst_desc_id_);
CHECK(fw_model_regst);
fw_model_regst->set_model_version_id(batch_id);
HandleRegstToConsumer(fw_model_regst, [](int64_t) { return true; });
AsyncSendRegstMsgToConsumer(fw_model_regst, [](int64_t) { return true; });
}
void NormalForwardCompActor::SendConstBufInitMsgToBwActor() {
......
......@@ -78,8 +78,7 @@ void NormalMdUpdtCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
return (need_save_model && is_saving_related) || (need_send_model && !is_saving_related);
});
if (need_save_model && forward_model_regst_ != nullptr) {
HandleProducedNaiveDataRegstToConsumer(
[&](Regst* regst) { return regst == forward_model_regst_; });
AsyncSendRegstMsgToConsumer(forward_model_regst_);
}
next_model_version_id_ += 1;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册