diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index 9564b73a21d63554621f1641f45b5bb44312dca7..376c040a06d86261b57e295ed8c68e3db84f69fd 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -41,6 +41,7 @@ void Actor::Init(const TaskProto& task_proto) { produced_regst2reading_cnt_[regst.get()] = 0; } writeable_produced_regst_desc_num_ = writeable_produced_regst_.size(); + total_reading_cnt_ = 0; } void Actor::AsyncWardKernelAndSendMsgToRegstReader( @@ -64,6 +65,7 @@ void Actor::AsyncWardKernelAndSendMsgToRegstReader( } }); produced_regst2reading_cnt_.at(regst) = regst->subscribers_actor_id().size(); + total_reading_cnt_ += regst->subscribers_actor_id().size(); if (!regst->subscribers_actor_id().empty()) { pair.second.pop(); } if (pair.second.empty()) { writeable_produced_regst_desc_num_ -= 1; } } @@ -74,6 +76,7 @@ int Actor::TryUpdtStateAsFromRegstReader(Regst* regst) { if (reading_cnt_it == produced_regst2reading_cnt_.end()) { return -1; } CHECK_GE(reading_cnt_it->second, 1); reading_cnt_it->second -= 1; + total_reading_cnt_ -= 1; if (reading_cnt_it->second != 0) { return 0; } auto writeable_it = writeable_produced_regst_.find(regst->regst_desc_id()); if (writeable_it == writeable_produced_regst_.end()) { return 0; } diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index 3cdd55ddf16c9e6dcf07240d15f9e12879e9e2d4..2814002a9467d07549e9c100f2096621fa30529a 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -54,6 +54,7 @@ class Actor { if (!it->second.empty()) { writeable_produced_regst_desc_num_ -= 1; } writeable_produced_regst_.erase(it); } + int64_t total_reading_cnt() const { return total_reading_cnt_; } private: uint64_t actor_id_; @@ -67,6 +68,7 @@ class Actor { HashMap> writeable_produced_regst_; // uint64_t writeable_produced_regst_desc_num_; HashMap produced_regst2reading_cnt_; + int64_t total_reading_cnt_; };