diff --git a/oneflow/core/actor/accumulate_compute_actor.cpp b/oneflow/core/actor/accumulate_compute_actor.cpp index 4a9e940ca10ddf1dc337ec337c0948ddea8c4cff..8b48fb5d23e3b1beea2ab1fda4a432c76544fb9f 100644 --- a/oneflow/core/actor/accumulate_compute_actor.cpp +++ b/oneflow/core/actor/accumulate_compute_actor.cpp @@ -25,7 +25,9 @@ void AccumulateCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt, next_piece_id_ = 0; } -int64_t AccumulateCompActor::ActNumForEachOutput() const { return max_acc_cnt_; } +int64_t AccumulateCompActor::ActNumForEachOutput(int64_t regst_desc_id) const { + return regst_desc_id == GetSoleProducedDataRegstDescId() ? max_acc_cnt_ : 1; +} void AccumulateCompActor::Act() { Regst* in_regst = GetNaiveSoleCurReadable(); diff --git a/oneflow/core/actor/accumulate_compute_actor.h b/oneflow/core/actor/accumulate_compute_actor.h index 4b1e265930461e3f4ea9de65116e4743c9f21489..3ee8ab77f042db09822e10b9ad31a2a746718ba4 100644 --- a/oneflow/core/actor/accumulate_compute_actor.h +++ b/oneflow/core/actor/accumulate_compute_actor.h @@ -13,7 +13,7 @@ class AccumulateCompActor : public CompActor { protected: void Init(const TaskProto&, int32_t max_acc_cnt, ColIdOrder order); - int64_t ActNumForEachOutput() const override; + int64_t ActNumForEachOutput(int64_t regst_desc_id) const override; private: void Act() override; diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index d523ea84fadd4f932ca6b5ad13b921c596aadddf..77de924fae3301806f25cbd76611d8921b619d93 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -417,6 +417,11 @@ Regst* Actor::GetSoleProducedRegst(int64_t regst_desc_id) { return it->second.front().get(); } +int64_t Actor::GetSoleProducedDataRegstDescId() const { + CHECK_EQ(produced_data_regsts_.size(), 1); + return produced_data_regsts_.begin()->first; +} + bool Actor::IsReadReady() { return naive_readable_data_regst_.size() == naive_readable_data_regst_cnt_ && IsCustomizedReadReady(); @@ -443,7 +448,7 @@ int Actor::ProcessWriteableCtrlRegstMsg(const ActorMsg& msg) { if (expected_act_id >= 0 && CheckOutputActId(regst->regst_desc_id())) { CHECK_EQ(regst->act_id(), expected_act_id); } - expected_act_id = regst->act_id() + ActNumForEachOutput(); + expected_act_id = regst->act_id() + ActNumForEachOutput(regst->regst_desc_id()); writeable_it->second.push_back(regst); return 0; } @@ -524,7 +529,7 @@ int Actor::TryUpdtStateAsProducedRegst(Regst* regst) { if (expected_act_id >= 0 && CheckOutputActId(regst->regst_desc_id())) { CHECK_EQ(regst->act_id(), expected_act_id); } - expected_act_id = regst->act_id() + ActNumForEachOutput(); + expected_act_id = regst->act_id() + ActNumForEachOutput(regst->regst_desc_id()); writeable_it->second.push_back(regst); return 0; } diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index 6132c8351e91a7684ee65e2dd25733a9805a4fe6..150e94f134d49cffb8a858d9968d93e5395a5b8d 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -85,7 +85,7 @@ class Actor { virtual bool IsCustomizedReadAlwaysUnReadyFromNow() { return false; } bool IsWriteReady(); virtual void AsyncReturnAllCustomizedReadableRegst() {} - virtual int64_t ActNumForEachOutput() const { return 1; } + virtual int64_t ActNumForEachOutput(int64_t regst_desc_id) const { return 1; } virtual bool CheckOutputActId(int64_t regst_desc_id) const { return true; // TODO(jiyuan): figure out the ActNumForEachOutput of the model regsts to MdSave // area @@ -117,6 +117,7 @@ class Actor { Regst* GetNaiveSoleCurReadable(); Regst* GetNaiveFirstCurReadable(); Regst* GetSoleProducedRegst(int64_t regst_desc_id); + int64_t GetSoleProducedDataRegstDescId() const; void DecreaseActualWriteableProducedDataRegstDescNum(int64_t amount) { actual_writeable_produced_data_regst_desc_num_ -= amount; diff --git a/oneflow/core/actor/input_wise_compute_actor.cpp b/oneflow/core/actor/input_wise_compute_actor.cpp index cd7b7d9f5a45ebd0e124acecfb60a7aa47212b98..156477f63f822cfef4190e1029309411d21826e4 100644 --- a/oneflow/core/actor/input_wise_compute_actor.cpp +++ b/oneflow/core/actor/input_wise_compute_actor.cpp @@ -29,7 +29,9 @@ void InputWiseCompActor::Init(const TaskProto& task_proto) { OF_SET_MSG_HANDLER(&InputWiseCompActor::HandlerNormal); } -int64_t InputWiseCompActor::ActNumForEachOutput() const { return regst_desc_id2in_bn_id_.size(); } +int64_t InputWiseCompActor::ActNumForEachOutput(int64_t regst_desc_id) const { + return regst_desc_id2in_bn_id_.size(); +} void InputWiseCompActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { Regst* regst = msg.regst(); diff --git a/oneflow/core/actor/input_wise_compute_actor.h b/oneflow/core/actor/input_wise_compute_actor.h index 0f6c5cd5a38f2a311a3be1b207d1753963bc3d82..678d52d2d41953fd06fada430d75cae7e0281397 100644 --- a/oneflow/core/actor/input_wise_compute_actor.h +++ b/oneflow/core/actor/input_wise_compute_actor.h @@ -17,7 +17,7 @@ class InputWiseCompActor : public CompActor { int64_t processed_regst_desc_id_cnt() const { return processed_regst_desc_id_cnt_; } int64_t RegstDescNum() const { return readable_regsts_.size(); } int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); } - int64_t ActNumForEachOutput() const override; + int64_t ActNumForEachOutput(int64_t regst_desc_id) const override; bool EnableInplace() const { return GetDeviceType() == DeviceType::kGPU && Global::Get()->enable_mem_sharing(); }