提交 09761973 编写于 作者: N Niu Chong 提交者: Jinhui Yuan

fix: add AsyncSednRegstMsgToConsumer() for send single produced regst, e.g....

fix: add AsyncSednRegstMsgToConsumer() for send single produced regst, e.g. forward_model_regst (#1274)

* fix(normal_model_update_compute_actor): fix send forward_model_regst_ to consumer

* fix: add AsyncSednRegstMsgToConsumer() for send single produced regst, e.g. forward_model_regst


Former-commit-id: 139c2241
上级 8626f4c2
...@@ -415,6 +415,17 @@ void Actor::HandleProducedNaiveDataRegstToConsumer() { ...@@ -415,6 +415,17 @@ void Actor::HandleProducedNaiveDataRegstToConsumer() {
HandleProducedNaiveDataRegstToConsumer([](Regst*) { return true; }); HandleProducedNaiveDataRegstToConsumer([](Regst*) { return true; });
} }
void Actor::AsyncSendRegstMsgToConsumer(Regst* regst) {
AsyncSendRegstMsgToConsumer(regst, [](int64_t) { return true; });
}
void Actor::AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor) {
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, IsAllowedActor);
if (real_consumer_cnt > 0) {
CHECK_EQ(0, naive_produced_rs_.TryPopFrontRegst(regst->regst_desc_id()));
}
}
void Actor::HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst) { void Actor::HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst) {
std::vector<int64_t> regst_desc_ids; std::vector<int64_t> regst_desc_ids;
naive_consumed_rs_.ForEachFrontRegst([&](Regst* regst) { naive_consumed_rs_.ForEachFrontRegst([&](Regst* regst) {
......
...@@ -89,12 +89,13 @@ class Actor { ...@@ -89,12 +89,13 @@ class Actor {
// Util For Derived Actor to Send Msg // Util For Derived Actor to Send Msg
void AsyncSendMsg(const ActorMsg&); void AsyncSendMsg(const ActorMsg&);
int64_t HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess, void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor); std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess); void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor); void HandleProducedNaiveDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer(); void HandleProducedNaiveDataRegstToConsumer();
void AsyncSendRegstMsgToConsumer(Regst* regst);
void AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
void HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst); void HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst);
void AsyncSendRegstMsgToProducer(Regst*); void AsyncSendRegstMsgToProducer(Regst*);
...@@ -171,6 +172,7 @@ class Actor { ...@@ -171,6 +172,7 @@ class Actor {
virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {} virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {}
void AsyncSendConsumedCtrlRegstMsgToProducer(); void AsyncSendConsumedCtrlRegstMsgToProducer();
void AsyncSendProducedCtrlRegstMsgToConsumer(); void AsyncSendProducedCtrlRegstMsgToConsumer();
int64_t HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
virtual void AsyncReturnAllCustomizedReadableRegst() {} virtual void AsyncReturnAllCustomizedReadableRegst() {}
int64_t actor_id_; int64_t actor_id_;
......
...@@ -206,7 +206,7 @@ void NormalForwardCompActor::SendMsgToForwardModelSaveActor(int64_t batch_id) { ...@@ -206,7 +206,7 @@ void NormalForwardCompActor::SendMsgToForwardModelSaveActor(int64_t batch_id) {
Regst* fw_model_regst = GetNaiveCurWriteable(forward_model_regst_desc_id_); Regst* fw_model_regst = GetNaiveCurWriteable(forward_model_regst_desc_id_);
CHECK(fw_model_regst); CHECK(fw_model_regst);
fw_model_regst->set_model_version_id(batch_id); fw_model_regst->set_model_version_id(batch_id);
HandleRegstToConsumer(fw_model_regst, [](int64_t) { return true; }); AsyncSendRegstMsgToConsumer(fw_model_regst, [](int64_t) { return true; });
} }
void NormalForwardCompActor::SendConstBufInitMsgToBwActor() { void NormalForwardCompActor::SendConstBufInitMsgToBwActor() {
......
...@@ -78,8 +78,7 @@ void NormalMdUpdtCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { ...@@ -78,8 +78,7 @@ void NormalMdUpdtCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
return (need_save_model && is_saving_related) || (need_send_model && !is_saving_related); return (need_save_model && is_saving_related) || (need_send_model && !is_saving_related);
}); });
if (need_save_model && forward_model_regst_ != nullptr) { if (need_save_model && forward_model_regst_ != nullptr) {
HandleProducedNaiveDataRegstToConsumer( AsyncSendRegstMsgToConsumer(forward_model_regst_);
[&](Regst* regst) { return regst == forward_model_regst_; });
} }
next_model_version_id_ += 1; next_model_version_id_ += 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册