diff --git a/oneflow/core/actor/copy_actor.cpp b/oneflow/core/actor/copy_actor.cpp index 0215c2a61b6d6dd040480f914f2f2e984d050064..8eb45f6ca98adbd22e2dc6fffd7fa9a5458c8466 100644 --- a/oneflow/core/actor/copy_actor.cpp +++ b/oneflow/core/actor/copy_actor.cpp @@ -1,4 +1,5 @@ #include "oneflow/core/actor/copy_actor.h" +#include "oneflow/core/common/util.h" #include "oneflow/core/register/local_register_warpper.h" namespace oneflow { @@ -13,6 +14,8 @@ void CopyActor::ProcessMsgAndWardKernel(const ActorMsg& msg, waiting_in_regst_.push(std::move(msg.regst_warpper())); } if (!waiting_in_regst_.empty() && IsWriteReady()) { + uint64_t piece_id = expected_piece_id(); + CHECK_EQ(waiting_in_regst.front()->piece_id(), piece_id); WardKernel(kernel_ctx, [this](uint64_t regst_desc_id) -> std::shared_ptr { Regst* regst = GetCurWriteableRegst(regst_desc_id); if (regst == nullptr) { @@ -22,6 +25,9 @@ void CopyActor::ProcessMsgAndWardKernel(const ActorMsg& msg, return std::make_shared (regst); } }); + ForEachCurWriteableRegst([piece_id](Regst* regst) { + regst->set_piece_id(piece_id); + }); CurWriteDone(); std::shared_ptr regst = waiting_in_regst_.front(); ActorMsgBus::Singleton().SendMsg(ActorMsg::BuildMsgForRegstWriter( diff --git a/oneflow/core/actor/copy_actor.h b/oneflow/core/actor/copy_actor.h index ab963bcac5d661009d674aaa8053bf00f40b69bf..c9dc4c5428710c2bfc954001050aec5c1b0696a9 100644 --- a/oneflow/core/actor/copy_actor.h +++ b/oneflow/core/actor/copy_actor.h @@ -19,6 +19,7 @@ protected: private: std::queue> waiting_in_regst_; + uint64_t waiting_piece_id_; };