diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index ddab7002312c5cee02a638159f492dcd027b033a..15e0900ead2f1a4f5593dac7efad4a2b24b7a73e 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -451,9 +451,18 @@ void Actor::AsyncSendCtrlRegstMsg() { for (auto& pair : consumed_ctrl_regst_) { CHECK(!pair.second.empty()); Regst* regst = pair.second.front(); - AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst)); - pair.second.pop_front(); - if (pair.second.empty()) { --readable_ctrl_regst_desc_cnt_; } + int32_t returned_regst_num = + regst->regst_desc()->regst_desc_type().ctrl_regst_desc().returned_regst_num(); + CHECK_GE(returned_regst_num, 1); + if (!pair.second.empty()) { + CHECK_GE(pair.second.size(), returned_regst_num); + while (returned_regst_num--) { + AsyncSendMsg( + ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst)); + pair.second.pop_front(); + } + if (pair.second.empty()) { --readable_ctrl_regst_desc_cnt_; } + } } for (auto& pair : writeable_produced_ctrl_regst_) { CHECK(!pair.second.empty()); diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index a74b2f34af893d1c7a1fc8e46231135cd66d4721..1e2e9334f24d643361f73ea8418c9b39d1e9ba5e 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -1,4 +1,6 @@ #include "oneflow/core/graph/task_graph.h" +#include "oneflow/core/graph/normal_forward_compute_task_node.h" +#include "oneflow/core/graph/normal_model_update_compute_task_node.h" #include "oneflow/core/graph/chain_graph.h" #include "oneflow/core/graph/boxing_task_node.h" #include "oneflow/core/common/balanced_splitter.h" @@ -184,7 +186,60 @@ bool TaskGraph::IsEndingTaskType(TaskType type) { void TaskGraph::AddMutexCtrlEdgeInSameChain() { UNIMPLEMENTED(); } -void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() { UNIMPLEMENTED(); } +void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() { + for (TaskNode* task_node : ordered_task_nodes_) { + auto copy_hd_task_node = dynamic_cast(task_node); + if (copy_hd_task_node == nullptr) { continue; } + if (copy_hd_task_node->copy_type() != CopyHdOpConf::H2D) { continue; } + if (copy_hd_task_node->area_id() != static_cast(kDataForwardArea) + && copy_hd_task_node->area_id() != static_cast(kBoundaryArea)) { + continue; + } + std::vector candidate_nodes; + auto ForEachNextNode = [&](TaskNode* node, + const std::function& TryPushNodeToQueue) { + auto fw_task_node = dynamic_cast(node); + if (fw_task_node != nullptr && fw_task_node->logical_node()->HasOpWithModelBlob()) { return; } + node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) { + if (IsForwardTaskType(node_on_out_edge->GetTaskType())) { + TryPushNodeToQueue(node_on_out_edge); + } + }); + }; + auto HandlerAddCandidate = [&](TaskNode* node) { + auto fw_task_node = dynamic_cast(node); + if (fw_task_node != nullptr && fw_task_node->logical_node()->HasOpWithModelBlob() + && fw_task_node->parallel_ctx()->parallel_num() > 1 + && fw_task_node->parallel_ctx()->policy() == kDataParallel) { + candidate_nodes.push_back(node); + } + }; + BfsForEachNode({task_node}, ForEachNextNode, HandlerAddCandidate); + std::sort(candidate_nodes.begin(), candidate_nodes.end(), + [](const TaskNode* a, const TaskNode* b) { + return a->order_in_graph() < b->order_in_graph(); + }); + int64_t last_chain_id = -1; + for (TaskNode* candidate_node : candidate_nodes) { + if (candidate_node->chain_id() != last_chain_id) { + last_chain_id = candidate_node->chain_id(); + candidate_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { + if (IsMdUpdtTaskType(node_on_in_edge->GetTaskType())) { + RegstDesc* ctrl_regst = task_node->BuildCtrlRegstDesc(node_on_in_edge); + RegstDesc* copy_out_regst = copy_hd_task_node->GetProducedRegst("copy_out").get(); + int64_t piece_num_in_batch = Global::Get()->NumOfPiecesInBatch(); + ctrl_regst->UpdtMinRegstNumIfNeed(copy_out_regst->min_register_num() + + piece_num_in_batch - 1); + CtrlRegstDesc* ctrl_regst_desc = + ctrl_regst->mut_regst_desc_type()->mutable_ctrl_regst_desc(); + ctrl_regst_desc->set_reliant_regst_desc_id(copy_out_regst->regst_desc_id()); + ctrl_regst_desc->set_returned_regst_num(piece_num_in_batch); + } + }); + } + } + } +} void TaskGraph::CollectAncestorsForEachNode() { std::vector ordered_nodes; diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 9b6d1eff45af4a3d81a6ae69815f07bd25a4c582..5e8cd13f12c55ac31299ea0e9ab5d08e569f8aa9 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -148,12 +148,17 @@ void TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node) { } const auto& dst_ancestors = dst_node->ancestors(); if (dst_ancestors.find(this) != dst_ancestors.end()) return; + BuildCtrlRegstDesc(dst_node); +} + +RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node) { RegstDescTypeProto regst_desc_type; regst_desc_type.mutable_ctrl_regst_desc(); auto regst = NewProducedRegst(false, 1, kMaxRegisterNum, regst_desc_type); std::string name = "out_ctrl_" + std::to_string(regst->regst_desc_id()); CHECK(produced_regsts_.emplace(name, regst).second); dst_node->ConsumeRegst("in_ctrl", regst); + return regst.get(); } void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) { diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 2c6f256bc5ac45b60ff371f0661881ece3d99fe4..9b1bf56501adc5abebba228353fddc3ee4f6fa62 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -75,6 +75,7 @@ class TaskNode : public Node { void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual int64_t MemZoneId121() const; // TODO: there is bug for reduce task node void BuildCtrlRegstDescIfNeed(TaskNode* dst_node); + RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node); protected: std::shared_ptr ProduceRegst(const std::string& name, bool enable_mem_sharing); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 337dc3f540cc2068a8867c35d28166d5aca1bf8b..043b64f1994718ac2e9627f1d1b7151ee7dfe92d 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -63,6 +63,7 @@ Plan Compiler::DoCompile() { if (Global::Get()->other_conf().use_ordered_allreduce_in_mdupdt()) { task_gph->AddCtrlEdgeInReduceStruct(); } + if (job_desc->IsTrain()) { task_gph->AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); } Plan plan; task_gph->ForEachNode([&](TaskNode* task_node) { if (task_node->IsMeaningLess()) { return; } diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp index dd55ea3433e23a18d5518c25bb02e34d8ff39a97..2e553251050634040c01b234da85815430c8d5b5 100644 --- a/oneflow/core/job/improver.cpp +++ b/oneflow/core/job/improver.cpp @@ -182,6 +182,13 @@ std::shared_ptr> MakeRegstDescId2RegstDesc(Pla return regst_desc_id2regst_desc; } +std::function MakeGetterGetPlanRegstNum(Plan* plan) { + auto regst_desc_id2regst_desc = MakeRegstDescId2RegstDesc(plan); + return [regst_desc_id2regst_desc](int64_t regst_desc_id) { + return regst_desc_id2regst_desc->at(regst_desc_id)->register_num(); + }; +} + std::function MakeSetterSetPlanRegstNum(Plan* plan) { auto regst_desc_id2regst_desc = MakeRegstDescId2RegstDesc(plan); return [regst_desc_id2regst_desc](int64_t regst_desc_id, uint64_t num) { @@ -374,6 +381,24 @@ void ForEachMemSharingCriticalSection( } } +void FixReliantCtrlRegstNum(const Plan& plan, const std::function& GetRegstNum, + const std::function& SetRegstNum) { + for (const auto& task_proto : plan.task()) { + for (const auto& pair : task_proto.produced_regst_desc()) { + const RegstDescProto& regst = pair.second; + const RegstDescTypeProto& regst_type = regst.regst_desc_type(); + if (regst_type.has_ctrl_regst_desc() + && regst_type.ctrl_regst_desc().has_reliant_regst_desc_id()) { + // set ctrl regst num between copyHd and MdUpdt + CHECK(task_proto.task_type() == kCopyHd); + uint64_t regst_num = GetRegstNum(regst_type.ctrl_regst_desc().reliant_regst_desc_id()) + + Global::Get()->NumOfPiecesInBatch() - 1; + SetRegstNum(regst.regst_desc_id(), regst_num); + } + } + } +} + } // namespace uint64_t Improver::AvailableMemSize(int64_t machine_id, int64_t memory_zone_id) const { @@ -509,6 +534,7 @@ Plan Improver::Improve(const AvailableMemDesc& amd, const Plan& naive_plan, Plan plan(mem_shared_plan); ForEachImprovedRegstNum(act_graph, mem_shared_plan, true, PathDurations4RegstDescId, PathIIScales4RegstDescId, MakeSetterSetPlanRegstNum(&plan)); + FixReliantCtrlRegstNum(plan, MakeGetterGetPlanRegstNum(&plan), MakeSetterSetPlanRegstNum(&plan)); return plan; } diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto index c04dbae8f145524542a760aa03b7dcff81f1864a..2879360fce0940eeab8bc0227895eb80fbd10628 100644 --- a/oneflow/core/register/register_desc.proto +++ b/oneflow/core/register/register_desc.proto @@ -16,6 +16,8 @@ message DataRegstDesc { } message CtrlRegstDesc { + optional int64 reliant_regst_desc_id = 1; + optional int32 returned_regst_num = 2 [default = 1]; } message RegstDescTypeProto { diff --git a/oneflow/core/register/runtime_register_desc.cpp b/oneflow/core/register/runtime_register_desc.cpp index e7c85374b74c0fcf7f5a3f1dd59ba6086fcfc181..756cecea810136606e4140c35d70200d6b2c4a14 100644 --- a/oneflow/core/register/runtime_register_desc.cpp +++ b/oneflow/core/register/runtime_register_desc.cpp @@ -11,6 +11,7 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& proto) { consumers_actor_id_ = PbRf2StdVec(proto.consumer_task_id()); register_num_ = proto.register_num(); mem_case_ = proto.mem_case(); + regst_desc_type_ = proto.regst_desc_type(); if (proto.regst_desc_type().has_data_regst_desc()) { const DataRegstDesc& data_regst_desc = proto.regst_desc_type().data_regst_desc(); for (const LbiBlobDescPair& pair : data_regst_desc.lbi2blob_desc()) { diff --git a/oneflow/core/register/runtime_register_desc.h b/oneflow/core/register/runtime_register_desc.h index fcf7adf6e651cb366d75037e7c46b195555e52a7..6db841456425109abd2c3332acfedd613efde0ff 100644 --- a/oneflow/core/register/runtime_register_desc.h +++ b/oneflow/core/register/runtime_register_desc.h @@ -20,6 +20,7 @@ class RtRegstDesc { const std::vector& consumers_actor_id() const { return consumers_actor_id_; } int64_t register_num() const { return register_num_; } const MemoryCase& mem_case() const { return mem_case_; } + const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; } const BlobDesc* GetBlobDescFromLbi(const LogicalBlobId& lbi) const; const BlobDesc* packed_blob_desc() const { return &packed_blob_desc_; } @@ -32,6 +33,7 @@ class RtRegstDesc { int64_t producer_actor_id_; std::vector consumers_actor_id_; int64_t register_num_; + RegstDescTypeProto regst_desc_type_; MemoryCase mem_case_; HashMap> lbi2blob_desc_; BlobDesc packed_blob_desc_;